1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::tensor;
22 
23 namespace mlir {
24 namespace tensor {
25 namespace {
26 
27 struct CastOpInterface
28     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
29                                                     tensor::CastOp> {
30   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31                               const AnalysisState &state) const {
32     return false;
33   }
34 
35   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
36                                const AnalysisState &state) const {
37     return false;
38   }
39 
40   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
41                                             const AnalysisState &state) const {
42     return {op->getResult(0)};
43   }
44 
45   BufferRelation bufferRelation(Operation *op, OpResult opResult,
46                                 const AnalysisState &state) const {
47     return BufferRelation::Equivalent;
48   }
49 
50   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51                           const BufferizationOptions &options) const {
52     auto castOp = cast<tensor::CastOp>(op);
53 
54     // The result buffer still has the old (pre-cast) type.
55     Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options);
56     auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
57     Attribute memorySpace = sourceMemRefType.getMemorySpace();
58     TensorType resultTensorType =
59         castOp.getResult().getType().cast<TensorType>();
60     MemRefLayoutAttrInterface layout;
61 
62     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
63       if (resultTensorType.isa<RankedTensorType>())
64         layout = rankedMemRefType.getLayout();
65 
66     // Compute the new memref type.
67     Type resultMemRefType =
68         getMemRefType(resultTensorType, options, layout, memorySpace);
69 
70     // Replace the op with a memref.cast.
71     assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
72                                              resultMemRefType) &&
73            "CallOp::bufferize: cast incompatible");
74     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
75                                                  resultBuffer);
76 
77     return success();
78   }
79 };
80 
81 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
82 struct CollapseShapeOpInterface
83     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
84                                                     tensor::CollapseShapeOp> {
85   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
86                               const AnalysisState &state) const {
87     return false;
88   }
89 
90   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
91                                const AnalysisState &state) const {
92     return false;
93   }
94 
95   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
96                                             const AnalysisState &state) const {
97     if (&opOperand == &op->getOpOperand(0) /*src*/)
98       return {op->getOpResult(0)};
99     return {};
100   }
101 
102   BufferRelation bufferRelation(Operation *op, OpResult opResult,
103                                 const AnalysisState &state) const {
104     return BufferRelation::Equivalent;
105   }
106 
107   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
108                           const BufferizationOptions &options) const {
109     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
110     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
111     Value buffer = getBuffer(rewriter, collapseShapeOp.getSrc(), options);
112     auto bufferType = buffer.getType().cast<MemRefType>();
113 
114     if (tensorResultType.getRank() == 0) {
115       // 0-d collapses must go through a different op builder.
116       MemRefType resultType;
117 
118       if (bufferType.getLayout().isIdentity()) {
119         // Standard layout: result type has no offset.
120         MemRefLayoutAttrInterface layout;
121         resultType = MemRefType::get({}, tensorResultType.getElementType(),
122                                      layout, bufferType.getMemorySpace());
123       } else {
124         // Source memref has a layout map: result type has the same offset as
125         // the source type.
126         SmallVector<int64_t> strides;
127         int64_t offset;
128         if (failed(getStridesAndOffset(bufferType, strides, offset)))
129           return failure();
130         AffineMap resultLayout =
131             makeStridedLinearLayoutMap({}, offset, op->getContext());
132         resultType =
133             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
134                             bufferType.getMemorySpaceAsInt());
135       }
136 
137       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
138           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
139       return success();
140     }
141 
142     // If the dims are not collapsible (due to an incompatible source layout
143     // map), force an out-of-place bufferization, i.e., a buffer copy. This
144     // newly allocated buffer will have no layout map and thus be collapsible.
145     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
146         bufferType, collapseShapeOp.getReassociationIndices());
147     if (!canBeCollapsed) {
148       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
149       AnalysisState analysisState(options);
150       Value tensorAlloc = allocateTensorForShapedValue(
151           rewriter, op->getLoc(), collapseShapeOp.getSrc(),
152           analysisState.isTensorYielded(collapseShapeOp.getResult()));
153       auto memrefType =
154           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
155                           collapseShapeOp.getSrcType().getElementType(),
156                           AffineMap(), bufferType.getMemorySpaceAsInt());
157       buffer = rewriter.create<bufferization::ToMemrefOp>(
158           op->getLoc(), memrefType, tensorAlloc);
159     }
160 
161     // Result type is inferred by the builder.
162     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
163         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
164     return success();
165   }
166 };
167 
168 /// Bufferization of tensor.dim. Replace with memref.dim.
169 struct DimOpInterface
170     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
171                                                     tensor::DimOp> {
172   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
173                               const AnalysisState &state) const {
174     return true;
175   }
176 
177   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
178                                const AnalysisState &state) const {
179     return false;
180   }
181 
182   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
183                                             const AnalysisState &state) const {
184     return {};
185   }
186 
187   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
188                           const BufferizationOptions &options) const {
189     auto dimOp = cast<tensor::DimOp>(op);
190     auto v = getBuffer(rewriter, dimOp.getSource(), options);
191     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v,
192                                                 dimOp.getIndex());
193     return success();
194   }
195 };
196 
197 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
198 struct ExpandShapeOpInterface
199     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
200                                                     tensor::ExpandShapeOp> {
201   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
202                               const AnalysisState &state) const {
203     return false;
204   }
205 
206   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
207                                const AnalysisState &state) const {
208     return false;
209   }
210 
211   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
212                                             const AnalysisState &state) const {
213     if (&opOperand == &op->getOpOperand(0) /*src*/)
214       return {op->getOpResult(0)};
215     return {};
216   }
217 
218   BufferRelation bufferRelation(Operation *op, OpResult opResult,
219                                 const AnalysisState &state) const {
220     return BufferRelation::Equivalent;
221   }
222 
223   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
224                           const BufferizationOptions &options) const {
225     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
226     auto tensorResultType = expandShapeOp.getResultType();
227     auto buffer = getBuffer(rewriter, expandShapeOp.getSrc(), options);
228 
229     // Memref result type is inferred by the builder based on reassociation
230     // indices and result shape.
231     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
232         rewriter, op, tensorResultType.getShape(), buffer,
233         expandShapeOp.getReassociationIndices());
234     return success();
235   }
236 };
237 
238 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
239 struct ExtractSliceOpInterface
240     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
241                                                     tensor::ExtractSliceOp> {
242   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
243                               const AnalysisState &state) const {
244     return false;
245   }
246 
247   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
248                                const AnalysisState &state) const {
249     return false;
250   }
251 
252   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
253                                             const AnalysisState &state) const {
254     if (&opOperand == &op->getOpOperand(0) /*source*/)
255       return {op->getOpResult(0)};
256     return {};
257   }
258 
259   BufferRelation bufferRelation(Operation *op, OpResult opResult,
260                                 const AnalysisState &state) const {
261     return BufferRelation::None;
262   }
263 
264   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
265                           const BufferizationOptions &options) const {
266     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
267     Location loc = extractSliceOp.getLoc();
268 
269     // Even if this op was decided to bufferize out-of-place, do not insert the
270     // buffer copy yet. This is done later in this function.
271     auto srcMemref = getBuffer(rewriter, extractSliceOp.getSource(), options);
272     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
273     auto dstTensorType =
274         extractSliceOp.getResult().getType().cast<RankedTensorType>();
275 
276     // Expand offsets, sizes and strides to the full rank to handle the
277     // rank-reducing case.
278     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
279     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
280     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
281     OffsetSizeAndStrideOpInterface::expandToRank(
282         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
283         [&](Value target, int64_t dim) -> OpFoldResult {
284           auto shapedType = target.getType().cast<ShapedType>();
285           if (shapedType.isDynamicDim(dim))
286             return rewriter.create<memref::DimOp>(loc, target, dim).result();
287           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
288         });
289     // Bufferize to subview.
290     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
291                                  dstTensorType.getRank(), srcMemrefType,
292                                  mixedOffsets, mixedSizes, mixedStrides)
293                                  .cast<MemRefType>();
294     Value subView = rewriter.create<memref::SubViewOp>(
295         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
296         mixedStrides);
297 
298     replaceOpWithBufferizedValues(rewriter, op, subView);
299     return success();
300   }
301 };
302 
303 /// Bufferization of tensor.extract. Replace with memref.load.
304 struct ExtractOpInterface
305     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
306                                                     tensor::ExtractOp> {
307   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
308                               const AnalysisState &state) const {
309     return true;
310   }
311 
312   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
313                                const AnalysisState &state) const {
314     return false;
315   }
316 
317   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
318                                             const AnalysisState &state) const {
319     return {};
320   }
321 
322   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
323                           const BufferizationOptions &options) const {
324     auto extractOp = cast<tensor::ExtractOp>(op);
325     Value srcMemref = getBuffer(rewriter, extractOp.getTensor(), options);
326     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
327                                                  extractOp.getIndices());
328     return success();
329   }
330 };
331 
332 // Implements backtracking to traverse indices of the output buffer while
333 // iterating over op.elements().
334 static void createStores(RewriterBase &rewriter, Location loc, int dim,
335                          Value buffer, ArrayRef<int64_t> shape,
336                          ArrayRef<Value> constants,
337                          OperandRange::iterator &elementIt,
338                          SmallVectorImpl<Value> &indices) {
339   if (dim == static_cast<int>(shape.size()) - 1) {
340     for (int i = 0; i < shape.back(); ++i) {
341       indices.back() = constants[i];
342       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
343       ++elementIt;
344     }
345     return;
346   }
347   for (int i = 0; i < shape[dim]; ++i) {
348     indices[dim] = constants[i];
349     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
350                  indices);
351   }
352 }
353 
354 /// Bufferization of tensor.from_elements.
355 struct FromElementsOpInterface
356     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
357                                                     tensor::FromElementsOp> {
358   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
359                           const BufferizationOptions &options) const {
360     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
361 
362     // Allocate a buffer for the result.
363     Location loc = op->getLoc();
364     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
365     auto shape = tensorType.getShape();
366     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
367     AnalysisState analysisState(options);
368     Value tensorAlloc = allocateTensorForShapedValue(
369         rewriter, loc, fromElementsOp.getResult(),
370         analysisState.isTensorYielded(fromElementsOp.getResult()),
371         /*copy=*/false);
372     auto memrefType =
373         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
374     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
375         op->getLoc(), memrefType, tensorAlloc);
376 
377     // Case: tensor<0xelem_type>.
378     if (fromElementsOp.getElements().empty()) {
379       replaceOpWithBufferizedValues(rewriter, op, buffer);
380       return success();
381     }
382 
383     // Case: tensor<elem_type>.
384     if (shape.empty()) {
385       rewriter.create<memref::StoreOp>(
386           loc, fromElementsOp.getElements().front(), buffer);
387       replaceOpWithBufferizedValues(rewriter, op, buffer);
388       return success();
389     }
390 
391     // Create constants for the range of possible indices [0, max{shape_i}).
392     auto maxDim = *std::max_element(shape.begin(), shape.end());
393     SmallVector<Value, 2> constants;
394     constants.reserve(maxDim);
395     for (int i = 0; i < maxDim; ++i)
396       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
397 
398     // Traverse all `elements` and create `memref.store` ops.
399     auto elementIt = fromElementsOp.getElements().begin();
400     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
401     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
402                  indices);
403 
404     replaceOpWithBufferizedValues(rewriter, op, buffer);
405     return success();
406   }
407 };
408 
409 /// Bufferization of tensor.generate.
410 struct GenerateOpInterface
411     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
412                                                     tensor::GenerateOp> {
413   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
414                           const BufferizationOptions &options) const {
415     auto generateOp = cast<tensor::GenerateOp>(op);
416     auto tensorType = generateOp.getType().cast<RankedTensorType>();
417     // Allocate memory.
418     Location loc = op->getLoc();
419     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
420     AnalysisState analysisState(options);
421     Value tensorAlloc = allocateTensorForShapedValue(
422         rewriter, loc, generateOp.getResult(),
423         analysisState.isTensorYielded(generateOp.getResult()),
424         /*copy=*/false);
425     auto memrefType =
426         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
427     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
428         op->getLoc(), memrefType, tensorAlloc);
429 
430     // Collect loop bounds.
431     int64_t rank = memrefType.getRank();
432     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
433     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
434     SmallVector<Value, 4> lowerBounds(rank, zero);
435     SmallVector<Value, 4> steps(rank, one);
436     SmallVector<Value, 4> upperBounds;
437     int nextDynamicIndex = 0;
438     for (int i = 0; i < rank; i++) {
439       Value upperBound =
440           memrefType.isDynamicDim(i)
441               ? generateOp.getDynamicExtents()[nextDynamicIndex++]
442               : rewriter.create<arith::ConstantIndexOp>(
443                     loc, memrefType.getDimSize(i));
444       upperBounds.push_back(upperBound);
445     }
446 
447     // Generate tensor elements with a parallel loop that stores into
448     // each element of the resulting memref. We use mergeBlockBefore to "move"
449     // this op's body into the scf.parallel's body.
450     auto parallel =
451         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
452     Block *parallelBody = parallel.getBody();
453     rewriter.mergeBlockBefore(&generateOp.getBody().front(),
454                               parallelBody->getTerminator(),
455                               parallelBody->getArguments());
456     // Replace the inlined yield op with a store op. The scf.parallel's builder
457     // already populated an scf.yield at the end, so we don't need to worry
458     // about creating that.
459     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
460     rewriter.setInsertionPointAfter(elementYield);
461     rewriter.replaceOpWithNewOp<memref::StoreOp>(
462         elementYield, elementYield->getOperands()[0], buffer,
463         parallelBody->getArguments());
464 
465     replaceOpWithBufferizedValues(rewriter, op, buffer);
466     return success();
467   }
468 };
469 
470 /// Bufferization of tensor.insert. Replace with memref.store.
471 struct InsertOpInterface
472     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
473                                                     tensor::InsertOp> {
474   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
475                               const AnalysisState &state) const {
476     return true;
477   }
478 
479   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
480                                const AnalysisState &state) const {
481     return true;
482   }
483 
484   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
485                                             const AnalysisState &state) const {
486     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
487            "expected dest OpOperand");
488     return {op->getOpResult(0)};
489   }
490 
491   SmallVector<OpOperand *>
492   getAliasingOpOperand(Operation *op, OpResult opResult,
493                        const AnalysisState &state) const {
494     return {&op->getOpOperand(1) /*dest*/};
495   }
496 
497   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
498                           const BufferizationOptions &options) const {
499     auto insertOp = cast<tensor::InsertOp>(op);
500     Value destMemref = getBuffer(rewriter, insertOp.getDest(), options);
501     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
502                                      destMemref, insertOp.getIndices());
503     replaceOpWithBufferizedValues(rewriter, op, destMemref);
504     return success();
505   }
506 
507   BufferRelation bufferRelation(Operation *op, OpResult opResult,
508                                 const AnalysisState &state) const {
509     return BufferRelation::Equivalent;
510   }
511 };
512 
513 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
514 /// equivalent operand / result and same offset/sizes/strides specification).
515 ///
516 /// This is one particular type of relationship between ops on tensors that
517 /// reduce to an equivalence on buffers. This should be generalized and
518 /// exposed as interfaces on the proper types.
519 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
520                                          ExtractSliceOp st, InsertSliceOp sti) {
521   if (!st || !sti)
522     return false;
523   if (sti != sti &&
524       !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest()))
525     return false;
526   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
527     return false;
528   return true;
529 }
530 
531 /// Return true if `value` is originating from an ExtractSliceOp that matches
532 /// the given InsertSliceOp.
533 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
534                                       InsertSliceOp insertOp) {
535   auto condition = [&](Value val) {
536     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
537       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
538         return true;
539     return false;
540   };
541 
542   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
543                       condition);
544 }
545 
546 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
547 /// certain circumstances, this op can also be a no-op.
548 struct InsertSliceOpInterface
549     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
550                                                     tensor::InsertSliceOp> {
551   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
552                               const AnalysisState &state) const {
553     return true;
554   }
555 
556   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
557                                const AnalysisState &state) const {
558     return &opOperand == &op->getOpOperand(1) /*dest*/;
559   }
560 
561   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
562                                             const AnalysisState &state) const {
563     if (&opOperand == &op->getOpOperand(1) /*dest*/)
564       return {op->getResult(0)};
565     return {};
566   }
567 
568   BufferRelation bufferRelation(Operation *op, OpResult opResult,
569                                 const AnalysisState &state) const {
570     return BufferRelation::Equivalent;
571   }
572 
573   bool isNotConflicting(Operation *op, OpOperand *uRead,
574                         OpOperand *uConflictingWrite,
575                         const AnalysisState &state) const {
576     Operation *readingOp = uRead->getOwner();
577     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
578 
579     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
580     // uRead is an InsertSliceOp...
581     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
582       // As an example, consider the following IR.
583       //
584       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
585       // %1 = linalg.fill %cst, %0 {inplace= [true] }
586       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
587       //     {inplace= [true] }
588 
589       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
590       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
591           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
592                                     insertSliceOp))
593         // Case 1: The main insight is that InsertSliceOp reads only part of
594         // the destination tensor. The overwritten area is not read. If
595         // uConflictingWrite writes into exactly the memory location that is
596         // being read by uRead, this is not a conflict.
597         //
598         // In the above example:
599         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
600         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
601         //
602         // The read of %t does not conflict with the write of the FillOp
603         // (same aliases!) because the area that the FillOp operates on is
604         // exactly the one that is *not* read via %t.
605         return true;
606 
607       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
608           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
609           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
610         // Case 2: The read of the source tensor and the write to the dest
611         // tensor via an InsertSliceOp is not a conflict if the read is
612         // reading exactly that part of an equivalent tensor that the
613         // InsertSliceOp is writing.
614         //
615         // In the above example:
616         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
617         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
618         return true;
619     }
620 
621     // If uConflictingWrite is an InsertSliceOp...
622     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
623       // As an example, consider the following IR.
624       //
625       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
626       // %1 = linalg.fill %cst, %0 {inplace= [true] }
627       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
628       //     {inplace= [true] }
629       // %3 = vector.transfer_read %1, %cst
630       //
631       // In the above example:
632       // uRead             = OpOperand 0 (%1) of vector.transfer_read
633       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
634       // lastWrite         = %1
635       //
636       // This is not a conflict because the InsertSliceOp overwrites the
637       // memory segment of %1 with the exact same data. (Effectively, there
638       // is no memory write here.)
639       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
640           state.areEquivalentBufferizedValues(uRead->get(),
641                                               insertSliceOp.getSource()) &&
642           hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
643                                     insertSliceOp))
644         return true;
645 
646     return false;
647   }
648 
649   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
650                           const BufferizationOptions &options) const {
651     // insert_slice ops arise from tiling and bufferizing them out-of-place is
652     // generally a deal breaker. When used with loops, this ends up cloning the
653     // whole tensor on every single iteration and is a symptom of a
654     // catastrophically bad scheduling decision.
655     // TODO: be very loud about it or even consider failing the pass.
656     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
657     Location loc = insertSliceOp.getLoc();
658     Value dstMemref = getBuffer(rewriter, insertSliceOp.getDest(), options);
659 
660     // Expand offsets, sizes and strides to the full rank to handle the
661     // rank-reducing case.
662     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
663     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
664     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
665     OffsetSizeAndStrideOpInterface::expandToRank(
666         dstMemref, mixedOffsets, mixedSizes, mixedStrides,
667         [&](Value target, int64_t dim) -> OpFoldResult {
668           auto shapedType = target.getType().cast<ShapedType>();
669           if (shapedType.isDynamicDim(dim))
670             return rewriter.create<memref::DimOp>(loc, target, dim).result();
671           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
672         });
673     // Take a subview of the dst.
674     auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
675     auto subviewMemRefType =
676         memref::SubViewOp::inferRankReducedResultType(
677             insertSliceOp.getSourceType().getRank(), dstMemrefType,
678             mixedOffsets, mixedSizes, mixedStrides)
679             .cast<MemRefType>();
680     Value subView = rewriter.create<memref::SubViewOp>(
681         loc, subviewMemRefType, dstMemref, mixedOffsets, mixedSizes,
682         mixedStrides);
683 
684     // Copy tensor. If this tensor.insert_slice has a matching
685     // tensor.extract_slice, the copy operation will eventually fold away.
686     auto srcMemref = getBuffer(rewriter, insertSliceOp.getSource(), options);
687     if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView)))
688       return failure();
689 
690     replaceOpWithBufferizedValues(rewriter, op, dstMemref);
691     return success();
692   }
693 };
694 
695 /// Bufferization of tensor.rank. Replace with memref.rank.
696 struct RankOpInterface
697     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
698                                                     tensor::RankOp> {
699   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
700                               const AnalysisState &state) const {
701     return true;
702   }
703 
704   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
705                                const AnalysisState &state) const {
706     return false;
707   }
708 
709   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
710                                             const AnalysisState &state) const {
711     return {};
712   }
713 
714   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
715                           const BufferizationOptions &options) const {
716     auto rankOp = cast<tensor::RankOp>(op);
717     auto v = getBuffer(rewriter, rankOp.getTensor(), options);
718     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
719                                                  v);
720     return success();
721   }
722 };
723 
724 /// Bufferization of tensor.reshape. Replace with memref.reshape.
725 struct ReshapeOpInterface
726     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
727                                                     tensor::ReshapeOp> {
728   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
729                               const AnalysisState &state) const {
730     if (&opOperand == &op->getOpOperand(1) /* shape */)
731       return true;
732     return false;
733   }
734 
735   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
736                                const AnalysisState &state) const {
737     return false;
738   }
739 
740   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
741                                             const AnalysisState &state) const {
742     return {op->getOpResult(0)};
743   }
744 
745   BufferRelation bufferRelation(Operation *op, OpResult opResult,
746                                 const AnalysisState &state) const {
747     return BufferRelation::Equivalent;
748   }
749 
750   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
751                           const BufferizationOptions &options) const {
752     auto reshapeOp = cast<tensor::ReshapeOp>(op);
753     Value srcBuffer = getBuffer(rewriter, reshapeOp.getSource(), options);
754     Value shapeBuffer = getBuffer(rewriter, reshapeOp.getShape(), options);
755     auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
756     auto resultMemRefType = getMemRefType(resultTensorType, options);
757     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
758         rewriter, op, resultMemRefType, srcBuffer, shapeBuffer);
759     return success();
760   }
761 };
762 
763 } // namespace
764 } // namespace tensor
765 } // namespace mlir
766 
767 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
768     DialectRegistry &registry) {
769   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
770     CastOp::attachInterface<CastOpInterface>(*ctx);
771     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
772     DimOp::attachInterface<DimOpInterface>(*ctx);
773     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
774     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
775     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
776     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
777     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
778     InsertOp::attachInterface<InsertOpInterface>(*ctx);
779     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
780     RankOp::attachInterface<RankOpInterface>(*ctx);
781     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
782   });
783 }
784