1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::tensor;
22 
23 namespace mlir {
24 namespace tensor {
25 namespace {
26 
27 struct CastOpInterface
28     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
29                                                     tensor::CastOp> {
30   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31                               const AnalysisState &state) const {
32     return false;
33   }
34 
35   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
36                                const AnalysisState &state) const {
37     return false;
38   }
39 
40   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
41                                             const AnalysisState &state) const {
42     return {op->getResult(0)};
43   }
44 
45   BufferRelation bufferRelation(Operation *op, OpResult opResult,
46                                 const AnalysisState &state) const {
47     return BufferRelation::Equivalent;
48   }
49 
50   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51                           const BufferizationOptions &options) const {
52     auto castOp = cast<tensor::CastOp>(op);
53 
54     // The result buffer still has the old (pre-cast) type.
55     FailureOr<Value> resultBuffer =
56         getBuffer(rewriter, castOp.getSource(), options);
57     if (failed(resultBuffer))
58       return failure();
59     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType =
70         getMemRefType(resultTensorType, options, layout,
71                       sourceMemRefType.getMemorySpaceAsInt());
72 
73     // Replace the op with a memref.cast.
74     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
75                                              resultMemRefType) &&
76            "CallOp::bufferize: cast incompatible");
77     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
78                                                  *resultBuffer);
79 
80     return success();
81   }
82 };
83 
84 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
85 struct CollapseShapeOpInterface
86     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
87                                                     tensor::CollapseShapeOp> {
88   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
89                               const AnalysisState &state) const {
90     return false;
91   }
92 
93   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
94                                const AnalysisState &state) const {
95     return false;
96   }
97 
98   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
99                                             const AnalysisState &state) const {
100     if (&opOperand == &op->getOpOperand(0) /*src*/)
101       return {op->getOpResult(0)};
102     return {};
103   }
104 
105   BufferRelation bufferRelation(Operation *op, OpResult opResult,
106                                 const AnalysisState &state) const {
107     return BufferRelation::Equivalent;
108   }
109 
110   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111                           const BufferizationOptions &options) const {
112     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
113     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
114     FailureOr<Value> maybeBuffer =
115         getBuffer(rewriter, collapseShapeOp.getSrc(), options);
116     if (failed(maybeBuffer))
117       return failure();
118     Value buffer = *maybeBuffer;
119     auto bufferType = buffer.getType().cast<MemRefType>();
120 
121     if (tensorResultType.getRank() == 0) {
122       // 0-d collapses must go through a different op builder.
123       MemRefType resultType;
124 
125       if (bufferType.getLayout().isIdentity()) {
126         // Standard layout: result type has no offset.
127         MemRefLayoutAttrInterface layout;
128         resultType = MemRefType::get({}, tensorResultType.getElementType(),
129                                      layout, bufferType.getMemorySpace());
130       } else {
131         // Source memref has a layout map: result type has the same offset as
132         // the source type.
133         SmallVector<int64_t> strides;
134         int64_t offset;
135         if (failed(getStridesAndOffset(bufferType, strides, offset)))
136           return failure();
137         AffineMap resultLayout =
138             makeStridedLinearLayoutMap({}, offset, op->getContext());
139         resultType =
140             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
141                             bufferType.getMemorySpaceAsInt());
142       }
143 
144       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
145           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
146       return success();
147     }
148 
149     // If the dims are not collapsible (due to an incompatible source layout
150     // map), force an out-of-place bufferization, i.e., a buffer copy. This
151     // newly allocated buffer will have no layout map and thus be collapsible.
152     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
153         bufferType, collapseShapeOp.getReassociationIndices());
154     if (!canBeCollapsed) {
155       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
156       AnalysisState analysisState(options);
157       Value tensorAlloc = allocateTensorForShapedValue(
158           rewriter, op->getLoc(), collapseShapeOp.getSrc(),
159           analysisState.isTensorYielded(collapseShapeOp.getResult()));
160       auto memrefType =
161           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
162                           collapseShapeOp.getSrcType().getElementType(),
163                           AffineMap(), bufferType.getMemorySpaceAsInt());
164       buffer = rewriter.create<bufferization::ToMemrefOp>(
165           op->getLoc(), memrefType, tensorAlloc);
166     }
167 
168     // Result type is inferred by the builder.
169     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
170         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
171     return success();
172   }
173 };
174 
175 /// Bufferization of tensor.dim. Replace with memref.dim.
176 struct DimOpInterface
177     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
178                                                     tensor::DimOp> {
179   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
180                               const AnalysisState &state) const {
181     return true;
182   }
183 
184   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
185                                const AnalysisState &state) const {
186     return false;
187   }
188 
189   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
190                                             const AnalysisState &state) const {
191     return {};
192   }
193 
194   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
195                           const BufferizationOptions &options) const {
196     auto dimOp = cast<tensor::DimOp>(op);
197     FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
198     if (failed(v))
199       return failure();
200     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
201                                                 dimOp.index());
202     return success();
203   }
204 };
205 
206 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
207 struct ExpandShapeOpInterface
208     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
209                                                     tensor::ExpandShapeOp> {
210   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
211                               const AnalysisState &state) const {
212     return false;
213   }
214 
215   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
216                                const AnalysisState &state) const {
217     return false;
218   }
219 
220   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
221                                             const AnalysisState &state) const {
222     if (&opOperand == &op->getOpOperand(0) /*src*/)
223       return {op->getOpResult(0)};
224     return {};
225   }
226 
227   BufferRelation bufferRelation(Operation *op, OpResult opResult,
228                                 const AnalysisState &state) const {
229     return BufferRelation::Equivalent;
230   }
231 
232   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
233                           const BufferizationOptions &options) const {
234     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
235     auto tensorResultType = expandShapeOp.getResultType();
236     FailureOr<Value> buffer =
237         getBuffer(rewriter, expandShapeOp.getSrc(), options);
238     if (failed(buffer))
239       return failure();
240 
241     // Memref result type is inferred by the builder based on reassociation
242     // indices and result shape.
243     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
244         rewriter, op, tensorResultType.getShape(), *buffer,
245         expandShapeOp.getReassociationIndices());
246     return success();
247   }
248 };
249 
250 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
251 struct ExtractSliceOpInterface
252     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
253                                                     tensor::ExtractSliceOp> {
254   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
255                               const AnalysisState &state) const {
256     return false;
257   }
258 
259   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
260                                const AnalysisState &state) const {
261     return false;
262   }
263 
264   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
265                                             const AnalysisState &state) const {
266     if (&opOperand == &op->getOpOperand(0) /*source*/)
267       return {op->getOpResult(0)};
268     return {};
269   }
270 
271   BufferRelation bufferRelation(Operation *op, OpResult opResult,
272                                 const AnalysisState &state) const {
273     return BufferRelation::None;
274   }
275 
276   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
277                           const BufferizationOptions &options) const {
278     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
279     Location loc = extractSliceOp.getLoc();
280 
281     // Even if this op was decided to bufferize out-of-place, do not insert the
282     // buffer copy yet. This is done later in this function.
283     FailureOr<Value> srcMemref =
284         getBuffer(rewriter, extractSliceOp.getSource(), options);
285     if (failed(srcMemref))
286       return failure();
287     auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
288     auto dstTensorType =
289         extractSliceOp.getResult().getType().cast<RankedTensorType>();
290 
291     // Expand offsets, sizes and strides to the full rank to handle the
292     // rank-reducing case.
293     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
294     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
295     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
296     OffsetSizeAndStrideOpInterface::expandToRank(
297         *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
298         [&](Value target, int64_t dim) -> OpFoldResult {
299           auto shapedType = target.getType().cast<ShapedType>();
300           if (shapedType.isDynamicDim(dim))
301             return rewriter.create<memref::DimOp>(loc, target, dim).result();
302           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
303         });
304     // Bufferize to subview.
305     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
306                                  dstTensorType.getRank(), srcMemrefType,
307                                  mixedOffsets, mixedSizes, mixedStrides)
308                                  .cast<MemRefType>();
309     Value subView = rewriter.create<memref::SubViewOp>(
310         loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
311         mixedStrides);
312 
313     replaceOpWithBufferizedValues(rewriter, op, subView);
314     return success();
315   }
316 };
317 
318 /// Bufferization of tensor.extract. Replace with memref.load.
319 struct ExtractOpInterface
320     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
321                                                     tensor::ExtractOp> {
322   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
323                               const AnalysisState &state) const {
324     return true;
325   }
326 
327   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
328                                const AnalysisState &state) const {
329     return false;
330   }
331 
332   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
333                                             const AnalysisState &state) const {
334     return {};
335   }
336 
337   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
338                           const BufferizationOptions &options) const {
339     auto extractOp = cast<tensor::ExtractOp>(op);
340     FailureOr<Value> srcMemref =
341         getBuffer(rewriter, extractOp.getTensor(), options);
342     if (failed(srcMemref))
343       return failure();
344     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
345                                                  extractOp.indices());
346     return success();
347   }
348 };
349 
350 // Implements backtracking to traverse indices of the output buffer while
351 // iterating over op.elements().
352 static void createStores(RewriterBase &rewriter, Location loc, int dim,
353                          Value buffer, ArrayRef<int64_t> shape,
354                          ArrayRef<Value> constants,
355                          OperandRange::iterator &elementIt,
356                          SmallVectorImpl<Value> &indices) {
357   if (dim == static_cast<int>(shape.size()) - 1) {
358     for (int i = 0; i < shape.back(); ++i) {
359       indices.back() = constants[i];
360       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
361       ++elementIt;
362     }
363     return;
364   }
365   for (int i = 0; i < shape[dim]; ++i) {
366     indices[dim] = constants[i];
367     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
368                  indices);
369   }
370 }
371 
372 /// Bufferization of tensor.from_elements.
373 struct FromElementsOpInterface
374     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
375                                                     tensor::FromElementsOp> {
376   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
377                           const BufferizationOptions &options) const {
378     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
379 
380     // Allocate a buffer for the result.
381     Location loc = op->getLoc();
382     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
383     auto shape = tensorType.getShape();
384     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
385     AnalysisState analysisState(options);
386     Value tensorAlloc = allocateTensorForShapedValue(
387         rewriter, loc, fromElementsOp.getResult(),
388         analysisState.isTensorYielded(fromElementsOp.getResult()),
389         /*copy=*/false);
390     auto memrefType =
391         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
392     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
393         op->getLoc(), memrefType, tensorAlloc);
394 
395     // Case: tensor<0xelem_type>.
396     if (fromElementsOp.getElements().empty()) {
397       replaceOpWithBufferizedValues(rewriter, op, buffer);
398       return success();
399     }
400 
401     // Case: tensor<elem_type>.
402     if (shape.empty()) {
403       rewriter.create<memref::StoreOp>(
404           loc, fromElementsOp.getElements().front(), buffer);
405       replaceOpWithBufferizedValues(rewriter, op, buffer);
406       return success();
407     }
408 
409     // Create constants for the range of possible indices [0, max{shape_i}).
410     auto maxDim = *std::max_element(shape.begin(), shape.end());
411     SmallVector<Value, 2> constants;
412     constants.reserve(maxDim);
413     for (int i = 0; i < maxDim; ++i)
414       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
415 
416     // Traverse all `elements` and create `memref.store` ops.
417     auto elementIt = fromElementsOp.getElements().begin();
418     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
419     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
420                  indices);
421 
422     replaceOpWithBufferizedValues(rewriter, op, buffer);
423     return success();
424   }
425 };
426 
427 /// Bufferization of tensor.generate.
428 struct GenerateOpInterface
429     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
430                                                     tensor::GenerateOp> {
431   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
432                           const BufferizationOptions &options) const {
433     auto generateOp = cast<tensor::GenerateOp>(op);
434     auto tensorType = generateOp.getType().cast<RankedTensorType>();
435     // Allocate memory.
436     Location loc = op->getLoc();
437     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
438     AnalysisState analysisState(options);
439     Value tensorAlloc = allocateTensorForShapedValue(
440         rewriter, loc, generateOp.getResult(),
441         analysisState.isTensorYielded(generateOp.getResult()),
442         /*copy=*/false);
443     auto memrefType =
444         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
445     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
446         op->getLoc(), memrefType, tensorAlloc);
447 
448     // Collect loop bounds.
449     int64_t rank = memrefType.getRank();
450     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
451     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
452     SmallVector<Value, 4> lowerBounds(rank, zero);
453     SmallVector<Value, 4> steps(rank, one);
454     SmallVector<Value, 4> upperBounds;
455     int nextDynamicIndex = 0;
456     for (int i = 0; i < rank; i++) {
457       Value upperBound =
458           memrefType.isDynamicDim(i)
459               ? generateOp.getDynamicExtents()[nextDynamicIndex++]
460               : rewriter.create<arith::ConstantIndexOp>(
461                     loc, memrefType.getDimSize(i));
462       upperBounds.push_back(upperBound);
463     }
464 
465     // Generate tensor elements with a parallel loop that stores into
466     // each element of the resulting memref. We use mergeBlockBefore to "move"
467     // this op's body into the scf.parallel's body.
468     auto parallel =
469         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
470     Block *parallelBody = parallel.getBody();
471     rewriter.mergeBlockBefore(&generateOp.getBody().front(),
472                               parallelBody->getTerminator(),
473                               parallelBody->getArguments());
474     // Replace the inlined yield op with a store op. The scf.parallel's builder
475     // already populated an scf.yield at the end, so we don't need to worry
476     // about creating that.
477     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
478     rewriter.setInsertionPointAfter(elementYield);
479     rewriter.replaceOpWithNewOp<memref::StoreOp>(
480         elementYield, elementYield->getOperands()[0], buffer,
481         parallelBody->getArguments());
482 
483     replaceOpWithBufferizedValues(rewriter, op, buffer);
484     return success();
485   }
486 };
487 
488 /// Bufferization of tensor.insert. Replace with memref.store.
489 struct InsertOpInterface
490     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
491                                                     tensor::InsertOp> {
492   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
493                               const AnalysisState &state) const {
494     return true;
495   }
496 
497   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
498                                const AnalysisState &state) const {
499     return true;
500   }
501 
502   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
503                                             const AnalysisState &state) const {
504     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
505            "expected dest OpOperand");
506     return {op->getOpResult(0)};
507   }
508 
509   SmallVector<OpOperand *>
510   getAliasingOpOperand(Operation *op, OpResult opResult,
511                        const AnalysisState &state) const {
512     return {&op->getOpOperand(1) /*dest*/};
513   }
514 
515   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
516                           const BufferizationOptions &options) const {
517     auto insertOp = cast<tensor::InsertOp>(op);
518     FailureOr<Value> destMemref =
519         getBuffer(rewriter, insertOp.getDest(), options);
520     if (failed(destMemref))
521       return failure();
522     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
523                                      *destMemref, insertOp.getIndices());
524     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
525     return success();
526   }
527 
528   BufferRelation bufferRelation(Operation *op, OpResult opResult,
529                                 const AnalysisState &state) const {
530     return BufferRelation::Equivalent;
531   }
532 };
533 
534 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
535 /// equivalent operand / result and same offset/sizes/strides specification).
536 ///
537 /// This is one particular type of relationship between ops on tensors that
538 /// reduce to an equivalence on buffers. This should be generalized and
539 /// exposed as interfaces on the proper types.
540 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
541                                          ExtractSliceOp st, InsertSliceOp sti) {
542   if (!st || !sti)
543     return false;
544   if (sti != sti &&
545       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
546     return false;
547   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
548     return false;
549   return true;
550 }
551 
552 /// Return true if `value` is originating from an ExtractSliceOp that matches
553 /// the given InsertSliceOp.
554 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
555                                       InsertSliceOp insertOp) {
556   auto condition = [&](Value val) {
557     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
558       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
559         return true;
560     return false;
561   };
562 
563   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
564                       condition);
565 }
566 
567 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
568 /// certain circumstances, this op can also be a no-op.
569 struct InsertSliceOpInterface
570     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
571                                                     tensor::InsertSliceOp> {
572   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
573                               const AnalysisState &state) const {
574     return true;
575   }
576 
577   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
578                                const AnalysisState &state) const {
579     return &opOperand == &op->getOpOperand(1) /*dest*/;
580   }
581 
582   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
583                                             const AnalysisState &state) const {
584     if (&opOperand == &op->getOpOperand(1) /*dest*/)
585       return {op->getResult(0)};
586     return {};
587   }
588 
589   BufferRelation bufferRelation(Operation *op, OpResult opResult,
590                                 const AnalysisState &state) const {
591     return BufferRelation::Equivalent;
592   }
593 
594   bool isNotConflicting(Operation *op, OpOperand *uRead,
595                         OpOperand *uConflictingWrite,
596                         const AnalysisState &state) const {
597     Operation *readingOp = uRead->getOwner();
598     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
599 
600     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
601     // uRead is an InsertSliceOp...
602     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
603       // As an example, consider the following IR.
604       //
605       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
606       // %1 = linalg.fill %cst, %0 {inplace= [true] }
607       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
608       //     {inplace= [true] }
609 
610       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
611       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
612           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
613                                     insertSliceOp))
614         // Case 1: The main insight is that InsertSliceOp reads only part of
615         // the destination tensor. The overwritten area is not read. If
616         // uConflictingWrite writes into exactly the memory location that is
617         // being read by uRead, this is not a conflict.
618         //
619         // In the above example:
620         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
621         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
622         //
623         // The read of %t does not conflict with the write of the FillOp
624         // (same aliases!) because the area that the FillOp operates on is
625         // exactly the one that is *not* read via %t.
626         return true;
627 
628       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
629           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
630           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
631         // Case 2: The read of the source tensor and the write to the dest
632         // tensor via an InsertSliceOp is not a conflict if the read is
633         // reading exactly that part of an equivalent tensor that the
634         // InsertSliceOp is writing.
635         //
636         // In the above example:
637         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
638         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
639         return true;
640     }
641 
642     // If uConflictingWrite is an InsertSliceOp...
643     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
644       // As an example, consider the following IR.
645       //
646       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
647       // %1 = linalg.fill %cst, %0 {inplace= [true] }
648       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
649       //     {inplace= [true] }
650       // %3 = vector.transfer_read %1, %cst
651       //
652       // In the above example:
653       // uRead             = OpOperand 0 (%1) of vector.transfer_read
654       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
655       // lastWrite         = %1
656       //
657       // This is not a conflict because the InsertSliceOp overwrites the
658       // memory segment of %1 with the exact same data. (Effectively, there
659       // is no memory write here.)
660       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
661           state.areEquivalentBufferizedValues(uRead->get(),
662                                               insertSliceOp.getSource()) &&
663           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
664                                     insertSliceOp))
665         return true;
666 
667     return false;
668   }
669 
670   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
671                           const BufferizationOptions &options) const {
672     // insert_slice ops arise from tiling and bufferizing them out-of-place is
673     // generally a deal breaker. When used with loops, this ends up cloning the
674     // whole tensor on every single iteration and is a symptom of a
675     // catastrophically bad scheduling decision.
676     // TODO: be very loud about it or even consider failing the pass.
677     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
678     Location loc = insertSliceOp.getLoc();
679     FailureOr<Value> dstMemref =
680         getBuffer(rewriter, insertSliceOp.getDest(), options);
681     if (failed(dstMemref))
682       return failure();
683 
684     // Expand offsets, sizes and strides to the full rank to handle the
685     // rank-reducing case.
686     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
687     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
688     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
689     OffsetSizeAndStrideOpInterface::expandToRank(
690         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
691         [&](Value target, int64_t dim) -> OpFoldResult {
692           auto shapedType = target.getType().cast<ShapedType>();
693           if (shapedType.isDynamicDim(dim))
694             return rewriter.create<memref::DimOp>(loc, target, dim).result();
695           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
696         });
697     // Take a subview of the dst.
698     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
699     auto subviewMemRefType =
700         memref::SubViewOp::inferRankReducedResultType(
701             insertSliceOp.getSourceType().getRank(), dstMemrefType,
702             mixedOffsets, mixedSizes, mixedStrides)
703             .cast<MemRefType>();
704     Value subView = rewriter.create<memref::SubViewOp>(
705         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
706         mixedStrides);
707 
708     // Copy tensor. If this tensor.insert_slice has a matching
709     // tensor.extract_slice, the copy operation will eventually fold away.
710     FailureOr<Value> srcMemref =
711         getBuffer(rewriter, insertSliceOp.getSource(), options);
712     if (failed(srcMemref))
713       return failure();
714     if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
715       return failure();
716 
717     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
718     return success();
719   }
720 };
721 
722 /// Bufferization of tensor.rank. Replace with memref.rank.
723 struct RankOpInterface
724     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
725                                                     tensor::RankOp> {
726   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
727                               const AnalysisState &state) const {
728     return true;
729   }
730 
731   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
732                                const AnalysisState &state) const {
733     return false;
734   }
735 
736   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
737                                             const AnalysisState &state) const {
738     return {};
739   }
740 
741   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
742                           const BufferizationOptions &options) const {
743     auto rankOp = cast<tensor::RankOp>(op);
744     FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
745     if (failed(v))
746       return failure();
747     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
748                                                  *v);
749     return success();
750   }
751 };
752 
753 /// Bufferization of tensor.reshape. Replace with memref.reshape.
754 struct ReshapeOpInterface
755     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
756                                                     tensor::ReshapeOp> {
757   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
758                               const AnalysisState &state) const {
759     if (&opOperand == &op->getOpOperand(1) /* shape */)
760       return true;
761     return false;
762   }
763 
764   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
765                                const AnalysisState &state) const {
766     return false;
767   }
768 
769   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
770                                             const AnalysisState &state) const {
771     return {op->getOpResult(0)};
772   }
773 
774   BufferRelation bufferRelation(Operation *op, OpResult opResult,
775                                 const AnalysisState &state) const {
776     return BufferRelation::Equivalent;
777   }
778 
779   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
780                           const BufferizationOptions &options) const {
781     auto reshapeOp = cast<tensor::ReshapeOp>(op);
782     FailureOr<Value> srcBuffer =
783         getBuffer(rewriter, reshapeOp.getSource(), options);
784     FailureOr<Value> shapeBuffer =
785         getBuffer(rewriter, reshapeOp.getShape(), options);
786     if (failed(srcBuffer) || failed(shapeBuffer))
787       return failure();
788     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
789     auto resultMemRefType = getMemRefType(resultTensorType, options);
790     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
791         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
792     return success();
793   }
794 };
795 
796 } // namespace
797 } // namespace tensor
798 } // namespace mlir
799 
800 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
801     DialectRegistry &registry) {
802   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
803     CastOp::attachInterface<CastOpInterface>(*ctx);
804     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
805     DimOp::attachInterface<DimOpInterface>(*ctx);
806     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
807     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
808     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
809     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
810     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
811     InsertOp::attachInterface<InsertOpInterface>(*ctx);
812     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
813     RankOp::attachInterface<RankOpInterface>(*ctx);
814     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
815   });
816 }
817