1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::tensor;
21 
22 namespace mlir {
23 namespace tensor {
24 namespace {
25 
26 struct CastOpInterface
27     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
28                                                     tensor::CastOp> {
29   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
30                               const AnalysisState &state) const {
31     return false;
32   }
33 
34   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
35                                const AnalysisState &state) const {
36     return false;
37   }
38 
39   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
40                                             const AnalysisState &state) const {
41     return {op->getResult(0)};
42   }
43 
44   BufferRelation bufferRelation(Operation *op, OpResult opResult,
45                                 const AnalysisState &state) const {
46     return BufferRelation::Equivalent;
47   }
48 
49   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50                           BufferizationState &state) const {
51     auto castOp = cast<tensor::CastOp>(op);
52 
53     // The result buffer still has the old (pre-cast) type.
54     FailureOr<Value> resultBuffer =
55         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
56     if (failed(resultBuffer))
57       return failure();
58     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
59     Attribute memorySpace = sourceMemRefType.getMemorySpace();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
70                                           layout, memorySpace);
71 
72     // Replace the op with a memref.cast.
73     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
74                                              resultMemRefType) &&
75            "CallOp::bufferize: cast incompatible");
76     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
77                                                  *resultBuffer);
78 
79     return success();
80   }
81 };
82 
83 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
84 struct CollapseShapeOpInterface
85     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
86                                                     tensor::CollapseShapeOp> {
87   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88                               const AnalysisState &state) const {
89     return false;
90   }
91 
92   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93                                const AnalysisState &state) const {
94     return false;
95   }
96 
97   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
98                                             const AnalysisState &state) const {
99     if (&opOperand == &op->getOpOperand(0) /*src*/)
100       return {op->getOpResult(0)};
101     return {};
102   }
103 
104   BufferRelation bufferRelation(Operation *op, OpResult opResult,
105                                 const AnalysisState &state) const {
106     return BufferRelation::Equivalent;
107   }
108 
109   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
110                           BufferizationState &state) const {
111     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
112     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
113     OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/;
114     auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>();
115 
116     if (tensorResultType.getRank() == 0) {
117       // 0-d collapses must go through a different op builder.
118       auto buffer = state.getBuffer(rewriter, srcOperand);
119       if (failed(buffer))
120         return failure();
121       MemRefType resultType;
122 
123       if (bufferType.getLayout().isIdentity()) {
124         // Standard layout: result type has no offset.
125         MemRefLayoutAttrInterface layout;
126         resultType = MemRefType::get({}, tensorResultType.getElementType(),
127                                      layout, bufferType.getMemorySpace());
128       } else {
129         // Source memref has a layout map: result type has the same offset as
130         // the source type.
131         SmallVector<int64_t> strides;
132         int64_t offset;
133         if (failed(getStridesAndOffset(bufferType, strides, offset)))
134           return failure();
135         AffineMap resultLayout =
136             makeStridedLinearLayoutMap({}, offset, op->getContext());
137         resultType =
138             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
139                             bufferType.getMemorySpaceAsInt());
140       }
141 
142       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
143           rewriter, op, resultType, *buffer, collapseShapeOp.reassociation());
144       return success();
145     }
146 
147     // If the dims are not collapsible (due to an incompatible source layout
148     // map), force an out-of-place bufferization, i.e., a buffer copy. This
149     // newly allocated buffer will have no layout map and thus be collapsible.
150     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
151         bufferType, collapseShapeOp.getReassociationIndices());
152     Optional<BufferizationState::ForceInPlacability> overrideInPlace =
153         canBeCollapsed
154             ? None
155             : Optional<BufferizationState::ForceInPlacability>(
156                   BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
157     auto buffer = state.getBuffer(rewriter, srcOperand, overrideInPlace);
158     if (failed(buffer))
159       return failure();
160 
161     // Result type is inferred by the builder.
162     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
163         rewriter, op, *buffer, collapseShapeOp.getReassociationIndices());
164     return success();
165   }
166 };
167 
168 /// Bufferization of tensor.dim. Replace with memref.dim.
169 struct DimOpInterface
170     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
171                                                     tensor::DimOp> {
172   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
173                               const AnalysisState &state) const {
174     return true;
175   }
176 
177   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
178                                const AnalysisState &state) const {
179     return false;
180   }
181 
182   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
183                                             const AnalysisState &state) const {
184     return {};
185   }
186 
187   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
188                           BufferizationState &state) const {
189     auto dimOp = cast<tensor::DimOp>(op);
190     auto v = state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
191     if (failed(v))
192       return failure();
193     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
194                                                 dimOp.index());
195     return success();
196   }
197 };
198 
199 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
200 struct ExpandShapeOpInterface
201     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
202                                                     tensor::ExpandShapeOp> {
203   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
204                               const AnalysisState &state) const {
205     return false;
206   }
207 
208   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
209                                const AnalysisState &state) const {
210     return false;
211   }
212 
213   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
214                                             const AnalysisState &state) const {
215     if (&opOperand == &op->getOpOperand(0) /*src*/)
216       return {op->getOpResult(0)};
217     return {};
218   }
219 
220   BufferRelation bufferRelation(Operation *op, OpResult opResult,
221                                 const AnalysisState &state) const {
222     return BufferRelation::Equivalent;
223   }
224 
225   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
226                           BufferizationState &state) const {
227     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
228     auto tensorResultType = expandShapeOp.getResultType();
229     auto buffer =
230         state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
231     if (failed(buffer))
232       return failure();
233 
234     // Memref result type is inferred by the builder based on reassociation
235     // indices and result shape.
236     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
237         rewriter, op, tensorResultType.getShape(), *buffer,
238         expandShapeOp.getReassociationIndices());
239     return success();
240   }
241 };
242 
243 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
244 struct ExtractSliceOpInterface
245     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
246                                                     tensor::ExtractSliceOp> {
247   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
248                               const AnalysisState &state) const {
249     return false;
250   }
251 
252   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
253                                const AnalysisState &state) const {
254     return false;
255   }
256 
257   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
258                                             const AnalysisState &state) const {
259     if (&opOperand == &op->getOpOperand(0) /*source*/)
260       return {op->getOpResult(0)};
261     return {};
262   }
263 
264   BufferRelation bufferRelation(Operation *op, OpResult opResult,
265                                 const AnalysisState &state) const {
266     return BufferRelation::None;
267   }
268 
269   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
270                           BufferizationState &state) const {
271     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
272     Location loc = extractSliceOp.getLoc();
273 
274     // Even if this op was decided to bufferize out-of-place, do not insert the
275     // buffer copy yet. This is done later in this function.
276     auto srcMemref =
277         state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
278                         BufferizationState::ForceInPlacability::FORCE_INPLACE);
279     if (failed(srcMemref))
280       return failure();
281     auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
282     auto dstTensorType =
283         extractSliceOp.result().getType().cast<RankedTensorType>();
284 
285     // If not inplaceable, alloc.
286     bool inplace =
287         state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
288     Value alloc;
289     if (!inplace) {
290       FailureOr<Value> allocOrFailure =
291           state.createAlloc(rewriter, loc, extractSliceOp.result());
292       if (failed(allocOrFailure))
293         return failure();
294       alloc = *allocOrFailure;
295     }
296 
297     // Expand offsets, sizes and strides to the full rank to handle the
298     // rank-reducing case.
299     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
300     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
301     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
302     OffsetSizeAndStrideOpInterface::expandToRank(
303         *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
304         [&](Value target, int64_t dim) -> OpFoldResult {
305           auto shapedType = target.getType().cast<ShapedType>();
306           if (shapedType.isDynamicDim(dim))
307             return rewriter.create<memref::DimOp>(loc, target, dim).result();
308           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
309         });
310     // Bufferize to subview.
311     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
312                                  dstTensorType.getRank(), srcMemrefType,
313                                  mixedOffsets, mixedSizes, mixedStrides)
314                                  .cast<MemRefType>();
315     Value subView = rewriter.create<memref::SubViewOp>(
316         loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
317         mixedStrides);
318 
319     // If not inplaceable, copy.
320     if (!inplace) {
321       // Do not copy if the copied data is never read.
322       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
323         if (failed(state.getOptions().createMemCpy(
324                 rewriter, extractSliceOp.getLoc(), subView, alloc)))
325           return failure();
326       subView = alloc;
327     }
328 
329     replaceOpWithBufferizedValues(rewriter, op, subView);
330     return success();
331   }
332 };
333 
334 /// Bufferization of tensor.extract. Replace with memref.load.
335 struct ExtractOpInterface
336     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
337                                                     tensor::ExtractOp> {
338   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
339                               const AnalysisState &state) const {
340     return true;
341   }
342 
343   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
344                                const AnalysisState &state) const {
345     return false;
346   }
347 
348   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
349                                             const AnalysisState &state) const {
350     return {};
351   }
352 
353   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
354                           BufferizationState &state) const {
355     auto extractOp = cast<tensor::ExtractOp>(op);
356     auto srcMemref =
357         state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
358     if (failed(srcMemref))
359       return failure();
360     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
361                                                  extractOp.indices());
362     return success();
363   }
364 };
365 
366 // Implements backtracking to traverse indices of the output buffer while
367 // iterating over op.elements().
368 static void createStores(RewriterBase &rewriter, Location loc, int dim,
369                          Value buffer, ArrayRef<int64_t> shape,
370                          ArrayRef<Value> constants,
371                          OperandRange::iterator &elementIt,
372                          SmallVectorImpl<Value> &indices) {
373   if (dim == static_cast<int>(shape.size()) - 1) {
374     for (int i = 0; i < shape.back(); ++i) {
375       indices.back() = constants[i];
376       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
377       ++elementIt;
378     }
379     return;
380   }
381   for (int i = 0; i < shape[dim]; ++i) {
382     indices[dim] = constants[i];
383     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
384                  indices);
385   }
386 }
387 
388 /// Bufferization of tensor.from_elements.
389 struct FromElementsOpInterface
390     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
391                                                     tensor::FromElementsOp> {
392   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
393                           BufferizationState &state) const {
394     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
395 
396     // Allocate a buffer for the result.
397     Location loc = op->getLoc();
398     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
399     auto shape = tensorType.getShape();
400     FailureOr<Value> maybeBuffer =
401         state.createAlloc(rewriter, loc, fromElementsOp.result());
402     if (failed(maybeBuffer))
403       return failure();
404     Value buffer = *maybeBuffer;
405 
406     // Case: tensor<0xelem_type>.
407     if (fromElementsOp.elements().empty()) {
408       replaceOpWithBufferizedValues(rewriter, op, buffer);
409       return success();
410     }
411 
412     // Case: tensor<elem_type>.
413     if (shape.empty()) {
414       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
415                                        buffer);
416       replaceOpWithBufferizedValues(rewriter, op, buffer);
417       return success();
418     }
419 
420     // Create constants for the range of possible indices [0, max{shape_i}).
421     auto maxDim = *std::max_element(shape.begin(), shape.end());
422     SmallVector<Value, 2> constants;
423     constants.reserve(maxDim);
424     for (int i = 0; i < maxDim; ++i)
425       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
426 
427     // Traverse all `elements` and create `memref.store` ops.
428     auto elementIt = fromElementsOp.elements().begin();
429     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
430     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
431                  indices);
432 
433     replaceOpWithBufferizedValues(rewriter, op, buffer);
434     return success();
435   }
436 };
437 
438 /// Bufferization of tensor.generate.
439 struct GenerateOpInterface
440     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
441                                                     tensor::GenerateOp> {
442   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
443                           BufferizationState &state) const {
444     auto generateOp = cast<tensor::GenerateOp>(op);
445 
446     // Allocate memory.
447     Location loc = op->getLoc();
448     FailureOr<Value> maybeResult =
449         state.createAlloc(rewriter, loc, generateOp.result());
450     if (failed(maybeResult))
451       return failure();
452     Value result = *maybeResult;
453     MemRefType memrefType = result.getType().cast<MemRefType>();
454 
455     // Collect loop bounds.
456     int64_t rank = memrefType.getRank();
457     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
458     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
459     SmallVector<Value, 4> lowerBounds(rank, zero);
460     SmallVector<Value, 4> steps(rank, one);
461     SmallVector<Value, 4> upperBounds;
462     int nextDynamicIndex = 0;
463     for (int i = 0; i < rank; i++) {
464       Value upperBound = memrefType.isDynamicDim(i)
465                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
466                              : rewriter.create<arith::ConstantIndexOp>(
467                                    loc, memrefType.getDimSize(i));
468       upperBounds.push_back(upperBound);
469     }
470 
471     // Generate tensor elements with a parallel loop that stores into
472     // each element of the resulting memref. We use mergeBlockBefore to "move"
473     // this op's body into the scf.parallel's body.
474     auto parallel =
475         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
476     Block *parallelBody = parallel.getBody();
477     rewriter.mergeBlockBefore(generateOp.getBody(),
478                               parallelBody->getTerminator(),
479                               parallelBody->getArguments());
480     // Replace the inlined yield op with a store op. The scf.parallel's builder
481     // already populated an scf.yield at the end, so we don't need to worry
482     // about creating that.
483     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
484     rewriter.setInsertionPointAfter(elementYield);
485     rewriter.replaceOpWithNewOp<memref::StoreOp>(
486         elementYield, elementYield->getOperands()[0], result,
487         parallelBody->getArguments());
488 
489     replaceOpWithBufferizedValues(rewriter, op, result);
490     return success();
491   }
492 };
493 
494 /// Bufferization of tensor.insert. Replace with memref.store.
495 struct InsertOpInterface
496     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
497                                                     tensor::InsertOp> {
498   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
499                               const AnalysisState &state) const {
500     return true;
501   }
502 
503   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
504                                const AnalysisState &state) const {
505     return true;
506   }
507 
508   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
509                                             const AnalysisState &state) const {
510     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
511            "expected dest OpOperand");
512     return {op->getOpResult(0)};
513   }
514 
515   SmallVector<OpOperand *>
516   getAliasingOpOperand(Operation *op, OpResult opResult,
517                        const AnalysisState &state) const {
518     return {&op->getOpOperand(1) /*dest*/};
519   }
520 
521   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
522                           BufferizationState &state) const {
523     auto insertOp = cast<tensor::InsertOp>(op);
524     FailureOr<Value> destMemref =
525         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
526     if (failed(destMemref))
527       return failure();
528     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
529                                      *destMemref, insertOp.indices());
530     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
531     return success();
532   }
533 
534   BufferRelation bufferRelation(Operation *op, OpResult opResult,
535                                 const AnalysisState &state) const {
536     return BufferRelation::Equivalent;
537   }
538 };
539 
540 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
541 /// equivalent operand / result and same offset/sizes/strides specification).
542 ///
543 /// This is one particular type of relationship between ops on tensors that
544 /// reduce to an equivalence on buffers. This should be generalized and
545 /// exposed as interfaces on the proper types.
546 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
547                                          ExtractSliceOp st, InsertSliceOp sti) {
548   if (!st || !sti)
549     return false;
550   if (sti != sti &&
551       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
552     return false;
553   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
554     return false;
555   return true;
556 }
557 
558 /// Return true if `value` is originating from an ExtractSliceOp that matches
559 /// the given InsertSliceOp.
560 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
561                                       InsertSliceOp insertOp) {
562   auto condition = [&](Value val) {
563     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
564       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
565         return true;
566     return false;
567   };
568 
569   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
570                       condition);
571 }
572 
573 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
574 /// certain circumstances, this op can also be a no-op.
575 struct InsertSliceOpInterface
576     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
577                                                     tensor::InsertSliceOp> {
578   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
579                               const AnalysisState &state) const {
580     return true;
581   }
582 
583   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
584                                const AnalysisState &state) const {
585     return &opOperand == &op->getOpOperand(1) /*dest*/;
586   }
587 
588   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
589                                             const AnalysisState &state) const {
590     if (&opOperand == &op->getOpOperand(1) /*dest*/)
591       return {op->getResult(0)};
592     return {};
593   }
594 
595   BufferRelation bufferRelation(Operation *op, OpResult opResult,
596                                 const AnalysisState &state) const {
597     return BufferRelation::Equivalent;
598   }
599 
600   bool isNotConflicting(Operation *op, OpOperand *uRead,
601                         OpOperand *uConflictingWrite,
602                         const AnalysisState &state) const {
603     Operation *readingOp = uRead->getOwner();
604     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
605 
606     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
607     // uRead is an InsertSliceOp...
608     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
609       // As an example, consider the following IR.
610       //
611       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
612       // %1 = linalg.fill %cst, %0 {inplace= [true] }
613       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
614       //     {inplace= [true] }
615 
616       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
617       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
618           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
619                                     insertSliceOp))
620         // Case 1: The main insight is that InsertSliceOp reads only part of
621         // the destination tensor. The overwritten area is not read. If
622         // uConflictingWrite writes into exactly the memory location that is
623         // being read by uRead, this is not a conflict.
624         //
625         // In the above example:
626         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
627         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
628         //
629         // The read of %t does not conflict with the write of the FillOp
630         // (same aliases!) because the area that the FillOp operates on is
631         // exactly the one that is *not* read via %t.
632         return true;
633 
634       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
635           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
636           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
637         // Case 2: The read of the source tensor and the write to the dest
638         // tensor via an InsertSliceOp is not a conflict if the read is
639         // reading exactly that part of an equivalent tensor that the
640         // InsertSliceOp is writing.
641         //
642         // In the above example:
643         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
644         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
645         return true;
646     }
647 
648     // If uConflictingWrite is an InsertSliceOp...
649     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
650       // As an example, consider the following IR.
651       //
652       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
653       // %1 = linalg.fill %cst, %0 {inplace= [true] }
654       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
655       //     {inplace= [true] }
656       // %3 = vector.transfer_read %1, %cst
657       //
658       // In the above example:
659       // uRead             = OpOperand 0 (%1) of vector.transfer_read
660       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
661       // lastWrite         = %1
662       //
663       // This is not a conflict because the InsertSliceOp overwrites the
664       // memory segment of %1 with the exact same data. (Effectively, there
665       // is no memory write here.)
666       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
667           state.areEquivalentBufferizedValues(uRead->get(),
668                                               insertSliceOp.source()) &&
669           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
670                                     insertSliceOp))
671         return true;
672 
673     return false;
674   }
675 
676   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
677                           BufferizationState &state) const {
678     // insert_slice ops arise from tiling and bufferizing them out-of-place is
679     // generally a deal breaker. When used with loops, this ends up cloning the
680     // whole tensor on every single iteration and is a symptom of a
681     // catastrophically bad scheduling decision.
682     // TODO: be very loud about it or even consider failing the pass.
683     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
684     Location loc = insertSliceOp.getLoc();
685 
686     // When bufferizing out-of-place, `getResultBuffer` allocates.
687     FailureOr<Value> dstMemref =
688         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
689     if (failed(dstMemref))
690       return failure();
691 
692     // Expand offsets, sizes and strides to the full rank to handle the
693     // rank-reducing case.
694     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
695     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
696     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
697     OffsetSizeAndStrideOpInterface::expandToRank(
698         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
699         [&](Value target, int64_t dim) -> OpFoldResult {
700           auto shapedType = target.getType().cast<ShapedType>();
701           if (shapedType.isDynamicDim(dim))
702             return rewriter.create<memref::DimOp>(loc, target, dim).result();
703           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
704         });
705     // Take a subview of the dst.
706     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
707     auto subviewMemRefType =
708         memref::SubViewOp::inferRankReducedResultType(
709             insertSliceOp.getSourceType().getRank(), dstMemrefType,
710             mixedOffsets, mixedSizes, mixedStrides)
711             .cast<MemRefType>();
712     Value subView = rewriter.create<memref::SubViewOp>(
713         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
714         mixedStrides);
715 
716     // Copy tensor. If this tensor.insert_slice has a matching
717     // tensor.extract_slice, the copy operation will eventually fold away.
718     auto srcMemref =
719         state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
720     if (failed(srcMemref) || failed(state.getOptions().createMemCpy(
721                                  rewriter, loc, *srcMemref, subView)))
722       return failure();
723 
724     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
725     return success();
726   }
727 };
728 
729 /// Bufferization of tensor.rank. Replace with memref.rank.
730 struct RankOpInterface
731     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
732                                                     tensor::RankOp> {
733   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
734                               const AnalysisState &state) const {
735     return true;
736   }
737 
738   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
739                                const AnalysisState &state) const {
740     return false;
741   }
742 
743   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
744                                             const AnalysisState &state) const {
745     return {};
746   }
747 
748   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
749                           BufferizationState &state) const {
750     auto rankOp = cast<tensor::RankOp>(op);
751     auto v = state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
752     if (failed(v))
753       return failure();
754     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
755                                                  *v);
756     return success();
757   }
758 };
759 
760 /// Bufferization of tensor.reshape. Replace with memref.reshape.
761 struct ReshapeOpInterface
762     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
763                                                     tensor::ReshapeOp> {
764   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
765                               const AnalysisState &state) const {
766     if (&opOperand == &op->getOpOperand(1) /* shape */)
767       return true;
768     return false;
769   }
770 
771   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
772                                const AnalysisState &state) const {
773     return false;
774   }
775 
776   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
777                                             const AnalysisState &state) const {
778     return {op->getOpResult(0)};
779   }
780 
781   BufferRelation bufferRelation(Operation *op, OpResult opResult,
782                                 const AnalysisState &state) const {
783     return BufferRelation::Equivalent;
784   }
785 
786   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
787                           BufferizationState &state) const {
788     auto reshapeOp = cast<tensor::ReshapeOp>(op);
789     auto &srcOperand = reshapeOp->getOpOperand(0);
790     auto srcBuffer = state.getBuffer(rewriter, srcOperand);
791     if (failed(srcBuffer))
792       return failure();
793 
794     auto &shapeOperand = reshapeOp->getOpOperand(1);
795     auto shapeBuffer = state.getBuffer(rewriter, shapeOperand);
796     if (failed(shapeBuffer))
797       return failure();
798 
799     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
800     auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
801 
802     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
803         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
804     return success();
805   }
806 };
807 
808 } // namespace
809 } // namespace tensor
810 } // namespace mlir
811 
812 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
813     DialectRegistry &registry) {
814   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
815     CastOp::attachInterface<CastOpInterface>(*ctx);
816     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
817     DimOp::attachInterface<DimOpInterface>(*ctx);
818     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
819     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
820     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
821     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
822     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
823     InsertOp::attachInterface<InsertOpInterface>(*ctx);
824     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
825     RankOp::attachInterface<RankOpInterface>(*ctx);
826     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
827   });
828 }
829