1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::tensor;
22 
23 namespace mlir {
24 namespace tensor {
25 namespace {
26 
27 struct CastOpInterface
28     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
29                                                     tensor::CastOp> {
30   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31                               const AnalysisState &state) const {
32     return false;
33   }
34 
35   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
36                                const AnalysisState &state) const {
37     return false;
38   }
39 
40   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
41                                             const AnalysisState &state) const {
42     return {op->getResult(0)};
43   }
44 
45   BufferRelation bufferRelation(Operation *op, OpResult opResult,
46                                 const AnalysisState &state) const {
47     return BufferRelation::Equivalent;
48   }
49 
50   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51                           const BufferizationOptions &options) const {
52     auto castOp = cast<tensor::CastOp>(op);
53 
54     // The result buffer still has the old (pre-cast) type.
55     FailureOr<Value> resultBuffer =
56         getBuffer(rewriter, castOp.getSource(), options);
57     if (failed(resultBuffer))
58       return failure();
59     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType =
70         getMemRefType(resultTensorType, options, layout,
71                       sourceMemRefType.getMemorySpaceAsInt());
72 
73     // Replace the op with a memref.cast.
74     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
75                                              resultMemRefType) &&
76            "CallOp::bufferize: cast incompatible");
77     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
78                                                  *resultBuffer);
79 
80     return success();
81   }
82 };
83 
84 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
85 struct CollapseShapeOpInterface
86     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
87                                                     tensor::CollapseShapeOp> {
88   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
89                               const AnalysisState &state) const {
90     return false;
91   }
92 
93   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
94                                const AnalysisState &state) const {
95     return false;
96   }
97 
98   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
99                                             const AnalysisState &state) const {
100     if (&opOperand == &op->getOpOperand(0) /*src*/)
101       return {op->getOpResult(0)};
102     return {};
103   }
104 
105   BufferRelation bufferRelation(Operation *op, OpResult opResult,
106                                 const AnalysisState &state) const {
107     return BufferRelation::Equivalent;
108   }
109 
110   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111                           const BufferizationOptions &options) const {
112     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
113     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
114     FailureOr<Value> maybeBuffer =
115         getBuffer(rewriter, collapseShapeOp.getSrc(), options);
116     if (failed(maybeBuffer))
117       return failure();
118     Value buffer = *maybeBuffer;
119     auto bufferType = buffer.getType().cast<MemRefType>();
120 
121     if (tensorResultType.getRank() == 0) {
122       // 0-d collapses must go through a different op builder.
123       MemRefType resultType;
124 
125       if (bufferType.getLayout().isIdentity()) {
126         // Standard layout: result type has no offset.
127         MemRefLayoutAttrInterface layout;
128         resultType = MemRefType::get({}, tensorResultType.getElementType(),
129                                      layout, bufferType.getMemorySpace());
130       } else {
131         // Source memref has a layout map: result type has the same offset as
132         // the source type.
133         SmallVector<int64_t> strides;
134         int64_t offset;
135         if (failed(getStridesAndOffset(bufferType, strides, offset)))
136           return failure();
137         AffineMap resultLayout =
138             makeStridedLinearLayoutMap({}, offset, op->getContext());
139         resultType =
140             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
141                             bufferType.getMemorySpaceAsInt());
142       }
143 
144       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
145           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
146       return success();
147     }
148 
149     // If the dims are not collapsible (due to an incompatible source layout
150     // map), force an out-of-place bufferization, i.e., a buffer copy. This
151     // newly allocated buffer will have no layout map and thus be collapsible.
152     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
153         bufferType, collapseShapeOp.getReassociationIndices());
154     if (!canBeCollapsed) {
155       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
156       AnalysisState analysisState(options);
157       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
158           rewriter, op->getLoc(), collapseShapeOp.getSrc(),
159           analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
160       if (failed(tensorAlloc))
161         return failure();
162       auto memrefType =
163           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
164                           collapseShapeOp.getSrcType().getElementType(),
165                           AffineMap(), bufferType.getMemorySpaceAsInt());
166       buffer = rewriter.create<bufferization::ToMemrefOp>(
167           op->getLoc(), memrefType, *tensorAlloc);
168     }
169 
170     // Result type is inferred by the builder.
171     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
172         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
173     return success();
174   }
175 };
176 
177 /// Bufferization of tensor.dim. Replace with memref.dim.
178 struct DimOpInterface
179     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
180                                                     tensor::DimOp> {
181   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
182                               const AnalysisState &state) const {
183     return true;
184   }
185 
186   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
187                                const AnalysisState &state) const {
188     return false;
189   }
190 
191   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
192                                             const AnalysisState &state) const {
193     return {};
194   }
195 
196   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
197                           const BufferizationOptions &options) const {
198     auto dimOp = cast<tensor::DimOp>(op);
199     FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
200     if (failed(v))
201       return failure();
202     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
203                                                 dimOp.index());
204     return success();
205   }
206 };
207 
208 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
209 struct ExpandShapeOpInterface
210     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
211                                                     tensor::ExpandShapeOp> {
212   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
213                               const AnalysisState &state) const {
214     return false;
215   }
216 
217   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
218                                const AnalysisState &state) const {
219     return false;
220   }
221 
222   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
223                                             const AnalysisState &state) const {
224     if (&opOperand == &op->getOpOperand(0) /*src*/)
225       return {op->getOpResult(0)};
226     return {};
227   }
228 
229   BufferRelation bufferRelation(Operation *op, OpResult opResult,
230                                 const AnalysisState &state) const {
231     return BufferRelation::Equivalent;
232   }
233 
234   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
235                           const BufferizationOptions &options) const {
236     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
237     auto tensorResultType = expandShapeOp.getResultType();
238     FailureOr<Value> buffer =
239         getBuffer(rewriter, expandShapeOp.getSrc(), options);
240     if (failed(buffer))
241       return failure();
242 
243     // Memref result type is inferred by the builder based on reassociation
244     // indices and result shape.
245     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
246         rewriter, op, tensorResultType.getShape(), *buffer,
247         expandShapeOp.getReassociationIndices());
248     return success();
249   }
250 };
251 
252 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
253 struct ExtractSliceOpInterface
254     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
255                                                     tensor::ExtractSliceOp> {
256   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
257                               const AnalysisState &state) const {
258     return false;
259   }
260 
261   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
262                                const AnalysisState &state) const {
263     return false;
264   }
265 
266   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
267                                             const AnalysisState &state) const {
268     if (&opOperand == &op->getOpOperand(0) /*source*/)
269       return {op->getOpResult(0)};
270     return {};
271   }
272 
273   BufferRelation bufferRelation(Operation *op, OpResult opResult,
274                                 const AnalysisState &state) const {
275     return BufferRelation::None;
276   }
277 
278   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
279                           const BufferizationOptions &options) const {
280     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
281     Location loc = extractSliceOp.getLoc();
282 
283     // Even if this op was decided to bufferize out-of-place, do not insert the
284     // buffer copy yet. This is done later in this function.
285     FailureOr<Value> srcMemref =
286         getBuffer(rewriter, extractSliceOp.getSource(), options);
287     if (failed(srcMemref))
288       return failure();
289     auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
290     auto dstTensorType =
291         extractSliceOp.getResult().getType().cast<RankedTensorType>();
292 
293     // Expand offsets, sizes and strides to the full rank to handle the
294     // rank-reducing case.
295     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
296     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
297     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
298     OffsetSizeAndStrideOpInterface::expandToRank(
299         *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
300         [&](Value target, int64_t dim) -> OpFoldResult {
301           auto shapedType = target.getType().cast<ShapedType>();
302           if (shapedType.isDynamicDim(dim))
303             return rewriter.create<memref::DimOp>(loc, target, dim).result();
304           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
305         });
306     // Bufferize to subview.
307     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
308                                  dstTensorType.getRank(), srcMemrefType,
309                                  mixedOffsets, mixedSizes, mixedStrides)
310                                  .cast<MemRefType>();
311     Value subView = rewriter.create<memref::SubViewOp>(
312         loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
313         mixedStrides);
314 
315     replaceOpWithBufferizedValues(rewriter, op, subView);
316     return success();
317   }
318 };
319 
320 /// Bufferization of tensor.extract. Replace with memref.load.
321 struct ExtractOpInterface
322     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
323                                                     tensor::ExtractOp> {
324   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
325                               const AnalysisState &state) const {
326     return true;
327   }
328 
329   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
330                                const AnalysisState &state) const {
331     return false;
332   }
333 
334   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
335                                             const AnalysisState &state) const {
336     return {};
337   }
338 
339   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
340                           const BufferizationOptions &options) const {
341     auto extractOp = cast<tensor::ExtractOp>(op);
342     FailureOr<Value> srcMemref =
343         getBuffer(rewriter, extractOp.getTensor(), options);
344     if (failed(srcMemref))
345       return failure();
346     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
347                                                  extractOp.indices());
348     return success();
349   }
350 };
351 
352 // Implements backtracking to traverse indices of the output buffer while
353 // iterating over op.elements().
354 static void createStores(RewriterBase &rewriter, Location loc, int dim,
355                          Value buffer, ArrayRef<int64_t> shape,
356                          ArrayRef<Value> constants,
357                          OperandRange::iterator &elementIt,
358                          SmallVectorImpl<Value> &indices) {
359   if (dim == static_cast<int>(shape.size()) - 1) {
360     for (int i = 0; i < shape.back(); ++i) {
361       indices.back() = constants[i];
362       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
363       ++elementIt;
364     }
365     return;
366   }
367   for (int i = 0; i < shape[dim]; ++i) {
368     indices[dim] = constants[i];
369     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
370                  indices);
371   }
372 }
373 
374 /// Bufferization of tensor.from_elements.
375 struct FromElementsOpInterface
376     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
377                                                     tensor::FromElementsOp> {
378   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
379                           const BufferizationOptions &options) const {
380     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
381 
382     // TODO: Implement memory space for this op.
383     if (options.defaultMemorySpace != static_cast<unsigned>(0))
384       return op->emitError("memory space not implemented yet");
385 
386     // Allocate a buffer for the result.
387     Location loc = op->getLoc();
388     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
389     auto shape = tensorType.getShape();
390     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
391     AnalysisState analysisState(options);
392     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
393         rewriter, loc, fromElementsOp.getResult(),
394         analysisState.isTensorYielded(fromElementsOp.getResult()), options,
395         /*copy=*/false);
396     if (failed(tensorAlloc))
397       return failure();
398     auto memrefType =
399         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
400     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
401         op->getLoc(), memrefType, *tensorAlloc);
402 
403     // Case: tensor<0xelem_type>.
404     if (fromElementsOp.getElements().empty()) {
405       replaceOpWithBufferizedValues(rewriter, op, buffer);
406       return success();
407     }
408 
409     // Case: tensor<elem_type>.
410     if (shape.empty()) {
411       rewriter.create<memref::StoreOp>(
412           loc, fromElementsOp.getElements().front(), buffer);
413       replaceOpWithBufferizedValues(rewriter, op, buffer);
414       return success();
415     }
416 
417     // Create constants for the range of possible indices [0, max{shape_i}).
418     auto maxDim = *std::max_element(shape.begin(), shape.end());
419     SmallVector<Value, 2> constants;
420     constants.reserve(maxDim);
421     for (int i = 0; i < maxDim; ++i)
422       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
423 
424     // Traverse all `elements` and create `memref.store` ops.
425     auto elementIt = fromElementsOp.getElements().begin();
426     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
427     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
428                  indices);
429 
430     replaceOpWithBufferizedValues(rewriter, op, buffer);
431     return success();
432   }
433 };
434 
435 /// Bufferization of tensor.generate.
436 struct GenerateOpInterface
437     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
438                                                     tensor::GenerateOp> {
439   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
440                           const BufferizationOptions &options) const {
441     auto generateOp = cast<tensor::GenerateOp>(op);
442 
443     // TODO: Implement memory space for this op.
444     if (options.defaultMemorySpace != static_cast<unsigned>(0))
445       return op->emitError("memory space not implemented yet");
446 
447     auto tensorType = generateOp.getType().cast<RankedTensorType>();
448     // Allocate memory.
449     Location loc = op->getLoc();
450     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
451     AnalysisState analysisState(options);
452     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
453         rewriter, loc, generateOp.getResult(),
454         analysisState.isTensorYielded(generateOp.getResult()), options,
455         /*copy=*/false);
456     if (failed(tensorAlloc))
457       return failure();
458     auto memrefType =
459         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
460     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
461         op->getLoc(), memrefType, *tensorAlloc);
462 
463     // Collect loop bounds.
464     int64_t rank = memrefType.getRank();
465     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
466     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
467     SmallVector<Value, 4> lowerBounds(rank, zero);
468     SmallVector<Value, 4> steps(rank, one);
469     SmallVector<Value, 4> upperBounds;
470     int nextDynamicIndex = 0;
471     for (int i = 0; i < rank; i++) {
472       Value upperBound =
473           memrefType.isDynamicDim(i)
474               ? generateOp.getDynamicExtents()[nextDynamicIndex++]
475               : rewriter.create<arith::ConstantIndexOp>(
476                     loc, memrefType.getDimSize(i));
477       upperBounds.push_back(upperBound);
478     }
479 
480     // Generate tensor elements with a parallel loop that stores into
481     // each element of the resulting memref. We use mergeBlockBefore to "move"
482     // this op's body into the scf.parallel's body.
483     auto parallel =
484         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
485     Block *parallelBody = parallel.getBody();
486     rewriter.mergeBlockBefore(&generateOp.getBody().front(),
487                               parallelBody->getTerminator(),
488                               parallelBody->getArguments());
489     // Replace the inlined yield op with a store op. The scf.parallel's builder
490     // already populated an scf.yield at the end, so we don't need to worry
491     // about creating that.
492     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
493     rewriter.setInsertionPointAfter(elementYield);
494     rewriter.replaceOpWithNewOp<memref::StoreOp>(
495         elementYield, elementYield->getOperands()[0], buffer,
496         parallelBody->getArguments());
497 
498     replaceOpWithBufferizedValues(rewriter, op, buffer);
499     return success();
500   }
501 };
502 
503 /// Bufferization of tensor.insert. Replace with memref.store.
504 struct InsertOpInterface
505     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
506                                                     tensor::InsertOp> {
507   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
508                               const AnalysisState &state) const {
509     return true;
510   }
511 
512   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
513                                const AnalysisState &state) const {
514     return true;
515   }
516 
517   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
518                                             const AnalysisState &state) const {
519     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
520            "expected dest OpOperand");
521     return {op->getOpResult(0)};
522   }
523 
524   SmallVector<OpOperand *>
525   getAliasingOpOperand(Operation *op, OpResult opResult,
526                        const AnalysisState &state) const {
527     return {&op->getOpOperand(1) /*dest*/};
528   }
529 
530   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
531                           const BufferizationOptions &options) const {
532     auto insertOp = cast<tensor::InsertOp>(op);
533     FailureOr<Value> destMemref =
534         getBuffer(rewriter, insertOp.getDest(), options);
535     if (failed(destMemref))
536       return failure();
537     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
538                                      *destMemref, insertOp.getIndices());
539     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
540     return success();
541   }
542 
543   BufferRelation bufferRelation(Operation *op, OpResult opResult,
544                                 const AnalysisState &state) const {
545     return BufferRelation::Equivalent;
546   }
547 };
548 
549 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
550 /// equivalent operand / result and same offset/sizes/strides specification).
551 ///
552 /// This is one particular type of relationship between ops on tensors that
553 /// reduce to an equivalence on buffers. This should be generalized and
554 /// exposed as interfaces on the proper types.
555 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
556                                          ExtractSliceOp st, InsertSliceOp sti) {
557   if (!st || !sti)
558     return false;
559   if (sti != sti &&
560       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
561     return false;
562   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
563     return false;
564   return true;
565 }
566 
567 /// Return true if `value` is originating from an ExtractSliceOp that matches
568 /// the given InsertSliceOp.
569 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
570                                       InsertSliceOp insertOp) {
571   auto condition = [&](Value val) {
572     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
573       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
574         return true;
575     return false;
576   };
577 
578   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
579                       condition);
580 }
581 
582 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
583 /// certain circumstances, this op can also be a no-op.
584 struct InsertSliceOpInterface
585     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
586                                                     tensor::InsertSliceOp> {
587   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
588                               const AnalysisState &state) const {
589     return true;
590   }
591 
592   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
593                                const AnalysisState &state) const {
594     return &opOperand == &op->getOpOperand(1) /*dest*/;
595   }
596 
597   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
598                                             const AnalysisState &state) const {
599     if (&opOperand == &op->getOpOperand(1) /*dest*/)
600       return {op->getResult(0)};
601     return {};
602   }
603 
604   BufferRelation bufferRelation(Operation *op, OpResult opResult,
605                                 const AnalysisState &state) const {
606     return BufferRelation::Equivalent;
607   }
608 
609   bool isNotConflicting(Operation *op, OpOperand *uRead,
610                         OpOperand *uConflictingWrite,
611                         const AnalysisState &state) const {
612     Operation *readingOp = uRead->getOwner();
613     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
614 
615     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
616     // uRead is an InsertSliceOp...
617     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
618       // As an example, consider the following IR.
619       //
620       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
621       // %1 = linalg.fill %cst, %0 {inplace= [true] }
622       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
623       //     {inplace= [true] }
624 
625       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
626       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
627           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
628                                     insertSliceOp))
629         // Case 1: The main insight is that InsertSliceOp reads only part of
630         // the destination tensor. The overwritten area is not read. If
631         // uConflictingWrite writes into exactly the memory location that is
632         // being read by uRead, this is not a conflict.
633         //
634         // In the above example:
635         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
636         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
637         //
638         // The read of %t does not conflict with the write of the FillOp
639         // (same aliases!) because the area that the FillOp operates on is
640         // exactly the one that is *not* read via %t.
641         return true;
642 
643       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
644           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
645           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
646         // Case 2: The read of the source tensor and the write to the dest
647         // tensor via an InsertSliceOp is not a conflict if the read is
648         // reading exactly that part of an equivalent tensor that the
649         // InsertSliceOp is writing.
650         //
651         // In the above example:
652         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
653         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
654         return true;
655     }
656 
657     // If uConflictingWrite is an InsertSliceOp...
658     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
659       // As an example, consider the following IR.
660       //
661       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
662       // %1 = linalg.fill %cst, %0 {inplace= [true] }
663       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
664       //     {inplace= [true] }
665       // %3 = vector.transfer_read %1, %cst
666       //
667       // In the above example:
668       // uRead             = OpOperand 0 (%1) of vector.transfer_read
669       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
670       // lastWrite         = %1
671       //
672       // This is not a conflict because the InsertSliceOp overwrites the
673       // memory segment of %1 with the exact same data. (Effectively, there
674       // is no memory write here.)
675       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
676           state.areEquivalentBufferizedValues(uRead->get(),
677                                               insertSliceOp.getSource()) &&
678           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
679                                     insertSliceOp))
680         return true;
681 
682     return false;
683   }
684 
685   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
686                           const BufferizationOptions &options) const {
687     // insert_slice ops arise from tiling and bufferizing them out-of-place is
688     // generally a deal breaker. When used with loops, this ends up cloning the
689     // whole tensor on every single iteration and is a symptom of a
690     // catastrophically bad scheduling decision.
691     // TODO: be very loud about it or even consider failing the pass.
692     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
693     Location loc = insertSliceOp.getLoc();
694     FailureOr<Value> dstMemref =
695         getBuffer(rewriter, insertSliceOp.getDest(), options);
696     if (failed(dstMemref))
697       return failure();
698 
699     // Expand offsets, sizes and strides to the full rank to handle the
700     // rank-reducing case.
701     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
702     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
703     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
704     OffsetSizeAndStrideOpInterface::expandToRank(
705         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
706         [&](Value target, int64_t dim) -> OpFoldResult {
707           auto shapedType = target.getType().cast<ShapedType>();
708           if (shapedType.isDynamicDim(dim))
709             return rewriter.create<memref::DimOp>(loc, target, dim).result();
710           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
711         });
712     // Take a subview of the dst.
713     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
714     auto subviewMemRefType =
715         memref::SubViewOp::inferRankReducedResultType(
716             insertSliceOp.getSourceType().getRank(), dstMemrefType,
717             mixedOffsets, mixedSizes, mixedStrides)
718             .cast<MemRefType>();
719     Value subView = rewriter.create<memref::SubViewOp>(
720         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
721         mixedStrides);
722 
723     // Copy tensor. If this tensor.insert_slice has a matching
724     // tensor.extract_slice, the copy operation will eventually fold away.
725     FailureOr<Value> srcMemref =
726         getBuffer(rewriter, insertSliceOp.getSource(), options);
727     if (failed(srcMemref))
728       return failure();
729     if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
730       return failure();
731 
732     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
733     return success();
734   }
735 };
736 
737 /// Bufferization of tensor.rank. Replace with memref.rank.
738 struct RankOpInterface
739     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
740                                                     tensor::RankOp> {
741   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
742                               const AnalysisState &state) const {
743     return true;
744   }
745 
746   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
747                                const AnalysisState &state) const {
748     return false;
749   }
750 
751   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
752                                             const AnalysisState &state) const {
753     return {};
754   }
755 
756   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
757                           const BufferizationOptions &options) const {
758     auto rankOp = cast<tensor::RankOp>(op);
759     FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
760     if (failed(v))
761       return failure();
762     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
763                                                  *v);
764     return success();
765   }
766 };
767 
768 /// Bufferization of tensor.reshape. Replace with memref.reshape.
769 struct ReshapeOpInterface
770     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
771                                                     tensor::ReshapeOp> {
772   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
773                               const AnalysisState &state) const {
774     if (&opOperand == &op->getOpOperand(1) /* shape */)
775       return true;
776     return false;
777   }
778 
779   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
780                                const AnalysisState &state) const {
781     return false;
782   }
783 
784   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
785                                             const AnalysisState &state) const {
786     return {op->getOpResult(0)};
787   }
788 
789   BufferRelation bufferRelation(Operation *op, OpResult opResult,
790                                 const AnalysisState &state) const {
791     return BufferRelation::Equivalent;
792   }
793 
794   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
795                           const BufferizationOptions &options) const {
796     auto reshapeOp = cast<tensor::ReshapeOp>(op);
797     FailureOr<Value> srcBuffer =
798         getBuffer(rewriter, reshapeOp.getSource(), options);
799     FailureOr<Value> shapeBuffer =
800         getBuffer(rewriter, reshapeOp.getShape(), options);
801     if (failed(srcBuffer) || failed(shapeBuffer))
802       return failure();
803     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
804     auto resultMemRefType = getMemRefType(
805         resultTensorType, options, /*layout=*/{},
806         srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
807     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
808         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
809     return success();
810   }
811 };
812 
813 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
814 /// equivalent operand / result and same offset/sizes/strides specification).
815 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
816                                          ExtractSliceOp st,
817                                          ParallelInsertSliceOp sti) {
818   if (!st || !sti)
819     return false;
820   if (st != sti &&
821       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
822     return false;
823   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
824     return false;
825   return true;
826 }
827 
828 /// Return true if `value` is originating from an ExtractSliceOp that matches
829 /// the given InsertSliceOp.
830 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
831                                       ParallelInsertSliceOp insertOp) {
832   auto condition = [&](Value val) {
833     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
834       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
835         return true;
836     return false;
837   };
838 
839   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
840                       condition);
841 }
842 
843 /// Analysis of ParallelInsertSliceOp.
844 struct ParallelInsertSliceOpInterface
845     : public BufferizableOpInterface::ExternalModel<
846           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
847   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
848                                             const AnalysisState &state) const {
849     if (&opOperand != &op->getOpOperand(1) /*dest*/)
850       return {};
851 
852     // ParallelInsertSliceOp itself has no results, query its tied op results.
853     auto insertOp = cast<ParallelInsertSliceOp>(op);
854     return {insertOp.getTiedOpResult()};
855   }
856 
857   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
858                               const AnalysisState &state) const {
859     return true;
860   }
861 
862   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
863                                const AnalysisState &state) const {
864     return &opOperand == &op->getOpOperand(1) /*dest*/;
865   }
866 
867   BufferRelation bufferRelation(Operation *op, OpResult opResult,
868                                 const AnalysisState &state) const {
869     return BufferRelation::Equivalent;
870   }
871 
872   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
873                                  const AnalysisState &state) const {
874     // This interface method is overridden because we want to set a custom
875     // insertion point for tensor copies. They should be inserted right before
876     // the ForeachThreadOp. E.g.:
877     //
878     // %r0, %r1 = foreach_thead ... {
879     //   ...
880     //   perform_concurrently {
881     //     parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
882     //     parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
883     //   }
884     // }
885     //
886     // After TensorCopyInsertion:
887     //
888     // %copy = bufferization.alloc_tensor() copy(%d)
889     // %r0, %r1 = foreach_thead ... {
890     //   ...
891     //   perform_concurrently {
892     //     parallel_insert_slice %a into %b ...
893     //     parallel_insert_slice %c into %copy ...
894     //   }
895     // }
896 
897     OpBuilder::InsertionGuard g(rewriter);
898     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
899     ParallelCombiningOpInterface parallelCombiningParent =
900         parallelInsertSliceOp.getParallelCombiningParent();
901     Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
902 
903     // Nothing to do if the destination tensor is inplace.
904     assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
905            "source is always in-place");
906     if (state.isInPlace(op->getOpOperand(1) /*dest*/))
907       return success();
908 
909     // Find corresponding OpResult.
910     OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
911 
912     // Insert tensor allocation right before the ForeachThreadOp.
913     rewriter.setInsertionPoint(parallelIteratingOp);
914     bool isYielded = state.isTensorYielded(opResult);
915     FailureOr<Value> alloc = allocateTensorForShapedValue(
916         rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
917         /*escape=*/isYielded, state.getOptions());
918     if (failed(alloc))
919       return failure();
920 
921     // Update destination operand.
922     rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
923       parallelInsertSliceOp.getDestMutable().assign(*alloc);
924     });
925 
926     return success();
927   }
928 
929   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
930                           const BufferizationOptions &options) const {
931     OpBuilder::InsertionGuard g(rewriter);
932     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
933     ParallelCombiningOpInterface parallelCombiningParent =
934         parallelInsertSliceOp.getParallelCombiningParent();
935     Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
936 
937     // Get destination buffer.
938     FailureOr<Value> destBuffer =
939         getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
940     if (failed(destBuffer))
941       return failure();
942 
943     // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
944     rewriter.setInsertionPoint(parallelCombiningParent);
945     FailureOr<Value> srcBuffer =
946         getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
947     if (failed(srcBuffer))
948       return failure();
949     Value subview = rewriter.create<memref::SubViewOp>(
950         parallelInsertSliceOp.getLoc(), *destBuffer,
951         parallelInsertSliceOp.getMixedOffsets(),
952         parallelInsertSliceOp.getMixedSizes(),
953         parallelInsertSliceOp.getMixedStrides());
954     // This memcpy will fold away if everything bufferizes in-place.
955     if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
956                                     *srcBuffer, subview)))
957       return failure();
958 
959     // Replace all uses of parallelIteratingOp (just the corresponding result).
960     rewriter.setInsertionPointAfter(parallelIteratingOp);
961     Value toTensorOp =
962         rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
963     // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
964     SmallVector<OpOperand *> resultUses = llvm::to_vector(
965         llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
966                         [](OpOperand &use) { return &use; }));
967     for (OpOperand *use : resultUses) {
968       rewriter.updateRootInPlace(use->getOwner(),
969                                  [&]() { use->set(toTensorOp); });
970     }
971     rewriter.eraseOp(op);
972     return success();
973   }
974 
975   // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
976   // the code.
977   bool isNotConflicting(Operation *op, OpOperand *uRead,
978                         OpOperand *uConflictingWrite,
979                         const AnalysisState &state) const {
980     Operation *readingOp = uRead->getOwner();
981     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
982 
983     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
984     // uRead is an InsertSliceOp...
985     if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
986       // As an example, consider the following IR.
987       //
988       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
989       // %1 = linalg.fill %cst, %0 {inplace= [true] }
990       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
991       //     {inplace= [true] }
992 
993       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
994       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
995           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
996                                     insertSliceOp))
997         // Case 1: The main insight is that InsertSliceOp reads only part of
998         // the destination tensor. The overwritten area is not read. If
999         // uConflictingWrite writes into exactly the memory location that is
1000         // being read by uRead, this is not a conflict.
1001         //
1002         // In the above example:
1003         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
1004         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
1005         //
1006         // The read of %t does not conflict with the write of the FillOp
1007         // (same aliases!) because the area that the FillOp operates on is
1008         // exactly the one that is *not* read via %t.
1009         return true;
1010 
1011       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
1012           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1013           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
1014         // Case 2: The read of the source tensor and the write to the dest
1015         // tensor via an InsertSliceOp is not a conflict if the read is
1016         // reading exactly that part of an equivalent tensor that the
1017         // InsertSliceOp is writing.
1018         //
1019         // In the above example:
1020         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
1021         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1022         return true;
1023     }
1024 
1025     // If uConflictingWrite is an InsertSliceOp...
1026     if (auto insertSliceOp =
1027             dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1028       // As an example, consider the following IR.
1029       //
1030       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1031       // %1 = linalg.fill %cst, %0 {inplace= [true] }
1032       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1033       //     {inplace= [true] }
1034       // %3 = vector.transfer_read %1, %cst
1035       //
1036       // In the above example:
1037       // uRead             = OpOperand 0 (%1) of vector.transfer_read
1038       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1039       // lastWrite         = %1
1040       //
1041       // This is not a conflict because the InsertSliceOp overwrites the
1042       // memory segment of %1 with the exact same data. (Effectively, there
1043       // is no memory write here.)
1044       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1045           state.areEquivalentBufferizedValues(uRead->get(),
1046                                               insertSliceOp.getSource()) &&
1047           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
1048                                     insertSliceOp))
1049         return true;
1050 
1051     return false;
1052   }
1053 };
1054 
1055 } // namespace
1056 } // namespace tensor
1057 } // namespace mlir
1058 
1059 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
1060     DialectRegistry &registry) {
1061   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1062     CastOp::attachInterface<CastOpInterface>(*ctx);
1063     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1064     DimOp::attachInterface<DimOpInterface>(*ctx);
1065     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1066     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1067     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1068     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1069     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1070     InsertOp::attachInterface<InsertOpInterface>(*ctx);
1071     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1072     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1073         *ctx);
1074     RankOp::attachInterface<RankOpInterface>(*ctx);
1075     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1076   });
1077 }
1078