1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::tensor;
21 
22 namespace mlir {
23 namespace tensor {
24 namespace {
25 
26 struct CastOpInterface
27     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
28                                                     tensor::CastOp> {
29   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
30                               const AnalysisState &state) const {
31     return false;
32   }
33 
34   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
35                                const AnalysisState &state) const {
36     return false;
37   }
38 
39   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
40                                             const AnalysisState &state) const {
41     return {op->getResult(0)};
42   }
43 
44   BufferRelation bufferRelation(Operation *op, OpResult opResult,
45                                 const AnalysisState &state) const {
46     return BufferRelation::Equivalent;
47   }
48 
49   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50                           BufferizationState &state) const {
51     auto castOp = cast<tensor::CastOp>(op);
52 
53     // The result buffer still has the old (pre-cast) type.
54     FailureOr<Value> resultBuffer =
55         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
56     if (failed(resultBuffer))
57       return failure();
58     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
59     Attribute memorySpace = sourceMemRefType.getMemorySpace();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
70                                           layout, memorySpace);
71 
72     // Replace the op with a memref.cast.
73     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
74                                              resultMemRefType) &&
75            "CallOp::bufferize: cast incompatible");
76     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
77                                                  *resultBuffer);
78 
79     return success();
80   }
81 };
82 
83 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
84 struct CollapseShapeOpInterface
85     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
86                                                     tensor::CollapseShapeOp> {
87   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88                               const AnalysisState &state) const {
89     return false;
90   }
91 
92   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93                                const AnalysisState &state) const {
94     return false;
95   }
96 
97   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
98                                             const AnalysisState &state) const {
99     if (&opOperand == &op->getOpOperand(0) /*src*/)
100       return {op->getOpResult(0)};
101     return {};
102   }
103 
104   BufferRelation bufferRelation(Operation *op, OpResult opResult,
105                                 const AnalysisState &state) const {
106     return BufferRelation::Equivalent;
107   }
108 
109   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
110                           BufferizationState &state) const {
111     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
112     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
113     OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/;
114     auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>();
115 
116     if (tensorResultType.getRank() == 0) {
117       // 0-d collapses must go through a different op builder.
118       auto buffer = state.getBuffer(rewriter, srcOperand);
119       if (failed(buffer))
120         return failure();
121       MemRefType resultType;
122 
123       if (bufferType.getLayout().isIdentity()) {
124         // Standard layout: result type has no offset.
125         MemRefLayoutAttrInterface layout;
126         resultType = MemRefType::get({}, tensorResultType.getElementType(),
127                                      layout, bufferType.getMemorySpace());
128       } else {
129         // Source memref has a layout map: result type has the same offset as
130         // the source type.
131         SmallVector<int64_t> strides;
132         int64_t offset;
133         if (failed(getStridesAndOffset(bufferType, strides, offset)))
134           return failure();
135         AffineMap resultLayout =
136             makeStridedLinearLayoutMap({}, offset, op->getContext());
137         resultType =
138             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
139                             bufferType.getMemorySpaceAsInt());
140       }
141 
142       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
143           rewriter, op, resultType, *buffer, collapseShapeOp.reassociation());
144       return success();
145     }
146 
147     // If the dims are not collapsible (due to an incompatible source layout
148     // map), force an out-of-place bufferization, i.e., a buffer copy. This
149     // newly allocated buffer will have no layout map and thus be collapsible.
150     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
151         bufferType, collapseShapeOp.getReassociationIndices());
152     Optional<BufferizationState::ForceInPlacability> overrideInPlace =
153         canBeCollapsed
154             ? None
155             : Optional<BufferizationState::ForceInPlacability>(
156                   BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
157     auto buffer = state.getBuffer(rewriter, srcOperand, overrideInPlace);
158     if (failed(buffer))
159       return failure();
160 
161     // Result type is inferred by the builder.
162     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
163         rewriter, op, *buffer, collapseShapeOp.getReassociationIndices());
164     return success();
165   }
166 };
167 
168 /// Bufferization of tensor.dim. Replace with memref.dim.
169 struct DimOpInterface
170     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
171                                                     tensor::DimOp> {
172   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
173                               const AnalysisState &state) const {
174     return true;
175   }
176 
177   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
178                                const AnalysisState &state) const {
179     return false;
180   }
181 
182   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
183                                             const AnalysisState &state) const {
184     return {};
185   }
186 
187   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
188                           BufferizationState &state) const {
189     auto dimOp = cast<tensor::DimOp>(op);
190     auto v = state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
191     if (failed(v))
192       return failure();
193     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
194                                                 dimOp.index());
195     return success();
196   }
197 };
198 
199 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
200 struct ExpandShapeOpInterface
201     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
202                                                     tensor::ExpandShapeOp> {
203   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
204                               const AnalysisState &state) const {
205     return false;
206   }
207 
208   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
209                                const AnalysisState &state) const {
210     return false;
211   }
212 
213   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
214                                             const AnalysisState &state) const {
215     if (&opOperand == &op->getOpOperand(0) /*src*/)
216       return {op->getOpResult(0)};
217     return {};
218   }
219 
220   BufferRelation bufferRelation(Operation *op, OpResult opResult,
221                                 const AnalysisState &state) const {
222     return BufferRelation::Equivalent;
223   }
224 
225   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
226                           BufferizationState &state) const {
227     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
228     auto tensorResultType = expandShapeOp.getResultType();
229     auto buffer =
230         state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
231     if (failed(buffer))
232       return failure();
233 
234     // Memref result type is inferred by the builder based on reassociation
235     // indices and result shape.
236     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
237         rewriter, op, tensorResultType.getShape(), *buffer,
238         expandShapeOp.getReassociationIndices());
239     return success();
240   }
241 };
242 
243 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
244 struct ExtractSliceOpInterface
245     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
246                                                     tensor::ExtractSliceOp> {
247   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
248                               const AnalysisState &state) const {
249     return false;
250   }
251 
252   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
253                                const AnalysisState &state) const {
254     return false;
255   }
256 
257   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
258                                             const AnalysisState &state) const {
259     if (&opOperand == &op->getOpOperand(0) /*source*/)
260       return {op->getOpResult(0)};
261     return {};
262   }
263 
264   BufferRelation bufferRelation(Operation *op, OpResult opResult,
265                                 const AnalysisState &state) const {
266     return BufferRelation::None;
267   }
268 
269   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
270                           BufferizationState &state) const {
271     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
272     Location loc = extractSliceOp.getLoc();
273 
274     // Even if this op was decided to bufferize out-of-place, do not insert the
275     // buffer copy yet. This is done later in this function.
276     auto srcMemref =
277         state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
278                         BufferizationState::ForceInPlacability::FORCE_INPLACE);
279     if (failed(srcMemref))
280       return failure();
281     auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
282     auto dstTensorType =
283         extractSliceOp.result().getType().cast<RankedTensorType>();
284 
285     // If not inplaceable, alloc.
286     bool inplace =
287         state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
288     Value alloc;
289     if (!inplace) {
290       FailureOr<Value> allocOrFailure =
291           state.createAlloc(rewriter, loc, extractSliceOp.result());
292       if (failed(allocOrFailure))
293         return failure();
294       alloc = *allocOrFailure;
295     }
296 
297     // Expand offsets, sizes and strides to the full rank to handle the
298     // rank-reducing case.
299     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
300     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
301     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
302     OffsetSizeAndStrideOpInterface::expandToRank(
303         *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
304         [&](Value target, int64_t dim) -> OpFoldResult {
305           auto shapedType = target.getType().cast<ShapedType>();
306           if (shapedType.isDynamicDim(dim))
307             return rewriter.create<memref::DimOp>(loc, target, dim).result();
308           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
309         });
310     // Bufferize to subview.
311     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
312                                  dstTensorType.getRank(), srcMemrefType,
313                                  mixedOffsets, mixedSizes, mixedStrides)
314                                  .cast<MemRefType>();
315     Value subView = rewriter.create<memref::SubViewOp>(
316         loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
317         mixedStrides);
318 
319     // If not inplaceable, copy.
320     if (!inplace) {
321       // Do not copy if the copied data is never read.
322       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
323         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
324                                 alloc, state.getOptions())))
325           return failure();
326       subView = alloc;
327     }
328 
329     replaceOpWithBufferizedValues(rewriter, op, subView);
330     return success();
331   }
332 };
333 
334 /// Bufferization of tensor.extract. Replace with memref.load.
335 struct ExtractOpInterface
336     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
337                                                     tensor::ExtractOp> {
338   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
339                               const AnalysisState &state) const {
340     return true;
341   }
342 
343   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
344                                const AnalysisState &state) const {
345     return false;
346   }
347 
348   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
349                                             const AnalysisState &state) const {
350     return {};
351   }
352 
353   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
354                           BufferizationState &state) const {
355     auto extractOp = cast<tensor::ExtractOp>(op);
356     auto srcMemref =
357         state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
358     if (failed(srcMemref))
359       return failure();
360     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
361                                                  extractOp.indices());
362     return success();
363   }
364 };
365 
366 // Implements backtracking to traverse indices of the output buffer while
367 // iterating over op.elements().
368 static void createStores(RewriterBase &rewriter, Location loc, int dim,
369                          Value buffer, ArrayRef<int64_t> shape,
370                          ArrayRef<Value> constants,
371                          OperandRange::iterator &elementIt,
372                          SmallVectorImpl<Value> &indices) {
373   if (dim == static_cast<int>(shape.size()) - 1) {
374     for (int i = 0; i < shape.back(); ++i) {
375       indices.back() = constants[i];
376       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
377       ++elementIt;
378     }
379     return;
380   }
381   for (int i = 0; i < shape[dim]; ++i) {
382     indices[dim] = constants[i];
383     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
384                  indices);
385   }
386 }
387 
388 /// Bufferization of tensor.from_elements.
389 struct FromElementsOpInterface
390     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
391                                                     tensor::FromElementsOp> {
392   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
393                           BufferizationState &state) const {
394     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
395 
396     // Allocate a buffer for the result.
397     Location loc = op->getLoc();
398     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
399     auto shape = tensorType.getShape();
400     FailureOr<Value> maybeBuffer =
401         state.createAlloc(rewriter, loc, fromElementsOp.result());
402     if (failed(maybeBuffer))
403       return failure();
404     Value buffer = *maybeBuffer;
405 
406     // Case: tensor<0xelem_type>.
407     if (fromElementsOp.elements().empty()) {
408       replaceOpWithBufferizedValues(rewriter, op, buffer);
409       return success();
410     }
411 
412     // Case: tensor<elem_type>.
413     if (shape.empty()) {
414       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
415                                        buffer);
416       replaceOpWithBufferizedValues(rewriter, op, buffer);
417       return success();
418     }
419 
420     // Create constants for the range of possible indices [0, max{shape_i}).
421     auto maxDim = *std::max_element(shape.begin(), shape.end());
422     SmallVector<Value, 2> constants;
423     constants.reserve(maxDim);
424     for (int i = 0; i < maxDim; ++i)
425       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
426 
427     // Traverse all `elements` and create `memref.store` ops.
428     auto elementIt = fromElementsOp.elements().begin();
429     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
430     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
431                  indices);
432 
433     replaceOpWithBufferizedValues(rewriter, op, buffer);
434     return success();
435   }
436 };
437 
438 /// Bufferization of tensor.generate.
439 struct GenerateOpInterface
440     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
441                                                     tensor::GenerateOp> {
442   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
443                           BufferizationState &state) const {
444     auto generateOp = cast<tensor::GenerateOp>(op);
445 
446     // Allocate memory.
447     Location loc = op->getLoc();
448     MemRefType memrefType =
449         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
450     FailureOr<Value> maybeResult =
451         state.createAlloc(rewriter, loc, generateOp.result());
452     if (failed(maybeResult))
453       return failure();
454     Value result = *maybeResult;
455 
456     // Collect loop bounds.
457     int64_t rank = memrefType.getRank();
458     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
459     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
460     SmallVector<Value, 4> lowerBounds(rank, zero);
461     SmallVector<Value, 4> steps(rank, one);
462     SmallVector<Value, 4> upperBounds;
463     int nextDynamicIndex = 0;
464     for (int i = 0; i < rank; i++) {
465       Value upperBound = memrefType.isDynamicDim(i)
466                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
467                              : rewriter.create<arith::ConstantIndexOp>(
468                                    loc, memrefType.getDimSize(i));
469       upperBounds.push_back(upperBound);
470     }
471 
472     // Generate tensor elements with a parallel loop that stores into
473     // each element of the resulting memref. We use mergeBlockBefore to "move"
474     // this op's body into the scf.parallel's body.
475     auto parallel =
476         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
477     Block *parallelBody = parallel.getBody();
478     rewriter.mergeBlockBefore(generateOp.getBody(),
479                               parallelBody->getTerminator(),
480                               parallelBody->getArguments());
481     // Replace the inlined yield op with a store op. The scf.parallel's builder
482     // already populated an scf.yield at the end, so we don't need to worry
483     // about creating that.
484     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
485     rewriter.setInsertionPointAfter(elementYield);
486     rewriter.replaceOpWithNewOp<memref::StoreOp>(
487         elementYield, elementYield->getOperands()[0], result,
488         parallelBody->getArguments());
489 
490     replaceOpWithBufferizedValues(rewriter, op, result);
491     return success();
492   }
493 };
494 
495 /// Bufferization of tensor.insert. Replace with memref.store.
496 struct InsertOpInterface
497     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
498                                                     tensor::InsertOp> {
499   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
500                               const AnalysisState &state) const {
501     return true;
502   }
503 
504   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
505                                const AnalysisState &state) const {
506     return true;
507   }
508 
509   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
510                                             const AnalysisState &state) const {
511     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
512            "expected dest OpOperand");
513     return {op->getOpResult(0)};
514   }
515 
516   SmallVector<OpOperand *>
517   getAliasingOpOperand(Operation *op, OpResult opResult,
518                        const AnalysisState &state) const {
519     return {&op->getOpOperand(1) /*dest*/};
520   }
521 
522   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
523                           BufferizationState &state) const {
524     auto insertOp = cast<tensor::InsertOp>(op);
525     FailureOr<Value> destMemref =
526         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
527     if (failed(destMemref))
528       return failure();
529     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
530                                      *destMemref, insertOp.indices());
531     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
532     return success();
533   }
534 
535   BufferRelation bufferRelation(Operation *op, OpResult opResult,
536                                 const AnalysisState &state) const {
537     return BufferRelation::Equivalent;
538   }
539 };
540 
541 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
542 /// equivalent operand / result and same offset/sizes/strides specification).
543 ///
544 /// This is one particular type of relationship between ops on tensors that
545 /// reduce to an equivalence on buffers. This should be generalized and
546 /// exposed as interfaces on the proper types.
547 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
548                                          ExtractSliceOp st, InsertSliceOp sti) {
549   if (!st || !sti)
550     return false;
551   if (sti != sti &&
552       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
553     return false;
554   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
555     return false;
556   return true;
557 }
558 
559 /// Return true if `value` is originating from an ExtractSliceOp that matches
560 /// the given InsertSliceOp.
561 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
562                                       InsertSliceOp insertOp) {
563   auto condition = [&](Value val) {
564     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
565       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
566         return true;
567     return false;
568   };
569 
570   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
571                       condition);
572 }
573 
574 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
575 /// certain circumstances, this op can also be a no-op.
576 struct InsertSliceOpInterface
577     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
578                                                     tensor::InsertSliceOp> {
579   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
580                               const AnalysisState &state) const {
581     return true;
582   }
583 
584   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
585                                const AnalysisState &state) const {
586     return &opOperand == &op->getOpOperand(1) /*dest*/;
587   }
588 
589   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
590                                             const AnalysisState &state) const {
591     if (&opOperand == &op->getOpOperand(1) /*dest*/)
592       return {op->getResult(0)};
593     return {};
594   }
595 
596   BufferRelation bufferRelation(Operation *op, OpResult opResult,
597                                 const AnalysisState &state) const {
598     return BufferRelation::Equivalent;
599   }
600 
601   bool isNotConflicting(Operation *op, OpOperand *uRead,
602                         OpOperand *uConflictingWrite,
603                         const AnalysisState &state) const {
604     Operation *readingOp = uRead->getOwner();
605     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
606 
607     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
608     // uRead is an InsertSliceOp...
609     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
610       // As an example, consider the following IR.
611       //
612       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
613       // %1 = linalg.fill %cst, %0 {inplace= [true] }
614       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
615       //     {inplace= [true] }
616 
617       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
618       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
619           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
620                                     insertSliceOp))
621         // Case 1: The main insight is that InsertSliceOp reads only part of
622         // the destination tensor. The overwritten area is not read. If
623         // uConflictingWrite writes into exactly the memory location that is
624         // being read by uRead, this is not a conflict.
625         //
626         // In the above example:
627         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
628         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
629         //
630         // The read of %t does not conflict with the write of the FillOp
631         // (same aliases!) because the area that the FillOp operates on is
632         // exactly the one that is *not* read via %t.
633         return true;
634 
635       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
636           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
637           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
638         // Case 2: The read of the source tensor and the write to the dest
639         // tensor via an InsertSliceOp is not a conflict if the read is
640         // reading exactly that part of an equivalent tensor that the
641         // InsertSliceOp is writing.
642         //
643         // In the above example:
644         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
645         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
646         return true;
647     }
648 
649     // If uConflictingWrite is an InsertSliceOp...
650     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
651       // As an example, consider the following IR.
652       //
653       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
654       // %1 = linalg.fill %cst, %0 {inplace= [true] }
655       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
656       //     {inplace= [true] }
657       // %3 = vector.transfer_read %1, %cst
658       //
659       // In the above example:
660       // uRead             = OpOperand 0 (%1) of vector.transfer_read
661       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
662       // lastWrite         = %1
663       //
664       // This is not a conflict because the InsertSliceOp overwrites the
665       // memory segment of %1 with the exact same data. (Effectively, there
666       // is no memory write here.)
667       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
668           state.areEquivalentBufferizedValues(uRead->get(),
669                                               insertSliceOp.source()) &&
670           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
671                                     insertSliceOp))
672         return true;
673 
674     return false;
675   }
676 
677   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
678                           BufferizationState &state) const {
679     // insert_slice ops arise from tiling and bufferizing them out-of-place is
680     // generally a deal breaker. When used with loops, this ends up cloning the
681     // whole tensor on every single iteration and is a symptom of a
682     // catastrophically bad scheduling decision.
683     // TODO: be very loud about it or even consider failing the pass.
684     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
685     Location loc = insertSliceOp.getLoc();
686 
687     // When bufferizing out-of-place, `getResultBuffer` allocates.
688     FailureOr<Value> dstMemref =
689         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
690     if (failed(dstMemref))
691       return failure();
692 
693     // Expand offsets, sizes and strides to the full rank to handle the
694     // rank-reducing case.
695     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
696     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
697     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
698     OffsetSizeAndStrideOpInterface::expandToRank(
699         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
700         [&](Value target, int64_t dim) -> OpFoldResult {
701           auto shapedType = target.getType().cast<ShapedType>();
702           if (shapedType.isDynamicDim(dim))
703             return rewriter.create<memref::DimOp>(loc, target, dim).result();
704           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
705         });
706     // Take a subview of the dst.
707     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
708     auto subviewMemRefType =
709         memref::SubViewOp::inferRankReducedResultType(
710             insertSliceOp.getSourceType().getRank(), dstMemrefType,
711             mixedOffsets, mixedSizes, mixedStrides)
712             .cast<MemRefType>();
713     Value subView = rewriter.create<memref::SubViewOp>(
714         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
715         mixedStrides);
716 
717     // Copy tensor. If this tensor.insert_slice has a matching
718     // tensor.extract_slice, the copy operation will eventually fold away.
719     auto srcMemref =
720         state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
721     if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref,
722                                                  subView, state.getOptions())))
723       return failure();
724 
725     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
726     return success();
727   }
728 };
729 
730 /// Bufferization of tensor.rank. Replace with memref.rank.
731 struct RankOpInterface
732     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
733                                                     tensor::RankOp> {
734   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
735                               const AnalysisState &state) const {
736     return true;
737   }
738 
739   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
740                                const AnalysisState &state) const {
741     return false;
742   }
743 
744   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
745                                             const AnalysisState &state) const {
746     return {};
747   }
748 
749   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
750                           BufferizationState &state) const {
751     auto rankOp = cast<tensor::RankOp>(op);
752     auto v = state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
753     if (failed(v))
754       return failure();
755     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
756                                                  *v);
757     return success();
758   }
759 };
760 
761 /// Bufferization of tensor.reshape. Replace with memref.reshape.
762 struct ReshapeOpInterface
763     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
764                                                     tensor::ReshapeOp> {
765   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
766                               const AnalysisState &state) const {
767     if (&opOperand == &op->getOpOperand(1) /* shape */)
768       return true;
769     return false;
770   }
771 
772   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
773                                const AnalysisState &state) const {
774     return false;
775   }
776 
777   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
778                                             const AnalysisState &state) const {
779     return {op->getOpResult(0)};
780   }
781 
782   BufferRelation bufferRelation(Operation *op, OpResult opResult,
783                                 const AnalysisState &state) const {
784     return BufferRelation::Equivalent;
785   }
786 
787   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
788                           BufferizationState &state) const {
789     auto reshapeOp = cast<tensor::ReshapeOp>(op);
790     auto &srcOperand = reshapeOp->getOpOperand(0);
791     auto srcBuffer = state.getBuffer(rewriter, srcOperand);
792     if (failed(srcBuffer))
793       return failure();
794 
795     auto &shapeOperand = reshapeOp->getOpOperand(1);
796     auto shapeBuffer = state.getBuffer(rewriter, shapeOperand);
797     if (failed(shapeBuffer))
798       return failure();
799 
800     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
801     auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
802 
803     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
804         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
805     return success();
806   }
807 };
808 
809 } // namespace
810 } // namespace tensor
811 } // namespace mlir
812 
813 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
814     DialectRegistry &registry) {
815   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
816     CastOp::attachInterface<CastOpInterface>(*ctx);
817     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
818     DimOp::attachInterface<DimOpInterface>(*ctx);
819     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
820     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
821     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
822     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
823     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
824     InsertOp::attachInterface<InsertOpInterface>(*ctx);
825     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
826     RankOp::attachInterface<RankOpInterface>(*ctx);
827     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
828   });
829 }
830