1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::tensor;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 
25 struct CastOpInterface
26     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
27                                                     tensor::CastOp> {
28   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
29                               const AnalysisState &state) const {
30     return false;
31   }
32 
33   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
34                                const AnalysisState &state) const {
35     return false;
36   }
37 
38   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
39                                             const AnalysisState &state) const {
40     return {op->getResult(0)};
41   }
42 
43   BufferRelation bufferRelation(Operation *op, OpResult opResult,
44                                 const AnalysisState &state) const {
45     return BufferRelation::Equivalent;
46   }
47 
48   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49                           BufferizationState &state) const {
50     auto castOp = cast<tensor::CastOp>(op);
51 
52     // The result buffer still has the old (pre-cast) type.
53     FailureOr<Value> resultBuffer =
54         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
55     if (failed(resultBuffer))
56       return failure();
57     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
58     Attribute memorySpace = sourceMemRefType.getMemorySpace();
59     TensorType resultTensorType =
60         castOp.getResult().getType().cast<TensorType>();
61     MemRefLayoutAttrInterface layout;
62 
63     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
64       if (resultTensorType.isa<RankedTensorType>())
65         layout = rankedMemRefType.getLayout();
66 
67     // Compute the new memref type.
68     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
69                                           layout, memorySpace);
70 
71     // Replace the op with a memref.cast.
72     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
73                                              resultMemRefType) &&
74            "CallOp::bufferize: cast incompatible");
75     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
76                                                  *resultBuffer);
77 
78     return success();
79   }
80 };
81 
82 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
83 struct CollapseShapeOpInterface
84     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
85                                                     tensor::CollapseShapeOp> {
86   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
87                               const AnalysisState &state) const {
88     return false;
89   }
90 
91   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
92                                const AnalysisState &state) const {
93     return false;
94   }
95 
96   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
97                                             const AnalysisState &state) const {
98     if (&opOperand == &op->getOpOperand(0) /*src*/)
99       return {op->getOpResult(0)};
100     return {};
101   }
102 
103   BufferRelation bufferRelation(Operation *op, OpResult opResult,
104                                 const AnalysisState &state) const {
105     return BufferRelation::Equivalent;
106   }
107 
108   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
109                           BufferizationState &state) const {
110     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
111     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
112     Value buffer =
113         *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
114 
115     if (tensorResultType.getRank() == 0) {
116       // 0-d collapses must go through a different op builder.
117       auto bufferType = buffer.getType().cast<MemRefType>();
118       MemRefType resultType;
119 
120       if (bufferType.getLayout().isIdentity()) {
121         // Standard layout: result type has no offset.
122         MemRefLayoutAttrInterface layout;
123         resultType = MemRefType::get({}, tensorResultType.getElementType(),
124                                      layout, bufferType.getMemorySpace());
125       } else {
126         // Source memref has a layout map: result type has the same offset as
127         // the source type.
128         SmallVector<int64_t> strides;
129         int64_t offset;
130         if (failed(getStridesAndOffset(bufferType, strides, offset)))
131           return failure();
132         AffineMap resultLayout =
133             makeStridedLinearLayoutMap({}, offset, op->getContext());
134         resultType =
135             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
136                             bufferType.getMemorySpaceAsInt());
137       }
138 
139       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
140           rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
141       return success();
142     }
143 
144     // Result type is inferred by the builder.
145     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
146         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
147     return success();
148   }
149 };
150 
151 /// Bufferization of tensor.dim. Replace with memref.dim.
152 struct DimOpInterface
153     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
154                                                     tensor::DimOp> {
155   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
156                               const AnalysisState &state) const {
157     return true;
158   }
159 
160   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
161                                const AnalysisState &state) const {
162     return false;
163   }
164 
165   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
166                                             const AnalysisState &state) const {
167     return {};
168   }
169 
170   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
171                           BufferizationState &state) const {
172     auto dimOp = cast<tensor::DimOp>(op);
173     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
174     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
175     return success();
176   }
177 };
178 
179 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
180 struct ExpandShapeOpInterface
181     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
182                                                     tensor::ExpandShapeOp> {
183   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
184                               const AnalysisState &state) const {
185     return false;
186   }
187 
188   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
189                                const AnalysisState &state) const {
190     return false;
191   }
192 
193   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
194                                             const AnalysisState &state) const {
195     if (&opOperand == &op->getOpOperand(0) /*src*/)
196       return {op->getOpResult(0)};
197     return {};
198   }
199 
200   BufferRelation bufferRelation(Operation *op, OpResult opResult,
201                                 const AnalysisState &state) const {
202     return BufferRelation::Equivalent;
203   }
204 
205   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
206                           BufferizationState &state) const {
207     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
208     auto tensorResultType = expandShapeOp.getResultType();
209     Value buffer =
210         *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
211 
212     // Memref result type is inferred by the builder based on reassociation
213     // indices and result shape.
214     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
215         rewriter, op, tensorResultType.getShape(), buffer,
216         expandShapeOp.getReassociationIndices());
217     return success();
218   }
219 };
220 
221 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
222 struct ExtractSliceOpInterface
223     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
224                                                     tensor::ExtractSliceOp> {
225   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
226                               const AnalysisState &state) const {
227     return false;
228   }
229 
230   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
231                                const AnalysisState &state) const {
232     return false;
233   }
234 
235   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
236                                             const AnalysisState &state) const {
237     if (&opOperand == &op->getOpOperand(0) /*source*/)
238       return {op->getOpResult(0)};
239     return {};
240   }
241 
242   BufferRelation bufferRelation(Operation *op, OpResult opResult,
243                                 const AnalysisState &state) const {
244     return BufferRelation::None;
245   }
246 
247   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
248                           BufferizationState &state) const {
249     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
250     Location loc = extractSliceOp.getLoc();
251     Value srcMemref =
252         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
253                          /*forceInPlace=*/true);
254     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
255     auto dstTensorType =
256         extractSliceOp.result().getType().cast<RankedTensorType>();
257 
258     // If not inplaceable, alloc.
259     bool inplace =
260         state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
261     Value alloc;
262     if (!inplace) {
263       FailureOr<Value> allocOrFailure =
264           state.createAlloc(rewriter, loc, extractSliceOp.result());
265       if (failed(allocOrFailure))
266         return failure();
267       alloc = *allocOrFailure;
268     }
269 
270     // Expand offsets, sizes and strides to the full rank to handle the
271     // rank-reducing case.
272     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
273     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
274     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
275     OffsetSizeAndStrideOpInterface::expandToRank(
276         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
277         [&](Value target, int64_t dim) -> OpFoldResult {
278           auto shapedType = target.getType().cast<ShapedType>();
279           if (shapedType.isDynamicDim(dim))
280             return rewriter.create<memref::DimOp>(loc, target, dim).result();
281           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
282         });
283     // Bufferize to subview.
284     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
285                                  dstTensorType.getRank(), srcMemrefType,
286                                  mixedOffsets, mixedSizes, mixedStrides)
287                                  .cast<MemRefType>();
288     Value subView = rewriter.create<memref::SubViewOp>(
289         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
290         mixedStrides);
291 
292     // If not inplaceable, copy.
293     if (!inplace) {
294       // Do not copy if the copied data is never read.
295       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
296         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
297                                 alloc, state.getOptions())))
298           return failure();
299       subView = alloc;
300     }
301 
302     replaceOpWithBufferizedValues(rewriter, op, subView);
303     return success();
304   }
305 };
306 
307 /// Bufferization of tensor.extract. Replace with memref.load.
308 struct ExtractOpInterface
309     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
310                                                     tensor::ExtractOp> {
311   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
312                               const AnalysisState &state) const {
313     return true;
314   }
315 
316   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
317                                const AnalysisState &state) const {
318     return false;
319   }
320 
321   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
322                                             const AnalysisState &state) const {
323     return {};
324   }
325 
326   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
327                           BufferizationState &state) const {
328     auto extractOp = cast<tensor::ExtractOp>(op);
329     Value srcMemref =
330         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
331     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
332                                                  extractOp.indices());
333     return success();
334   }
335 };
336 
337 // Implements backtracking to traverse indices of the output buffer while
338 // iterating over op.elements().
339 static void createStores(RewriterBase &rewriter, Location loc, int dim,
340                          Value buffer, ArrayRef<int64_t> shape,
341                          ArrayRef<Value> constants,
342                          OperandRange::iterator &elementIt,
343                          SmallVectorImpl<Value> &indices) {
344   if (dim == static_cast<int>(shape.size()) - 1) {
345     for (int i = 0; i < shape.back(); ++i) {
346       indices.back() = constants[i];
347       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
348       ++elementIt;
349     }
350     return;
351   }
352   for (int i = 0; i < shape[dim]; ++i) {
353     indices[dim] = constants[i];
354     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
355                  indices);
356   }
357 }
358 
359 /// Bufferization of tensor.from_elements.
360 struct FromElementsOpInterface
361     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
362                                                     tensor::FromElementsOp> {
363   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
364                           BufferizationState &state) const {
365     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
366 
367     // Allocate a buffer for the result.
368     Location loc = op->getLoc();
369     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
370     auto shape = tensorType.getShape();
371     FailureOr<Value> maybeBuffer =
372         state.createAlloc(rewriter, loc, fromElementsOp.result());
373     if (failed(maybeBuffer))
374       return failure();
375     Value buffer = *maybeBuffer;
376 
377     // Case: tensor<0xelem_type>.
378     if (fromElementsOp.elements().empty()) {
379       replaceOpWithBufferizedValues(rewriter, op, buffer);
380       return success();
381     }
382 
383     // Case: tensor<elem_type>.
384     if (shape.empty()) {
385       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
386                                        buffer);
387       replaceOpWithBufferizedValues(rewriter, op, buffer);
388       return success();
389     }
390 
391     // Create constants for the range of possible indices [0, max{shape_i}).
392     auto maxDim = *std::max_element(shape.begin(), shape.end());
393     SmallVector<Value, 2> constants;
394     constants.reserve(maxDim);
395     for (int i = 0; i < maxDim; ++i)
396       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
397 
398     // Traverse all `elements` and create `memref.store` ops.
399     auto elementIt = fromElementsOp.elements().begin();
400     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
401     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
402                  indices);
403 
404     replaceOpWithBufferizedValues(rewriter, op, buffer);
405     return success();
406   }
407 };
408 
409 /// Bufferization of tensor.generate.
410 struct GenerateOpInterface
411     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
412                                                     tensor::GenerateOp> {
413   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
414                           BufferizationState &state) const {
415     auto generateOp = cast<tensor::GenerateOp>(op);
416 
417     // Allocate memory.
418     Location loc = op->getLoc();
419     MemRefType memrefType =
420         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
421     FailureOr<Value> maybeResult =
422         state.createAlloc(rewriter, loc, generateOp.result());
423     if (failed(maybeResult))
424       return failure();
425     Value result = *maybeResult;
426 
427     // Collect loop bounds.
428     int64_t rank = memrefType.getRank();
429     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
430     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
431     SmallVector<Value, 4> lowerBounds(rank, zero);
432     SmallVector<Value, 4> steps(rank, one);
433     SmallVector<Value, 4> upperBounds;
434     int nextDynamicIndex = 0;
435     for (int i = 0; i < rank; i++) {
436       Value upperBound = memrefType.isDynamicDim(i)
437                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
438                              : rewriter.create<arith::ConstantIndexOp>(
439                                    loc, memrefType.getDimSize(i));
440       upperBounds.push_back(upperBound);
441     }
442 
443     // Generate tensor elements with a parallel loop that stores into
444     // each element of the resulting memref. We use mergeBlockBefore to "move"
445     // this op's body into the scf.parallel's body.
446     auto parallel =
447         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
448     Block *parallelBody = parallel.getBody();
449     rewriter.mergeBlockBefore(generateOp.getBody(),
450                               parallelBody->getTerminator(),
451                               parallelBody->getArguments());
452     // Replace the inlined yield op with a store op. The scf.parallel's builder
453     // already populated an scf.yield at the end, so we don't need to worry
454     // about creating that.
455     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
456     rewriter.setInsertionPointAfter(elementYield);
457     rewriter.replaceOpWithNewOp<memref::StoreOp>(
458         elementYield, elementYield->getOperands()[0], result,
459         parallelBody->getArguments());
460 
461     replaceOpWithBufferizedValues(rewriter, op, result);
462     return success();
463   }
464 };
465 
466 /// Bufferization of tensor.insert. Replace with memref.store.
467 struct InsertOpInterface
468     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
469                                                     tensor::InsertOp> {
470   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
471                               const AnalysisState &state) const {
472     return true;
473   }
474 
475   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
476                                const AnalysisState &state) const {
477     return true;
478   }
479 
480   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
481                                             const AnalysisState &state) const {
482     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
483            "expected dest OpOperand");
484     return {op->getOpResult(0)};
485   }
486 
487   SmallVector<OpOperand *>
488   getAliasingOpOperand(Operation *op, OpResult opResult,
489                        const AnalysisState &state) const {
490     return {&op->getOpOperand(1) /*dest*/};
491   }
492 
493   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
494                           BufferizationState &state) const {
495     auto insertOp = cast<tensor::InsertOp>(op);
496     FailureOr<Value> destMemref =
497         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
498     if (failed(destMemref))
499       return failure();
500     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
501                                      *destMemref, insertOp.indices());
502     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
503     return success();
504   }
505 
506   BufferRelation bufferRelation(Operation *op, OpResult opResult,
507                                 const AnalysisState &state) const {
508     return BufferRelation::Equivalent;
509   }
510 };
511 
512 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
513 /// equivalent operand / result and same offset/sizes/strides specification).
514 ///
515 /// This is one particular type of relationship between ops on tensors that
516 /// reduce to an equivalence on buffers. This should be generalized and
517 /// exposed as interfaces on the proper types.
518 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
519                                          ExtractSliceOp st, InsertSliceOp sti) {
520   if (!st || !sti)
521     return false;
522   if (sti != sti &&
523       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
524     return false;
525   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
526     return false;
527   return true;
528 }
529 
530 /// Return true if `value` is originating from an ExtractSliceOp that matches
531 /// the given InsertSliceOp.
532 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
533                                       InsertSliceOp insertOp) {
534   auto condition = [&](Value val) {
535     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
536       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
537         return true;
538     return false;
539   };
540 
541   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
542                       condition);
543 }
544 
545 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
546 /// certain circumstances, this op can also be a no-op.
547 struct InsertSliceOpInterface
548     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
549                                                     tensor::InsertSliceOp> {
550   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
551                               const AnalysisState &state) const {
552     return true;
553   }
554 
555   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
556                                const AnalysisState &state) const {
557     return &opOperand == &op->getOpOperand(1) /*dest*/;
558   }
559 
560   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
561                                             const AnalysisState &state) const {
562     if (&opOperand == &op->getOpOperand(1) /*dest*/)
563       return {op->getResult(0)};
564     return {};
565   }
566 
567   BufferRelation bufferRelation(Operation *op, OpResult opResult,
568                                 const AnalysisState &state) const {
569     return BufferRelation::Equivalent;
570   }
571 
572   bool isNotConflicting(Operation *op, OpOperand *uRead,
573                         OpOperand *uConflictingWrite,
574                         const AnalysisState &state) const {
575     Operation *readingOp = uRead->getOwner();
576     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
577 
578     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
579     // uRead is an InsertSliceOp...
580     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
581       // As an example, consider the following IR.
582       //
583       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
584       // %1 = linalg.fill %cst, %0 {inplace= [true] }
585       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
586       //     {inplace= [true] }
587 
588       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
589       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
590           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
591                                     insertSliceOp))
592         // Case 1: The main insight is that InsertSliceOp reads only part of
593         // the destination tensor. The overwritten area is not read. If
594         // uConflictingWrite writes into exactly the memory location that is
595         // being read by uRead, this is not a conflict.
596         //
597         // In the above example:
598         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
599         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
600         //
601         // The read of %t does not conflict with the write of the FillOp
602         // (same aliases!) because the area that the FillOp operates on is
603         // exactly the one that is *not* read via %t.
604         return true;
605 
606       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
607           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
608           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
609         // Case 2: The read of the source tensor and the write to the dest
610         // tensor via an InsertSliceOp is not a conflict if the read is
611         // reading exactly that part of an equivalent tensor that the
612         // InsertSliceOp is writing.
613         //
614         // In the above example:
615         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
616         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
617         return true;
618     }
619 
620     // If uConflictingWrite is an InsertSliceOp...
621     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
622       // As an example, consider the following IR.
623       //
624       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
625       // %1 = linalg.fill %cst, %0 {inplace= [true] }
626       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
627       //     {inplace= [true] }
628       // %3 = vector.transfer_read %1, %cst
629       //
630       // In the above example:
631       // uRead             = OpOperand 0 (%1) of vector.transfer_read
632       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
633       // lastWrite         = %1
634       //
635       // This is not a conflict because the InsertSliceOp overwrites the
636       // memory segment of %1 with the exact same data. (Effectively, there
637       // is no memory write here.)
638       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
639           state.areEquivalentBufferizedValues(uRead->get(),
640                                               insertSliceOp.source()) &&
641           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
642                                     insertSliceOp))
643         return true;
644 
645     return false;
646   }
647 
648   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
649                           BufferizationState &state) const {
650     // insert_slice ops arise from tiling and bufferizing them out-of-place is
651     // generally a deal breaker. When used with loops, this ends up cloning the
652     // whole tensor on every single iteration and is a symptom of a
653     // catastrophically bad scheduling decision.
654     // TODO: be very loud about it or even consider failing the pass.
655     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
656     Location loc = insertSliceOp.getLoc();
657 
658     // When bufferizing out-of-place, `getResultBuffer` allocates.
659     FailureOr<Value> dstMemref =
660         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
661     if (failed(dstMemref))
662       return failure();
663 
664     // Expand offsets, sizes and strides to the full rank to handle the
665     // rank-reducing case.
666     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
667     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
668     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
669     OffsetSizeAndStrideOpInterface::expandToRank(
670         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
671         [&](Value target, int64_t dim) -> OpFoldResult {
672           auto shapedType = target.getType().cast<ShapedType>();
673           if (shapedType.isDynamicDim(dim))
674             return rewriter.create<memref::DimOp>(loc, target, dim).result();
675           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
676         });
677     // Take a subview of the dst.
678     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
679     auto subviewMemRefType =
680         memref::SubViewOp::inferRankReducedResultType(
681             insertSliceOp.getSourceType().getRank(), dstMemrefType,
682             mixedOffsets, mixedSizes, mixedStrides)
683             .cast<MemRefType>();
684     Value subView = rewriter.create<memref::SubViewOp>(
685         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
686         mixedStrides);
687 
688     // Copy tensor. If this tensor.insert_slice has a matching
689     // tensor.extract_slice, the copy operation will eventually fold away.
690     Value srcMemref =
691         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
692     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
693                             state.getOptions())))
694       return failure();
695 
696     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
697     return success();
698   }
699 };
700 
701 /// Bufferization of tensor.rank. Replace with memref.rank.
702 struct RankOpInterface
703     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
704                                                     tensor::RankOp> {
705   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
706                               const AnalysisState &state) const {
707     return true;
708   }
709 
710   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
711                                const AnalysisState &state) const {
712     return false;
713   }
714 
715   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
716                                             const AnalysisState &state) const {
717     return {};
718   }
719 
720   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
721                           BufferizationState &state) const {
722     auto rankOp = cast<tensor::RankOp>(op);
723     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
724     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
725                                                  v);
726     return success();
727   }
728 };
729 
730 } // namespace
731 } // namespace tensor
732 } // namespace mlir
733 
734 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
735     DialectRegistry &registry) {
736   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
737     CastOp::attachInterface<CastOpInterface>(*ctx);
738     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
739     DimOp::attachInterface<DimOpInterface>(*ctx);
740     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
741     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
742     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
743     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
744     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
745     InsertOp::attachInterface<InsertOpInterface>(*ctx);
746     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
747     RankOp::attachInterface<RankOpInterface>(*ctx);
748   });
749 }
750