1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::tensor;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 
25 struct CastOpInterface
26     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
27                                                     tensor::CastOp> {
28   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
29                               const AnalysisState &state) const {
30     return false;
31   }
32 
33   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
34                                const AnalysisState &state) const {
35     return false;
36   }
37 
38   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
39                                             const AnalysisState &state) const {
40     return {op->getResult(0)};
41   }
42 
43   BufferRelation bufferRelation(Operation *op, OpResult opResult,
44                                 const AnalysisState &state) const {
45     return BufferRelation::Equivalent;
46   }
47 
48   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49                           BufferizationState &state) const {
50     auto castOp = cast<tensor::CastOp>(op);
51 
52     // The result buffer still has the old (pre-cast) type.
53     FailureOr<Value> resultBuffer =
54         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
55     if (failed(resultBuffer))
56       return failure();
57     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
58     Attribute memorySpace = sourceMemRefType.getMemorySpace();
59     TensorType resultTensorType =
60         castOp.getResult().getType().cast<TensorType>();
61     MemRefLayoutAttrInterface layout;
62 
63     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
64       if (resultTensorType.isa<RankedTensorType>())
65         layout = rankedMemRefType.getLayout();
66 
67     // Compute the new memref type.
68     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
69                                           layout, memorySpace);
70 
71     // Replace the op with a memref.cast.
72     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
73                                              resultMemRefType) &&
74            "CallOp::bufferize: cast incompatible");
75     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
76                                                  *resultBuffer);
77 
78     return success();
79   }
80 };
81 
82 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
83 struct CollapseShapeOpInterface
84     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
85                                                     tensor::CollapseShapeOp> {
86   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
87                               const AnalysisState &state) const {
88     return false;
89   }
90 
91   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
92                                const AnalysisState &state) const {
93     return false;
94   }
95 
96   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
97                                             const AnalysisState &state) const {
98     if (&opOperand == &op->getOpOperand(0) /*src*/)
99       return {op->getOpResult(0)};
100     return {};
101   }
102 
103   BufferRelation bufferRelation(Operation *op, OpResult opResult,
104                                 const AnalysisState &state) const {
105     return BufferRelation::Equivalent;
106   }
107 
108   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
109                           BufferizationState &state) const {
110     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
111     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
112     OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/;
113     auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>();
114 
115     if (tensorResultType.getRank() == 0) {
116       // 0-d collapses must go through a different op builder.
117       Value buffer = *state.getBuffer(rewriter, srcOperand);
118       MemRefType resultType;
119 
120       if (bufferType.getLayout().isIdentity()) {
121         // Standard layout: result type has no offset.
122         MemRefLayoutAttrInterface layout;
123         resultType = MemRefType::get({}, tensorResultType.getElementType(),
124                                      layout, bufferType.getMemorySpace());
125       } else {
126         // Source memref has a layout map: result type has the same offset as
127         // the source type.
128         SmallVector<int64_t> strides;
129         int64_t offset;
130         if (failed(getStridesAndOffset(bufferType, strides, offset)))
131           return failure();
132         AffineMap resultLayout =
133             makeStridedLinearLayoutMap({}, offset, op->getContext());
134         resultType =
135             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
136                             bufferType.getMemorySpaceAsInt());
137       }
138 
139       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
140           rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
141       return success();
142     }
143 
144     // If the dims are not collapsible (due to an incompatible source layout
145     // map), force an out-of-place bufferization, i.e., a buffer copy. This
146     // newly allocated buffer will have no layout map and thus be collapsible.
147     bool canBeCollapsed = memref::ExpandShapeOp::isGuaranteedCollapsible(
148         bufferType, collapseShapeOp.getReassociationIndices());
149     Optional<BufferizationState::ForceInPlacability> overrideInPlace =
150         canBeCollapsed
151             ? None
152             : Optional<BufferizationState::ForceInPlacability>(
153                   BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
154     Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace);
155 
156     // Result type is inferred by the builder.
157     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
158         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
159     return success();
160   }
161 };
162 
163 /// Bufferization of tensor.dim. Replace with memref.dim.
164 struct DimOpInterface
165     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
166                                                     tensor::DimOp> {
167   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
168                               const AnalysisState &state) const {
169     return true;
170   }
171 
172   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
173                                const AnalysisState &state) const {
174     return false;
175   }
176 
177   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
178                                             const AnalysisState &state) const {
179     return {};
180   }
181 
182   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
183                           BufferizationState &state) const {
184     auto dimOp = cast<tensor::DimOp>(op);
185     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
186     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
187     return success();
188   }
189 };
190 
191 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
192 struct ExpandShapeOpInterface
193     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
194                                                     tensor::ExpandShapeOp> {
195   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
196                               const AnalysisState &state) const {
197     return false;
198   }
199 
200   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
201                                const AnalysisState &state) const {
202     return false;
203   }
204 
205   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
206                                             const AnalysisState &state) const {
207     if (&opOperand == &op->getOpOperand(0) /*src*/)
208       return {op->getOpResult(0)};
209     return {};
210   }
211 
212   BufferRelation bufferRelation(Operation *op, OpResult opResult,
213                                 const AnalysisState &state) const {
214     return BufferRelation::Equivalent;
215   }
216 
217   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
218                           BufferizationState &state) const {
219     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
220     auto tensorResultType = expandShapeOp.getResultType();
221     Value buffer =
222         *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
223 
224     // Memref result type is inferred by the builder based on reassociation
225     // indices and result shape.
226     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
227         rewriter, op, tensorResultType.getShape(), buffer,
228         expandShapeOp.getReassociationIndices());
229     return success();
230   }
231 };
232 
233 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
234 struct ExtractSliceOpInterface
235     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
236                                                     tensor::ExtractSliceOp> {
237   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
238                               const AnalysisState &state) const {
239     return false;
240   }
241 
242   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
243                                const AnalysisState &state) const {
244     return false;
245   }
246 
247   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
248                                             const AnalysisState &state) const {
249     if (&opOperand == &op->getOpOperand(0) /*source*/)
250       return {op->getOpResult(0)};
251     return {};
252   }
253 
254   BufferRelation bufferRelation(Operation *op, OpResult opResult,
255                                 const AnalysisState &state) const {
256     return BufferRelation::None;
257   }
258 
259   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
260                           BufferizationState &state) const {
261     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
262     Location loc = extractSliceOp.getLoc();
263 
264     // Even if this op was decided to bufferize out-of-place, do not insert the
265     // buffer copy yet. This is done later in this function.
266     Value srcMemref =
267         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
268                          BufferizationState::ForceInPlacability::FORCE_INPLACE);
269     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
270     auto dstTensorType =
271         extractSliceOp.result().getType().cast<RankedTensorType>();
272 
273     // If not inplaceable, alloc.
274     bool inplace =
275         state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
276     Value alloc;
277     if (!inplace) {
278       FailureOr<Value> allocOrFailure =
279           state.createAlloc(rewriter, loc, extractSliceOp.result());
280       if (failed(allocOrFailure))
281         return failure();
282       alloc = *allocOrFailure;
283     }
284 
285     // Expand offsets, sizes and strides to the full rank to handle the
286     // rank-reducing case.
287     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
288     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
289     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
290     OffsetSizeAndStrideOpInterface::expandToRank(
291         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
292         [&](Value target, int64_t dim) -> OpFoldResult {
293           auto shapedType = target.getType().cast<ShapedType>();
294           if (shapedType.isDynamicDim(dim))
295             return rewriter.create<memref::DimOp>(loc, target, dim).result();
296           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
297         });
298     // Bufferize to subview.
299     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
300                                  dstTensorType.getRank(), srcMemrefType,
301                                  mixedOffsets, mixedSizes, mixedStrides)
302                                  .cast<MemRefType>();
303     Value subView = rewriter.create<memref::SubViewOp>(
304         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
305         mixedStrides);
306 
307     // If not inplaceable, copy.
308     if (!inplace) {
309       // Do not copy if the copied data is never read.
310       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
311         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
312                                 alloc, state.getOptions())))
313           return failure();
314       subView = alloc;
315     }
316 
317     replaceOpWithBufferizedValues(rewriter, op, subView);
318     return success();
319   }
320 };
321 
322 /// Bufferization of tensor.extract. Replace with memref.load.
323 struct ExtractOpInterface
324     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
325                                                     tensor::ExtractOp> {
326   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
327                               const AnalysisState &state) const {
328     return true;
329   }
330 
331   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
332                                const AnalysisState &state) const {
333     return false;
334   }
335 
336   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
337                                             const AnalysisState &state) const {
338     return {};
339   }
340 
341   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
342                           BufferizationState &state) const {
343     auto extractOp = cast<tensor::ExtractOp>(op);
344     Value srcMemref =
345         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
346     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
347                                                  extractOp.indices());
348     return success();
349   }
350 };
351 
352 // Implements backtracking to traverse indices of the output buffer while
353 // iterating over op.elements().
354 static void createStores(RewriterBase &rewriter, Location loc, int dim,
355                          Value buffer, ArrayRef<int64_t> shape,
356                          ArrayRef<Value> constants,
357                          OperandRange::iterator &elementIt,
358                          SmallVectorImpl<Value> &indices) {
359   if (dim == static_cast<int>(shape.size()) - 1) {
360     for (int i = 0; i < shape.back(); ++i) {
361       indices.back() = constants[i];
362       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
363       ++elementIt;
364     }
365     return;
366   }
367   for (int i = 0; i < shape[dim]; ++i) {
368     indices[dim] = constants[i];
369     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
370                  indices);
371   }
372 }
373 
374 /// Bufferization of tensor.from_elements.
375 struct FromElementsOpInterface
376     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
377                                                     tensor::FromElementsOp> {
378   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
379                           BufferizationState &state) const {
380     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
381 
382     // Allocate a buffer for the result.
383     Location loc = op->getLoc();
384     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
385     auto shape = tensorType.getShape();
386     FailureOr<Value> maybeBuffer =
387         state.createAlloc(rewriter, loc, fromElementsOp.result());
388     if (failed(maybeBuffer))
389       return failure();
390     Value buffer = *maybeBuffer;
391 
392     // Case: tensor<0xelem_type>.
393     if (fromElementsOp.elements().empty()) {
394       replaceOpWithBufferizedValues(rewriter, op, buffer);
395       return success();
396     }
397 
398     // Case: tensor<elem_type>.
399     if (shape.empty()) {
400       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
401                                        buffer);
402       replaceOpWithBufferizedValues(rewriter, op, buffer);
403       return success();
404     }
405 
406     // Create constants for the range of possible indices [0, max{shape_i}).
407     auto maxDim = *std::max_element(shape.begin(), shape.end());
408     SmallVector<Value, 2> constants;
409     constants.reserve(maxDim);
410     for (int i = 0; i < maxDim; ++i)
411       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
412 
413     // Traverse all `elements` and create `memref.store` ops.
414     auto elementIt = fromElementsOp.elements().begin();
415     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
416     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
417                  indices);
418 
419     replaceOpWithBufferizedValues(rewriter, op, buffer);
420     return success();
421   }
422 };
423 
424 /// Bufferization of tensor.generate.
425 struct GenerateOpInterface
426     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
427                                                     tensor::GenerateOp> {
428   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
429                           BufferizationState &state) const {
430     auto generateOp = cast<tensor::GenerateOp>(op);
431 
432     // Allocate memory.
433     Location loc = op->getLoc();
434     MemRefType memrefType =
435         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
436     FailureOr<Value> maybeResult =
437         state.createAlloc(rewriter, loc, generateOp.result());
438     if (failed(maybeResult))
439       return failure();
440     Value result = *maybeResult;
441 
442     // Collect loop bounds.
443     int64_t rank = memrefType.getRank();
444     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
445     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
446     SmallVector<Value, 4> lowerBounds(rank, zero);
447     SmallVector<Value, 4> steps(rank, one);
448     SmallVector<Value, 4> upperBounds;
449     int nextDynamicIndex = 0;
450     for (int i = 0; i < rank; i++) {
451       Value upperBound = memrefType.isDynamicDim(i)
452                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
453                              : rewriter.create<arith::ConstantIndexOp>(
454                                    loc, memrefType.getDimSize(i));
455       upperBounds.push_back(upperBound);
456     }
457 
458     // Generate tensor elements with a parallel loop that stores into
459     // each element of the resulting memref. We use mergeBlockBefore to "move"
460     // this op's body into the scf.parallel's body.
461     auto parallel =
462         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
463     Block *parallelBody = parallel.getBody();
464     rewriter.mergeBlockBefore(generateOp.getBody(),
465                               parallelBody->getTerminator(),
466                               parallelBody->getArguments());
467     // Replace the inlined yield op with a store op. The scf.parallel's builder
468     // already populated an scf.yield at the end, so we don't need to worry
469     // about creating that.
470     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
471     rewriter.setInsertionPointAfter(elementYield);
472     rewriter.replaceOpWithNewOp<memref::StoreOp>(
473         elementYield, elementYield->getOperands()[0], result,
474         parallelBody->getArguments());
475 
476     replaceOpWithBufferizedValues(rewriter, op, result);
477     return success();
478   }
479 };
480 
481 /// Bufferization of tensor.insert. Replace with memref.store.
482 struct InsertOpInterface
483     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
484                                                     tensor::InsertOp> {
485   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
486                               const AnalysisState &state) const {
487     return true;
488   }
489 
490   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
491                                const AnalysisState &state) const {
492     return true;
493   }
494 
495   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
496                                             const AnalysisState &state) const {
497     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
498            "expected dest OpOperand");
499     return {op->getOpResult(0)};
500   }
501 
502   SmallVector<OpOperand *>
503   getAliasingOpOperand(Operation *op, OpResult opResult,
504                        const AnalysisState &state) const {
505     return {&op->getOpOperand(1) /*dest*/};
506   }
507 
508   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
509                           BufferizationState &state) const {
510     auto insertOp = cast<tensor::InsertOp>(op);
511     FailureOr<Value> destMemref =
512         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
513     if (failed(destMemref))
514       return failure();
515     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
516                                      *destMemref, insertOp.indices());
517     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
518     return success();
519   }
520 
521   BufferRelation bufferRelation(Operation *op, OpResult opResult,
522                                 const AnalysisState &state) const {
523     return BufferRelation::Equivalent;
524   }
525 };
526 
527 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
528 /// equivalent operand / result and same offset/sizes/strides specification).
529 ///
530 /// This is one particular type of relationship between ops on tensors that
531 /// reduce to an equivalence on buffers. This should be generalized and
532 /// exposed as interfaces on the proper types.
533 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
534                                          ExtractSliceOp st, InsertSliceOp sti) {
535   if (!st || !sti)
536     return false;
537   if (sti != sti &&
538       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
539     return false;
540   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
541     return false;
542   return true;
543 }
544 
545 /// Return true if `value` is originating from an ExtractSliceOp that matches
546 /// the given InsertSliceOp.
547 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
548                                       InsertSliceOp insertOp) {
549   auto condition = [&](Value val) {
550     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
551       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
552         return true;
553     return false;
554   };
555 
556   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
557                       condition);
558 }
559 
560 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
561 /// certain circumstances, this op can also be a no-op.
562 struct InsertSliceOpInterface
563     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
564                                                     tensor::InsertSliceOp> {
565   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
566                               const AnalysisState &state) const {
567     return true;
568   }
569 
570   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
571                                const AnalysisState &state) const {
572     return &opOperand == &op->getOpOperand(1) /*dest*/;
573   }
574 
575   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
576                                             const AnalysisState &state) const {
577     if (&opOperand == &op->getOpOperand(1) /*dest*/)
578       return {op->getResult(0)};
579     return {};
580   }
581 
582   BufferRelation bufferRelation(Operation *op, OpResult opResult,
583                                 const AnalysisState &state) const {
584     return BufferRelation::Equivalent;
585   }
586 
587   bool isNotConflicting(Operation *op, OpOperand *uRead,
588                         OpOperand *uConflictingWrite,
589                         const AnalysisState &state) const {
590     Operation *readingOp = uRead->getOwner();
591     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
592 
593     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
594     // uRead is an InsertSliceOp...
595     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
596       // As an example, consider the following IR.
597       //
598       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
599       // %1 = linalg.fill %cst, %0 {inplace= [true] }
600       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
601       //     {inplace= [true] }
602 
603       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
604       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
605           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
606                                     insertSliceOp))
607         // Case 1: The main insight is that InsertSliceOp reads only part of
608         // the destination tensor. The overwritten area is not read. If
609         // uConflictingWrite writes into exactly the memory location that is
610         // being read by uRead, this is not a conflict.
611         //
612         // In the above example:
613         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
614         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
615         //
616         // The read of %t does not conflict with the write of the FillOp
617         // (same aliases!) because the area that the FillOp operates on is
618         // exactly the one that is *not* read via %t.
619         return true;
620 
621       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
622           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
623           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
624         // Case 2: The read of the source tensor and the write to the dest
625         // tensor via an InsertSliceOp is not a conflict if the read is
626         // reading exactly that part of an equivalent tensor that the
627         // InsertSliceOp is writing.
628         //
629         // In the above example:
630         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
631         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
632         return true;
633     }
634 
635     // If uConflictingWrite is an InsertSliceOp...
636     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
637       // As an example, consider the following IR.
638       //
639       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
640       // %1 = linalg.fill %cst, %0 {inplace= [true] }
641       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
642       //     {inplace= [true] }
643       // %3 = vector.transfer_read %1, %cst
644       //
645       // In the above example:
646       // uRead             = OpOperand 0 (%1) of vector.transfer_read
647       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
648       // lastWrite         = %1
649       //
650       // This is not a conflict because the InsertSliceOp overwrites the
651       // memory segment of %1 with the exact same data. (Effectively, there
652       // is no memory write here.)
653       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
654           state.areEquivalentBufferizedValues(uRead->get(),
655                                               insertSliceOp.source()) &&
656           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
657                                     insertSliceOp))
658         return true;
659 
660     return false;
661   }
662 
663   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
664                           BufferizationState &state) const {
665     // insert_slice ops arise from tiling and bufferizing them out-of-place is
666     // generally a deal breaker. When used with loops, this ends up cloning the
667     // whole tensor on every single iteration and is a symptom of a
668     // catastrophically bad scheduling decision.
669     // TODO: be very loud about it or even consider failing the pass.
670     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
671     Location loc = insertSliceOp.getLoc();
672 
673     // When bufferizing out-of-place, `getResultBuffer` allocates.
674     FailureOr<Value> dstMemref =
675         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
676     if (failed(dstMemref))
677       return failure();
678 
679     // Expand offsets, sizes and strides to the full rank to handle the
680     // rank-reducing case.
681     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
682     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
683     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
684     OffsetSizeAndStrideOpInterface::expandToRank(
685         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
686         [&](Value target, int64_t dim) -> OpFoldResult {
687           auto shapedType = target.getType().cast<ShapedType>();
688           if (shapedType.isDynamicDim(dim))
689             return rewriter.create<memref::DimOp>(loc, target, dim).result();
690           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
691         });
692     // Take a subview of the dst.
693     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
694     auto subviewMemRefType =
695         memref::SubViewOp::inferRankReducedResultType(
696             insertSliceOp.getSourceType().getRank(), dstMemrefType,
697             mixedOffsets, mixedSizes, mixedStrides)
698             .cast<MemRefType>();
699     Value subView = rewriter.create<memref::SubViewOp>(
700         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
701         mixedStrides);
702 
703     // Copy tensor. If this tensor.insert_slice has a matching
704     // tensor.extract_slice, the copy operation will eventually fold away.
705     Value srcMemref =
706         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
707     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
708                             state.getOptions())))
709       return failure();
710 
711     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
712     return success();
713   }
714 };
715 
716 /// Bufferization of tensor.rank. Replace with memref.rank.
717 struct RankOpInterface
718     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
719                                                     tensor::RankOp> {
720   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
721                               const AnalysisState &state) const {
722     return true;
723   }
724 
725   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
726                                const AnalysisState &state) const {
727     return false;
728   }
729 
730   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
731                                             const AnalysisState &state) const {
732     return {};
733   }
734 
735   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
736                           BufferizationState &state) const {
737     auto rankOp = cast<tensor::RankOp>(op);
738     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
739     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
740                                                  v);
741     return success();
742   }
743 };
744 
745 } // namespace
746 } // namespace tensor
747 } // namespace mlir
748 
749 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
750     DialectRegistry &registry) {
751   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
752     CastOp::attachInterface<CastOpInterface>(*ctx);
753     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
754     DimOp::attachInterface<DimOpInterface>(*ctx);
755     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
756     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
757     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
758     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
759     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
760     InsertOp::attachInterface<InsertOpInterface>(*ctx);
761     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
762     RankOp::attachInterface<RankOpInterface>(*ctx);
763   });
764 }
765