1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::tensor;
22 
23 namespace mlir {
24 namespace tensor {
25 namespace {
26 
27 struct CastOpInterface
28     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
29                                                     tensor::CastOp> {
30   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31                               const AnalysisState &state) const {
32     return false;
33   }
34 
35   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
36                                const AnalysisState &state) const {
37     return false;
38   }
39 
40   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
41                                             const AnalysisState &state) const {
42     return {op->getResult(0)};
43   }
44 
45   BufferRelation bufferRelation(Operation *op, OpResult opResult,
46                                 const AnalysisState &state) const {
47     return BufferRelation::Equivalent;
48   }
49 
50   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51                           const BufferizationOptions &options) const {
52     auto castOp = cast<tensor::CastOp>(op);
53 
54     // The result buffer still has the old (pre-cast) type.
55     Value resultBuffer = getBuffer(rewriter, castOp.source(), options);
56     auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
57     Attribute memorySpace = sourceMemRefType.getMemorySpace();
58     TensorType resultTensorType =
59         castOp.getResult().getType().cast<TensorType>();
60     MemRefLayoutAttrInterface layout;
61 
62     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
63       if (resultTensorType.isa<RankedTensorType>())
64         layout = rankedMemRefType.getLayout();
65 
66     // Compute the new memref type.
67     Type resultMemRefType =
68         getMemRefType(resultTensorType, options, layout, memorySpace);
69 
70     // Replace the op with a memref.cast.
71     assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
72                                              resultMemRefType) &&
73            "CallOp::bufferize: cast incompatible");
74     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
75                                                  resultBuffer);
76 
77     return success();
78   }
79 };
80 
81 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
82 struct CollapseShapeOpInterface
83     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
84                                                     tensor::CollapseShapeOp> {
85   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
86                               const AnalysisState &state) const {
87     return false;
88   }
89 
90   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
91                                const AnalysisState &state) const {
92     return false;
93   }
94 
95   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
96                                             const AnalysisState &state) const {
97     if (&opOperand == &op->getOpOperand(0) /*src*/)
98       return {op->getOpResult(0)};
99     return {};
100   }
101 
102   BufferRelation bufferRelation(Operation *op, OpResult opResult,
103                                 const AnalysisState &state) const {
104     return BufferRelation::Equivalent;
105   }
106 
107   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
108                           const BufferizationOptions &options) const {
109     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
110     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
111     Value buffer = getBuffer(rewriter, collapseShapeOp.src(), options);
112     auto bufferType = buffer.getType().cast<MemRefType>();
113 
114     if (tensorResultType.getRank() == 0) {
115       // 0-d collapses must go through a different op builder.
116       MemRefType resultType;
117 
118       if (bufferType.getLayout().isIdentity()) {
119         // Standard layout: result type has no offset.
120         MemRefLayoutAttrInterface layout;
121         resultType = MemRefType::get({}, tensorResultType.getElementType(),
122                                      layout, bufferType.getMemorySpace());
123       } else {
124         // Source memref has a layout map: result type has the same offset as
125         // the source type.
126         SmallVector<int64_t> strides;
127         int64_t offset;
128         if (failed(getStridesAndOffset(bufferType, strides, offset)))
129           return failure();
130         AffineMap resultLayout =
131             makeStridedLinearLayoutMap({}, offset, op->getContext());
132         resultType =
133             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
134                             bufferType.getMemorySpaceAsInt());
135       }
136 
137       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
138           rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
139       return success();
140     }
141 
142     // If the dims are not collapsible (due to an incompatible source layout
143     // map), force an out-of-place bufferization, i.e., a buffer copy. This
144     // newly allocated buffer will have no layout map and thus be collapsible.
145     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
146         bufferType, collapseShapeOp.getReassociationIndices());
147     if (!canBeCollapsed) {
148       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
149       AnalysisState analysisState(options);
150       Value tensorAlloc = allocateTensorForShapedValue(
151           rewriter, op->getLoc(), collapseShapeOp.src(),
152           analysisState.isTensorYielded(collapseShapeOp.result()));
153       auto memrefType =
154           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
155                           collapseShapeOp.getSrcType().getElementType(),
156                           AffineMap(), bufferType.getMemorySpaceAsInt());
157       buffer = rewriter.create<bufferization::ToMemrefOp>(
158           op->getLoc(), memrefType, tensorAlloc);
159     }
160 
161     // Result type is inferred by the builder.
162     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
163         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
164     return success();
165   }
166 };
167 
168 /// Bufferization of tensor.dim. Replace with memref.dim.
169 struct DimOpInterface
170     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
171                                                     tensor::DimOp> {
172   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
173                               const AnalysisState &state) const {
174     return true;
175   }
176 
177   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
178                                const AnalysisState &state) const {
179     return false;
180   }
181 
182   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
183                                             const AnalysisState &state) const {
184     return {};
185   }
186 
187   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
188                           const BufferizationOptions &options) const {
189     auto dimOp = cast<tensor::DimOp>(op);
190     auto v = getBuffer(rewriter, dimOp.source(), options);
191     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
192     return success();
193   }
194 };
195 
196 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
197 struct ExpandShapeOpInterface
198     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
199                                                     tensor::ExpandShapeOp> {
200   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
201                               const AnalysisState &state) const {
202     return false;
203   }
204 
205   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
206                                const AnalysisState &state) const {
207     return false;
208   }
209 
210   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
211                                             const AnalysisState &state) const {
212     if (&opOperand == &op->getOpOperand(0) /*src*/)
213       return {op->getOpResult(0)};
214     return {};
215   }
216 
217   BufferRelation bufferRelation(Operation *op, OpResult opResult,
218                                 const AnalysisState &state) const {
219     return BufferRelation::Equivalent;
220   }
221 
222   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
223                           const BufferizationOptions &options) const {
224     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
225     auto tensorResultType = expandShapeOp.getResultType();
226     auto buffer = getBuffer(rewriter, expandShapeOp.src(), options);
227 
228     // Memref result type is inferred by the builder based on reassociation
229     // indices and result shape.
230     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
231         rewriter, op, tensorResultType.getShape(), buffer,
232         expandShapeOp.getReassociationIndices());
233     return success();
234   }
235 };
236 
237 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
238 struct ExtractSliceOpInterface
239     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
240                                                     tensor::ExtractSliceOp> {
241   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
242                               const AnalysisState &state) const {
243     return false;
244   }
245 
246   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
247                                const AnalysisState &state) const {
248     return false;
249   }
250 
251   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
252                                             const AnalysisState &state) const {
253     if (&opOperand == &op->getOpOperand(0) /*source*/)
254       return {op->getOpResult(0)};
255     return {};
256   }
257 
258   BufferRelation bufferRelation(Operation *op, OpResult opResult,
259                                 const AnalysisState &state) const {
260     return BufferRelation::None;
261   }
262 
263   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
264                           const BufferizationOptions &options) const {
265     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
266     Location loc = extractSliceOp.getLoc();
267 
268     // Even if this op was decided to bufferize out-of-place, do not insert the
269     // buffer copy yet. This is done later in this function.
270     auto srcMemref = getBuffer(rewriter, extractSliceOp.source(), options);
271     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
272     auto dstTensorType =
273         extractSliceOp.result().getType().cast<RankedTensorType>();
274 
275     // Expand offsets, sizes and strides to the full rank to handle the
276     // rank-reducing case.
277     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
278     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
279     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
280     OffsetSizeAndStrideOpInterface::expandToRank(
281         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
282         [&](Value target, int64_t dim) -> OpFoldResult {
283           auto shapedType = target.getType().cast<ShapedType>();
284           if (shapedType.isDynamicDim(dim))
285             return rewriter.create<memref::DimOp>(loc, target, dim).result();
286           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
287         });
288     // Bufferize to subview.
289     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
290                                  dstTensorType.getRank(), srcMemrefType,
291                                  mixedOffsets, mixedSizes, mixedStrides)
292                                  .cast<MemRefType>();
293     Value subView = rewriter.create<memref::SubViewOp>(
294         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
295         mixedStrides);
296 
297     replaceOpWithBufferizedValues(rewriter, op, subView);
298     return success();
299   }
300 };
301 
302 /// Bufferization of tensor.extract. Replace with memref.load.
303 struct ExtractOpInterface
304     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
305                                                     tensor::ExtractOp> {
306   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
307                               const AnalysisState &state) const {
308     return true;
309   }
310 
311   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
312                                const AnalysisState &state) const {
313     return false;
314   }
315 
316   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
317                                             const AnalysisState &state) const {
318     return {};
319   }
320 
321   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
322                           const BufferizationOptions &options) const {
323     auto extractOp = cast<tensor::ExtractOp>(op);
324     Value srcMemref = getBuffer(rewriter, extractOp.tensor(), options);
325     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
326                                                  extractOp.indices());
327     return success();
328   }
329 };
330 
331 // Implements backtracking to traverse indices of the output buffer while
332 // iterating over op.elements().
333 static void createStores(RewriterBase &rewriter, Location loc, int dim,
334                          Value buffer, ArrayRef<int64_t> shape,
335                          ArrayRef<Value> constants,
336                          OperandRange::iterator &elementIt,
337                          SmallVectorImpl<Value> &indices) {
338   if (dim == static_cast<int>(shape.size()) - 1) {
339     for (int i = 0; i < shape.back(); ++i) {
340       indices.back() = constants[i];
341       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
342       ++elementIt;
343     }
344     return;
345   }
346   for (int i = 0; i < shape[dim]; ++i) {
347     indices[dim] = constants[i];
348     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
349                  indices);
350   }
351 }
352 
353 /// Bufferization of tensor.from_elements.
354 struct FromElementsOpInterface
355     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
356                                                     tensor::FromElementsOp> {
357   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
358                           const BufferizationOptions &options) const {
359     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
360 
361     // Allocate a buffer for the result.
362     Location loc = op->getLoc();
363     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
364     auto shape = tensorType.getShape();
365     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
366     AnalysisState analysisState(options);
367     Value tensorAlloc = allocateTensorForShapedValue(
368         rewriter, loc, fromElementsOp.result(),
369         analysisState.isTensorYielded(fromElementsOp.result()),
370         /*copy=*/false);
371     auto memrefType =
372         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
373     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
374         op->getLoc(), memrefType, tensorAlloc);
375 
376     // Case: tensor<0xelem_type>.
377     if (fromElementsOp.elements().empty()) {
378       replaceOpWithBufferizedValues(rewriter, op, buffer);
379       return success();
380     }
381 
382     // Case: tensor<elem_type>.
383     if (shape.empty()) {
384       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
385                                        buffer);
386       replaceOpWithBufferizedValues(rewriter, op, buffer);
387       return success();
388     }
389 
390     // Create constants for the range of possible indices [0, max{shape_i}).
391     auto maxDim = *std::max_element(shape.begin(), shape.end());
392     SmallVector<Value, 2> constants;
393     constants.reserve(maxDim);
394     for (int i = 0; i < maxDim; ++i)
395       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
396 
397     // Traverse all `elements` and create `memref.store` ops.
398     auto elementIt = fromElementsOp.elements().begin();
399     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
400     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
401                  indices);
402 
403     replaceOpWithBufferizedValues(rewriter, op, buffer);
404     return success();
405   }
406 };
407 
408 /// Bufferization of tensor.generate.
409 struct GenerateOpInterface
410     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
411                                                     tensor::GenerateOp> {
412   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
413                           const BufferizationOptions &options) const {
414     auto generateOp = cast<tensor::GenerateOp>(op);
415     auto tensorType = generateOp.getType().cast<RankedTensorType>();
416     // Allocate memory.
417     Location loc = op->getLoc();
418     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
419     AnalysisState analysisState(options);
420     Value tensorAlloc = allocateTensorForShapedValue(
421         rewriter, loc, generateOp.result(),
422         analysisState.isTensorYielded(generateOp.result()),
423         /*copy=*/false);
424     auto memrefType =
425         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
426     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
427         op->getLoc(), memrefType, tensorAlloc);
428 
429     // Collect loop bounds.
430     int64_t rank = memrefType.getRank();
431     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
432     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
433     SmallVector<Value, 4> lowerBounds(rank, zero);
434     SmallVector<Value, 4> steps(rank, one);
435     SmallVector<Value, 4> upperBounds;
436     int nextDynamicIndex = 0;
437     for (int i = 0; i < rank; i++) {
438       Value upperBound = memrefType.isDynamicDim(i)
439                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
440                              : rewriter.create<arith::ConstantIndexOp>(
441                                    loc, memrefType.getDimSize(i));
442       upperBounds.push_back(upperBound);
443     }
444 
445     // Generate tensor elements with a parallel loop that stores into
446     // each element of the resulting memref. We use mergeBlockBefore to "move"
447     // this op's body into the scf.parallel's body.
448     auto parallel =
449         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
450     Block *parallelBody = parallel.getBody();
451     rewriter.mergeBlockBefore(&generateOp.getBody().front(),
452                               parallelBody->getTerminator(),
453                               parallelBody->getArguments());
454     // Replace the inlined yield op with a store op. The scf.parallel's builder
455     // already populated an scf.yield at the end, so we don't need to worry
456     // about creating that.
457     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
458     rewriter.setInsertionPointAfter(elementYield);
459     rewriter.replaceOpWithNewOp<memref::StoreOp>(
460         elementYield, elementYield->getOperands()[0], buffer,
461         parallelBody->getArguments());
462 
463     replaceOpWithBufferizedValues(rewriter, op, buffer);
464     return success();
465   }
466 };
467 
468 /// Bufferization of tensor.insert. Replace with memref.store.
469 struct InsertOpInterface
470     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
471                                                     tensor::InsertOp> {
472   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
473                               const AnalysisState &state) const {
474     return true;
475   }
476 
477   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
478                                const AnalysisState &state) const {
479     return true;
480   }
481 
482   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
483                                             const AnalysisState &state) const {
484     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
485            "expected dest OpOperand");
486     return {op->getOpResult(0)};
487   }
488 
489   SmallVector<OpOperand *>
490   getAliasingOpOperand(Operation *op, OpResult opResult,
491                        const AnalysisState &state) const {
492     return {&op->getOpOperand(1) /*dest*/};
493   }
494 
495   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
496                           const BufferizationOptions &options) const {
497     auto insertOp = cast<tensor::InsertOp>(op);
498     Value destMemref = getBuffer(rewriter, insertOp.dest(), options);
499     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
500                                      destMemref, insertOp.indices());
501     replaceOpWithBufferizedValues(rewriter, op, destMemref);
502     return success();
503   }
504 
505   BufferRelation bufferRelation(Operation *op, OpResult opResult,
506                                 const AnalysisState &state) const {
507     return BufferRelation::Equivalent;
508   }
509 };
510 
511 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
512 /// equivalent operand / result and same offset/sizes/strides specification).
513 ///
514 /// This is one particular type of relationship between ops on tensors that
515 /// reduce to an equivalence on buffers. This should be generalized and
516 /// exposed as interfaces on the proper types.
517 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
518                                          ExtractSliceOp st, InsertSliceOp sti) {
519   if (!st || !sti)
520     return false;
521   if (sti != sti &&
522       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
523     return false;
524   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
525     return false;
526   return true;
527 }
528 
529 /// Return true if `value` is originating from an ExtractSliceOp that matches
530 /// the given InsertSliceOp.
531 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
532                                       InsertSliceOp insertOp) {
533   auto condition = [&](Value val) {
534     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
535       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
536         return true;
537     return false;
538   };
539 
540   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
541                       condition);
542 }
543 
544 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
545 /// certain circumstances, this op can also be a no-op.
546 struct InsertSliceOpInterface
547     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
548                                                     tensor::InsertSliceOp> {
549   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
550                               const AnalysisState &state) const {
551     return true;
552   }
553 
554   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
555                                const AnalysisState &state) const {
556     return &opOperand == &op->getOpOperand(1) /*dest*/;
557   }
558 
559   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
560                                             const AnalysisState &state) const {
561     if (&opOperand == &op->getOpOperand(1) /*dest*/)
562       return {op->getResult(0)};
563     return {};
564   }
565 
566   BufferRelation bufferRelation(Operation *op, OpResult opResult,
567                                 const AnalysisState &state) const {
568     return BufferRelation::Equivalent;
569   }
570 
571   bool isNotConflicting(Operation *op, OpOperand *uRead,
572                         OpOperand *uConflictingWrite,
573                         const AnalysisState &state) const {
574     Operation *readingOp = uRead->getOwner();
575     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
576 
577     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
578     // uRead is an InsertSliceOp...
579     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
580       // As an example, consider the following IR.
581       //
582       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
583       // %1 = linalg.fill %cst, %0 {inplace= [true] }
584       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
585       //     {inplace= [true] }
586 
587       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
588       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
589           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
590                                     insertSliceOp))
591         // Case 1: The main insight is that InsertSliceOp reads only part of
592         // the destination tensor. The overwritten area is not read. If
593         // uConflictingWrite writes into exactly the memory location that is
594         // being read by uRead, this is not a conflict.
595         //
596         // In the above example:
597         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
598         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
599         //
600         // The read of %t does not conflict with the write of the FillOp
601         // (same aliases!) because the area that the FillOp operates on is
602         // exactly the one that is *not* read via %t.
603         return true;
604 
605       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
606           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
607           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
608         // Case 2: The read of the source tensor and the write to the dest
609         // tensor via an InsertSliceOp is not a conflict if the read is
610         // reading exactly that part of an equivalent tensor that the
611         // InsertSliceOp is writing.
612         //
613         // In the above example:
614         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
615         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
616         return true;
617     }
618 
619     // If uConflictingWrite is an InsertSliceOp...
620     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
621       // As an example, consider the following IR.
622       //
623       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
624       // %1 = linalg.fill %cst, %0 {inplace= [true] }
625       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
626       //     {inplace= [true] }
627       // %3 = vector.transfer_read %1, %cst
628       //
629       // In the above example:
630       // uRead             = OpOperand 0 (%1) of vector.transfer_read
631       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
632       // lastWrite         = %1
633       //
634       // This is not a conflict because the InsertSliceOp overwrites the
635       // memory segment of %1 with the exact same data. (Effectively, there
636       // is no memory write here.)
637       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
638           state.areEquivalentBufferizedValues(uRead->get(),
639                                               insertSliceOp.source()) &&
640           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
641                                     insertSliceOp))
642         return true;
643 
644     return false;
645   }
646 
647   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
648                           const BufferizationOptions &options) const {
649     // insert_slice ops arise from tiling and bufferizing them out-of-place is
650     // generally a deal breaker. When used with loops, this ends up cloning the
651     // whole tensor on every single iteration and is a symptom of a
652     // catastrophically bad scheduling decision.
653     // TODO: be very loud about it or even consider failing the pass.
654     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
655     Location loc = insertSliceOp.getLoc();
656     Value dstMemref = getBuffer(rewriter, insertSliceOp.dest(), options);
657 
658     // Expand offsets, sizes and strides to the full rank to handle the
659     // rank-reducing case.
660     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
661     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
662     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
663     OffsetSizeAndStrideOpInterface::expandToRank(
664         dstMemref, mixedOffsets, mixedSizes, mixedStrides,
665         [&](Value target, int64_t dim) -> OpFoldResult {
666           auto shapedType = target.getType().cast<ShapedType>();
667           if (shapedType.isDynamicDim(dim))
668             return rewriter.create<memref::DimOp>(loc, target, dim).result();
669           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
670         });
671     // Take a subview of the dst.
672     auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
673     auto subviewMemRefType =
674         memref::SubViewOp::inferRankReducedResultType(
675             insertSliceOp.getSourceType().getRank(), dstMemrefType,
676             mixedOffsets, mixedSizes, mixedStrides)
677             .cast<MemRefType>();
678     Value subView = rewriter.create<memref::SubViewOp>(
679         loc, subviewMemRefType, dstMemref, mixedOffsets, mixedSizes,
680         mixedStrides);
681 
682     // Copy tensor. If this tensor.insert_slice has a matching
683     // tensor.extract_slice, the copy operation will eventually fold away.
684     auto srcMemref = getBuffer(rewriter, insertSliceOp.source(), options);
685     if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView)))
686       return failure();
687 
688     replaceOpWithBufferizedValues(rewriter, op, dstMemref);
689     return success();
690   }
691 };
692 
693 /// Bufferization of tensor.rank. Replace with memref.rank.
694 struct RankOpInterface
695     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
696                                                     tensor::RankOp> {
697   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
698                               const AnalysisState &state) const {
699     return true;
700   }
701 
702   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
703                                const AnalysisState &state) const {
704     return false;
705   }
706 
707   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
708                                             const AnalysisState &state) const {
709     return {};
710   }
711 
712   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
713                           const BufferizationOptions &options) const {
714     auto rankOp = cast<tensor::RankOp>(op);
715     auto v = getBuffer(rewriter, rankOp.tensor(), options);
716     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
717                                                  v);
718     return success();
719   }
720 };
721 
722 /// Bufferization of tensor.reshape. Replace with memref.reshape.
723 struct ReshapeOpInterface
724     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
725                                                     tensor::ReshapeOp> {
726   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
727                               const AnalysisState &state) const {
728     if (&opOperand == &op->getOpOperand(1) /* shape */)
729       return true;
730     return false;
731   }
732 
733   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
734                                const AnalysisState &state) const {
735     return false;
736   }
737 
738   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
739                                             const AnalysisState &state) const {
740     return {op->getOpResult(0)};
741   }
742 
743   BufferRelation bufferRelation(Operation *op, OpResult opResult,
744                                 const AnalysisState &state) const {
745     return BufferRelation::Equivalent;
746   }
747 
748   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
749                           const BufferizationOptions &options) const {
750     auto reshapeOp = cast<tensor::ReshapeOp>(op);
751     Value srcBuffer = getBuffer(rewriter, reshapeOp.source(), options);
752     Value shapeBuffer = getBuffer(rewriter, reshapeOp.shape(), options);
753     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
754     auto resultMemRefType = getMemRefType(resultTensorType, options);
755     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
756         rewriter, op, resultMemRefType, srcBuffer, shapeBuffer);
757     return success();
758   }
759 };
760 
761 } // namespace
762 } // namespace tensor
763 } // namespace mlir
764 
765 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
766     DialectRegistry &registry) {
767   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
768     CastOp::attachInterface<CastOpInterface>(*ctx);
769     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
770     DimOp::attachInterface<DimOpInterface>(*ctx);
771     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
772     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
773     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
774     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
775     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
776     InsertOp::attachInterface<InsertOpInterface>(*ctx);
777     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
778     RankOp::attachInterface<RankOpInterface>(*ctx);
779     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
780   });
781 }
782