1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::tensor;
22 
23 namespace mlir {
24 namespace tensor {
25 namespace {
26 
27 struct CastOpInterface
28     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
29                                                     tensor::CastOp> {
30   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31                               const AnalysisState &state) const {
32     return false;
33   }
34 
35   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
36                                const AnalysisState &state) const {
37     return false;
38   }
39 
40   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
41                                             const AnalysisState &state) const {
42     return {op->getResult(0)};
43   }
44 
45   BufferRelation bufferRelation(Operation *op, OpResult opResult,
46                                 const AnalysisState &state) const {
47     return BufferRelation::Equivalent;
48   }
49 
50   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51                           const BufferizationOptions &options) const {
52     auto castOp = cast<tensor::CastOp>(op);
53 
54     // The result buffer still has the old (pre-cast) type.
55     FailureOr<Value> resultBuffer =
56         getBuffer(rewriter, castOp.getSource(), options);
57     if (failed(resultBuffer))
58       return failure();
59     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType =
70         getMemRefType(resultTensorType, options, layout,
71                       sourceMemRefType.getMemorySpaceAsInt());
72 
73     // Replace the op with a memref.cast.
74     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
75                                              resultMemRefType) &&
76            "CallOp::bufferize: cast incompatible");
77     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
78                                                  *resultBuffer);
79 
80     return success();
81   }
82 };
83 
84 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
85 struct CollapseShapeOpInterface
86     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
87                                                     tensor::CollapseShapeOp> {
88   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
89                               const AnalysisState &state) const {
90     return false;
91   }
92 
93   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
94                                const AnalysisState &state) const {
95     return false;
96   }
97 
98   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
99                                             const AnalysisState &state) const {
100     if (&opOperand == &op->getOpOperand(0) /*src*/)
101       return {op->getOpResult(0)};
102     return {};
103   }
104 
105   BufferRelation bufferRelation(Operation *op, OpResult opResult,
106                                 const AnalysisState &state) const {
107     return BufferRelation::Equivalent;
108   }
109 
110   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111                           const BufferizationOptions &options) const {
112     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
113     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
114     FailureOr<Value> maybeBuffer =
115         getBuffer(rewriter, collapseShapeOp.getSrc(), options);
116     if (failed(maybeBuffer))
117       return failure();
118     Value buffer = *maybeBuffer;
119     auto bufferType = buffer.getType().cast<MemRefType>();
120 
121     if (tensorResultType.getRank() == 0) {
122       // 0-d collapses must go through a different op builder.
123       MemRefType resultType;
124 
125       if (bufferType.getLayout().isIdentity()) {
126         // Standard layout: result type has no offset.
127         MemRefLayoutAttrInterface layout;
128         resultType = MemRefType::get({}, tensorResultType.getElementType(),
129                                      layout, bufferType.getMemorySpace());
130       } else {
131         // Source memref has a layout map: result type has the same offset as
132         // the source type.
133         SmallVector<int64_t> strides;
134         int64_t offset;
135         if (failed(getStridesAndOffset(bufferType, strides, offset)))
136           return failure();
137         AffineMap resultLayout =
138             makeStridedLinearLayoutMap({}, offset, op->getContext());
139         resultType =
140             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
141                             bufferType.getMemorySpaceAsInt());
142       }
143 
144       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
145           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
146       return success();
147     }
148 
149     // If the dims are not collapsible (due to an incompatible source layout
150     // map), force an out-of-place bufferization, i.e., a buffer copy. This
151     // newly allocated buffer will have no layout map and thus be collapsible.
152     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
153         bufferType, collapseShapeOp.getReassociationIndices());
154     if (!canBeCollapsed) {
155       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
156       AnalysisState analysisState(options);
157       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
158           rewriter, op->getLoc(), collapseShapeOp.getSrc(),
159           analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
160       if (failed(tensorAlloc))
161         return failure();
162       auto memrefType =
163           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
164                           collapseShapeOp.getSrcType().getElementType(),
165                           AffineMap(), bufferType.getMemorySpaceAsInt());
166       buffer = rewriter.create<bufferization::ToMemrefOp>(
167           op->getLoc(), memrefType, *tensorAlloc);
168     }
169 
170     // Result type is inferred by the builder.
171     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
172         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
173     return success();
174   }
175 };
176 
177 /// Bufferization of tensor.dim. Replace with memref.dim.
178 struct DimOpInterface
179     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
180                                                     tensor::DimOp> {
181   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
182                               const AnalysisState &state) const {
183     return true;
184   }
185 
186   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
187                                const AnalysisState &state) const {
188     return false;
189   }
190 
191   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
192                                             const AnalysisState &state) const {
193     return {};
194   }
195 
196   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
197                           const BufferizationOptions &options) const {
198     auto dimOp = cast<tensor::DimOp>(op);
199     FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
200     if (failed(v))
201       return failure();
202     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
203                                                 dimOp.index());
204     return success();
205   }
206 };
207 
208 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
209 struct ExpandShapeOpInterface
210     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
211                                                     tensor::ExpandShapeOp> {
212   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
213                               const AnalysisState &state) const {
214     return false;
215   }
216 
217   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
218                                const AnalysisState &state) const {
219     return false;
220   }
221 
222   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
223                                             const AnalysisState &state) const {
224     if (&opOperand == &op->getOpOperand(0) /*src*/)
225       return {op->getOpResult(0)};
226     return {};
227   }
228 
229   BufferRelation bufferRelation(Operation *op, OpResult opResult,
230                                 const AnalysisState &state) const {
231     return BufferRelation::Equivalent;
232   }
233 
234   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
235                           const BufferizationOptions &options) const {
236     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
237     auto tensorResultType = expandShapeOp.getResultType();
238     FailureOr<Value> buffer =
239         getBuffer(rewriter, expandShapeOp.getSrc(), options);
240     if (failed(buffer))
241       return failure();
242 
243     // Memref result type is inferred by the builder based on reassociation
244     // indices and result shape.
245     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
246         rewriter, op, tensorResultType.getShape(), *buffer,
247         expandShapeOp.getReassociationIndices());
248     return success();
249   }
250 };
251 
252 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
253 struct ExtractSliceOpInterface
254     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
255                                                     tensor::ExtractSliceOp> {
256   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
257                               const AnalysisState &state) const {
258     return false;
259   }
260 
261   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
262                                const AnalysisState &state) const {
263     return false;
264   }
265 
266   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
267                                             const AnalysisState &state) const {
268     if (&opOperand == &op->getOpOperand(0) /*source*/)
269       return {op->getOpResult(0)};
270     return {};
271   }
272 
273   BufferRelation bufferRelation(Operation *op, OpResult opResult,
274                                 const AnalysisState &state) const {
275     return BufferRelation::None;
276   }
277 
278   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
279                           const BufferizationOptions &options) const {
280     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
281     Location loc = extractSliceOp.getLoc();
282 
283     // Even if this op was decided to bufferize out-of-place, do not insert the
284     // buffer copy yet. This is done later in this function.
285     FailureOr<Value> srcMemref =
286         getBuffer(rewriter, extractSliceOp.getSource(), options);
287     if (failed(srcMemref))
288       return failure();
289     auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
290     auto dstTensorType =
291         extractSliceOp.getResult().getType().cast<RankedTensorType>();
292 
293     // Expand offsets, sizes and strides to the full rank to handle the
294     // rank-reducing case.
295     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
296     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
297     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
298     OffsetSizeAndStrideOpInterface::expandToRank(
299         *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
300         [&](Value target, int64_t dim) -> OpFoldResult {
301           auto shapedType = target.getType().cast<ShapedType>();
302           if (shapedType.isDynamicDim(dim))
303             return rewriter.create<memref::DimOp>(loc, target, dim).result();
304           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
305         });
306     // Bufferize to subview.
307     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
308                                  dstTensorType.getRank(), srcMemrefType,
309                                  mixedOffsets, mixedSizes, mixedStrides)
310                                  .cast<MemRefType>();
311     Value subView = rewriter.create<memref::SubViewOp>(
312         loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
313         mixedStrides);
314 
315     replaceOpWithBufferizedValues(rewriter, op, subView);
316     return success();
317   }
318 };
319 
320 /// Bufferization of tensor.extract. Replace with memref.load.
321 struct ExtractOpInterface
322     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
323                                                     tensor::ExtractOp> {
324   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
325                               const AnalysisState &state) const {
326     return true;
327   }
328 
329   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
330                                const AnalysisState &state) const {
331     return false;
332   }
333 
334   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
335                                             const AnalysisState &state) const {
336     return {};
337   }
338 
339   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
340                           const BufferizationOptions &options) const {
341     auto extractOp = cast<tensor::ExtractOp>(op);
342     FailureOr<Value> srcMemref =
343         getBuffer(rewriter, extractOp.getTensor(), options);
344     if (failed(srcMemref))
345       return failure();
346     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
347                                                  extractOp.indices());
348     return success();
349   }
350 };
351 
352 // Implements backtracking to traverse indices of the output buffer while
353 // iterating over op.elements().
354 static void createStores(RewriterBase &rewriter, Location loc, int dim,
355                          Value buffer, ArrayRef<int64_t> shape,
356                          ArrayRef<Value> constants,
357                          OperandRange::iterator &elementIt,
358                          SmallVectorImpl<Value> &indices) {
359   if (dim == static_cast<int>(shape.size()) - 1) {
360     for (int i = 0; i < shape.back(); ++i) {
361       indices.back() = constants[i];
362       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
363       ++elementIt;
364     }
365     return;
366   }
367   for (int i = 0; i < shape[dim]; ++i) {
368     indices[dim] = constants[i];
369     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
370                  indices);
371   }
372 }
373 
374 /// Bufferization of tensor.from_elements.
375 struct FromElementsOpInterface
376     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
377                                                     tensor::FromElementsOp> {
378   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
379                           const BufferizationOptions &options) const {
380     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
381 
382     // Allocate a buffer for the result.
383     Location loc = op->getLoc();
384     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
385     auto shape = tensorType.getShape();
386     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
387     AnalysisState analysisState(options);
388     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
389         rewriter, loc, fromElementsOp.getResult(),
390         analysisState.isTensorYielded(fromElementsOp.getResult()), options,
391         /*copy=*/false);
392     if (failed(tensorAlloc))
393       return failure();
394     auto memrefType =
395         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
396     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
397         op->getLoc(), memrefType, *tensorAlloc);
398 
399     // Case: tensor<0xelem_type>.
400     if (fromElementsOp.getElements().empty()) {
401       replaceOpWithBufferizedValues(rewriter, op, buffer);
402       return success();
403     }
404 
405     // Case: tensor<elem_type>.
406     if (shape.empty()) {
407       rewriter.create<memref::StoreOp>(
408           loc, fromElementsOp.getElements().front(), buffer);
409       replaceOpWithBufferizedValues(rewriter, op, buffer);
410       return success();
411     }
412 
413     // Create constants for the range of possible indices [0, max{shape_i}).
414     auto maxDim = *std::max_element(shape.begin(), shape.end());
415     SmallVector<Value, 2> constants;
416     constants.reserve(maxDim);
417     for (int i = 0; i < maxDim; ++i)
418       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
419 
420     // Traverse all `elements` and create `memref.store` ops.
421     auto elementIt = fromElementsOp.getElements().begin();
422     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
423     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
424                  indices);
425 
426     replaceOpWithBufferizedValues(rewriter, op, buffer);
427     return success();
428   }
429 };
430 
431 /// Bufferization of tensor.generate.
432 struct GenerateOpInterface
433     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
434                                                     tensor::GenerateOp> {
435   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
436                           const BufferizationOptions &options) const {
437     auto generateOp = cast<tensor::GenerateOp>(op);
438     auto tensorType = generateOp.getType().cast<RankedTensorType>();
439     // Allocate memory.
440     Location loc = op->getLoc();
441     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
442     AnalysisState analysisState(options);
443     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
444         rewriter, loc, generateOp.getResult(),
445         analysisState.isTensorYielded(generateOp.getResult()), options,
446         /*copy=*/false);
447     if (failed(tensorAlloc))
448       return failure();
449     auto memrefType =
450         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
451     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
452         op->getLoc(), memrefType, *tensorAlloc);
453 
454     // Collect loop bounds.
455     int64_t rank = memrefType.getRank();
456     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
457     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
458     SmallVector<Value, 4> lowerBounds(rank, zero);
459     SmallVector<Value, 4> steps(rank, one);
460     SmallVector<Value, 4> upperBounds;
461     int nextDynamicIndex = 0;
462     for (int i = 0; i < rank; i++) {
463       Value upperBound =
464           memrefType.isDynamicDim(i)
465               ? generateOp.getDynamicExtents()[nextDynamicIndex++]
466               : rewriter.create<arith::ConstantIndexOp>(
467                     loc, memrefType.getDimSize(i));
468       upperBounds.push_back(upperBound);
469     }
470 
471     // Generate tensor elements with a parallel loop that stores into
472     // each element of the resulting memref. We use mergeBlockBefore to "move"
473     // this op's body into the scf.parallel's body.
474     auto parallel =
475         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
476     Block *parallelBody = parallel.getBody();
477     rewriter.mergeBlockBefore(&generateOp.getBody().front(),
478                               parallelBody->getTerminator(),
479                               parallelBody->getArguments());
480     // Replace the inlined yield op with a store op. The scf.parallel's builder
481     // already populated an scf.yield at the end, so we don't need to worry
482     // about creating that.
483     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
484     rewriter.setInsertionPointAfter(elementYield);
485     rewriter.replaceOpWithNewOp<memref::StoreOp>(
486         elementYield, elementYield->getOperands()[0], buffer,
487         parallelBody->getArguments());
488 
489     replaceOpWithBufferizedValues(rewriter, op, buffer);
490     return success();
491   }
492 };
493 
494 /// Bufferization of tensor.insert. Replace with memref.store.
495 struct InsertOpInterface
496     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
497                                                     tensor::InsertOp> {
498   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
499                               const AnalysisState &state) const {
500     return true;
501   }
502 
503   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
504                                const AnalysisState &state) const {
505     return true;
506   }
507 
508   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
509                                             const AnalysisState &state) const {
510     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
511            "expected dest OpOperand");
512     return {op->getOpResult(0)};
513   }
514 
515   SmallVector<OpOperand *>
516   getAliasingOpOperand(Operation *op, OpResult opResult,
517                        const AnalysisState &state) const {
518     return {&op->getOpOperand(1) /*dest*/};
519   }
520 
521   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
522                           const BufferizationOptions &options) const {
523     auto insertOp = cast<tensor::InsertOp>(op);
524     FailureOr<Value> destMemref =
525         getBuffer(rewriter, insertOp.getDest(), options);
526     if (failed(destMemref))
527       return failure();
528     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
529                                      *destMemref, insertOp.getIndices());
530     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
531     return success();
532   }
533 
534   BufferRelation bufferRelation(Operation *op, OpResult opResult,
535                                 const AnalysisState &state) const {
536     return BufferRelation::Equivalent;
537   }
538 };
539 
540 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
541 /// equivalent operand / result and same offset/sizes/strides specification).
542 ///
543 /// This is one particular type of relationship between ops on tensors that
544 /// reduce to an equivalence on buffers. This should be generalized and
545 /// exposed as interfaces on the proper types.
546 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
547                                          ExtractSliceOp st, InsertSliceOp sti) {
548   if (!st || !sti)
549     return false;
550   if (sti != sti &&
551       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
552     return false;
553   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
554     return false;
555   return true;
556 }
557 
558 /// Return true if `value` is originating from an ExtractSliceOp that matches
559 /// the given InsertSliceOp.
560 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
561                                       InsertSliceOp insertOp) {
562   auto condition = [&](Value val) {
563     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
564       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
565         return true;
566     return false;
567   };
568 
569   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
570                       condition);
571 }
572 
573 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
574 /// certain circumstances, this op can also be a no-op.
575 struct InsertSliceOpInterface
576     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
577                                                     tensor::InsertSliceOp> {
578   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
579                               const AnalysisState &state) const {
580     return true;
581   }
582 
583   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
584                                const AnalysisState &state) const {
585     return &opOperand == &op->getOpOperand(1) /*dest*/;
586   }
587 
588   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
589                                             const AnalysisState &state) const {
590     if (&opOperand == &op->getOpOperand(1) /*dest*/)
591       return {op->getResult(0)};
592     return {};
593   }
594 
595   BufferRelation bufferRelation(Operation *op, OpResult opResult,
596                                 const AnalysisState &state) const {
597     return BufferRelation::Equivalent;
598   }
599 
600   bool isNotConflicting(Operation *op, OpOperand *uRead,
601                         OpOperand *uConflictingWrite,
602                         const AnalysisState &state) const {
603     Operation *readingOp = uRead->getOwner();
604     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
605 
606     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
607     // uRead is an InsertSliceOp...
608     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
609       // As an example, consider the following IR.
610       //
611       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
612       // %1 = linalg.fill %cst, %0 {inplace= [true] }
613       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
614       //     {inplace= [true] }
615 
616       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
617       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
618           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
619                                     insertSliceOp))
620         // Case 1: The main insight is that InsertSliceOp reads only part of
621         // the destination tensor. The overwritten area is not read. If
622         // uConflictingWrite writes into exactly the memory location that is
623         // being read by uRead, this is not a conflict.
624         //
625         // In the above example:
626         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
627         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
628         //
629         // The read of %t does not conflict with the write of the FillOp
630         // (same aliases!) because the area that the FillOp operates on is
631         // exactly the one that is *not* read via %t.
632         return true;
633 
634       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
635           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
636           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
637         // Case 2: The read of the source tensor and the write to the dest
638         // tensor via an InsertSliceOp is not a conflict if the read is
639         // reading exactly that part of an equivalent tensor that the
640         // InsertSliceOp is writing.
641         //
642         // In the above example:
643         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
644         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
645         return true;
646     }
647 
648     // If uConflictingWrite is an InsertSliceOp...
649     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
650       // As an example, consider the following IR.
651       //
652       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
653       // %1 = linalg.fill %cst, %0 {inplace= [true] }
654       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
655       //     {inplace= [true] }
656       // %3 = vector.transfer_read %1, %cst
657       //
658       // In the above example:
659       // uRead             = OpOperand 0 (%1) of vector.transfer_read
660       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
661       // lastWrite         = %1
662       //
663       // This is not a conflict because the InsertSliceOp overwrites the
664       // memory segment of %1 with the exact same data. (Effectively, there
665       // is no memory write here.)
666       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
667           state.areEquivalentBufferizedValues(uRead->get(),
668                                               insertSliceOp.getSource()) &&
669           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
670                                     insertSliceOp))
671         return true;
672 
673     return false;
674   }
675 
676   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
677                           const BufferizationOptions &options) const {
678     // insert_slice ops arise from tiling and bufferizing them out-of-place is
679     // generally a deal breaker. When used with loops, this ends up cloning the
680     // whole tensor on every single iteration and is a symptom of a
681     // catastrophically bad scheduling decision.
682     // TODO: be very loud about it or even consider failing the pass.
683     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
684     Location loc = insertSliceOp.getLoc();
685     FailureOr<Value> dstMemref =
686         getBuffer(rewriter, insertSliceOp.getDest(), options);
687     if (failed(dstMemref))
688       return failure();
689 
690     // Expand offsets, sizes and strides to the full rank to handle the
691     // rank-reducing case.
692     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
693     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
694     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
695     OffsetSizeAndStrideOpInterface::expandToRank(
696         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
697         [&](Value target, int64_t dim) -> OpFoldResult {
698           auto shapedType = target.getType().cast<ShapedType>();
699           if (shapedType.isDynamicDim(dim))
700             return rewriter.create<memref::DimOp>(loc, target, dim).result();
701           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
702         });
703     // Take a subview of the dst.
704     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
705     auto subviewMemRefType =
706         memref::SubViewOp::inferRankReducedResultType(
707             insertSliceOp.getSourceType().getRank(), dstMemrefType,
708             mixedOffsets, mixedSizes, mixedStrides)
709             .cast<MemRefType>();
710     Value subView = rewriter.create<memref::SubViewOp>(
711         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
712         mixedStrides);
713 
714     // Copy tensor. If this tensor.insert_slice has a matching
715     // tensor.extract_slice, the copy operation will eventually fold away.
716     FailureOr<Value> srcMemref =
717         getBuffer(rewriter, insertSliceOp.getSource(), options);
718     if (failed(srcMemref))
719       return failure();
720     if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
721       return failure();
722 
723     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
724     return success();
725   }
726 };
727 
728 /// Bufferization of tensor.rank. Replace with memref.rank.
729 struct RankOpInterface
730     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
731                                                     tensor::RankOp> {
732   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
733                               const AnalysisState &state) const {
734     return true;
735   }
736 
737   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
738                                const AnalysisState &state) const {
739     return false;
740   }
741 
742   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
743                                             const AnalysisState &state) const {
744     return {};
745   }
746 
747   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
748                           const BufferizationOptions &options) const {
749     auto rankOp = cast<tensor::RankOp>(op);
750     FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
751     if (failed(v))
752       return failure();
753     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
754                                                  *v);
755     return success();
756   }
757 };
758 
759 /// Bufferization of tensor.reshape. Replace with memref.reshape.
760 struct ReshapeOpInterface
761     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
762                                                     tensor::ReshapeOp> {
763   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
764                               const AnalysisState &state) const {
765     if (&opOperand == &op->getOpOperand(1) /* shape */)
766       return true;
767     return false;
768   }
769 
770   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
771                                const AnalysisState &state) const {
772     return false;
773   }
774 
775   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
776                                             const AnalysisState &state) const {
777     return {op->getOpResult(0)};
778   }
779 
780   BufferRelation bufferRelation(Operation *op, OpResult opResult,
781                                 const AnalysisState &state) const {
782     return BufferRelation::Equivalent;
783   }
784 
785   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
786                           const BufferizationOptions &options) const {
787     auto reshapeOp = cast<tensor::ReshapeOp>(op);
788     FailureOr<Value> srcBuffer =
789         getBuffer(rewriter, reshapeOp.getSource(), options);
790     FailureOr<Value> shapeBuffer =
791         getBuffer(rewriter, reshapeOp.getShape(), options);
792     if (failed(srcBuffer) || failed(shapeBuffer))
793       return failure();
794     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
795     auto resultMemRefType = getMemRefType(resultTensorType, options);
796     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
797         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
798     return success();
799   }
800 };
801 
802 } // namespace
803 } // namespace tensor
804 } // namespace mlir
805 
806 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
807     DialectRegistry &registry) {
808   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
809     CastOp::attachInterface<CastOpInterface>(*ctx);
810     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
811     DimOp::attachInterface<DimOpInterface>(*ctx);
812     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
813     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
814     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
815     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
816     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
817     InsertOp::attachInterface<InsertOpInterface>(*ctx);
818     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
819     RankOp::attachInterface<RankOpInterface>(*ctx);
820     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
821   });
822 }
823