1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::tensor;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 
25 struct CastOpInterface
26     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
27                                                     tensor::CastOp> {
28   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
29                               const AnalysisState &state) const {
30     return false;
31   }
32 
33   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
34                                const AnalysisState &state) const {
35     return false;
36   }
37 
38   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
39                                             const AnalysisState &state) const {
40     return {op->getResult(0)};
41   }
42 
43   BufferRelation bufferRelation(Operation *op, OpResult opResult,
44                                 const AnalysisState &state) const {
45     return BufferRelation::Equivalent;
46   }
47 
48   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49                           BufferizationState &state) const {
50     auto castOp = cast<tensor::CastOp>(op);
51 
52     // The result buffer still has the old (pre-cast) type.
53     FailureOr<Value> resultBuffer =
54         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
55     if (failed(resultBuffer))
56       return failure();
57     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
58     Attribute memorySpace = sourceMemRefType.getMemorySpace();
59     TensorType resultTensorType =
60         castOp.getResult().getType().cast<TensorType>();
61     MemRefLayoutAttrInterface layout;
62 
63     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
64       if (resultTensorType.isa<RankedTensorType>())
65         layout = rankedMemRefType.getLayout();
66 
67     // Compute the new memref type.
68     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
69                                           layout, memorySpace);
70 
71     // Replace the op with a memref.cast.
72     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
73                                              resultMemRefType) &&
74            "CallOp::bufferize: cast incompatible");
75     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
76                                                  *resultBuffer);
77 
78     return success();
79   }
80 };
81 
82 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
83 struct CollapseShapeOpInterface
84     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
85                                                     tensor::CollapseShapeOp> {
86   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
87                               const AnalysisState &state) const {
88     return false;
89   }
90 
91   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
92                                const AnalysisState &state) const {
93     return false;
94   }
95 
96   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
97                                             const AnalysisState &state) const {
98     if (&opOperand == &op->getOpOperand(0) /*src*/)
99       return {op->getOpResult(0)};
100     return {};
101   }
102 
103   BufferRelation bufferRelation(Operation *op, OpResult opResult,
104                                 const AnalysisState &state) const {
105     return BufferRelation::Equivalent;
106   }
107 
108   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
109                           BufferizationState &state) const {
110     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
111     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
112     Value buffer =
113         *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
114 
115     if (tensorResultType.getRank() == 0) {
116       // 0-d collapses must go through a different op builder.
117       auto bufferType = buffer.getType().cast<MemRefType>();
118       // Assume identity layout: No offset.
119       assert(bufferType.getLayout().isIdentity() &&
120              "non-zero offset for 0-d collapse not supported");
121       MemRefLayoutAttrInterface layout;
122       auto resultType = MemRefType::get({}, tensorResultType.getElementType(),
123                                         layout, bufferType.getMemorySpace());
124       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
125           rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
126       return success();
127     }
128 
129     // Result type is inferred by the builder.
130     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
131         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
132     return success();
133   }
134 };
135 
136 /// Bufferization of tensor.dim. Replace with memref.dim.
137 struct DimOpInterface
138     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
139                                                     tensor::DimOp> {
140   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
141                               const AnalysisState &state) const {
142     return true;
143   }
144 
145   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
146                                const AnalysisState &state) const {
147     return false;
148   }
149 
150   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
151                                             const AnalysisState &state) const {
152     return {};
153   }
154 
155   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
156                           BufferizationState &state) const {
157     auto dimOp = cast<tensor::DimOp>(op);
158     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
159     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
160     return success();
161   }
162 };
163 
164 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
165 struct ExpandShapeOpInterface
166     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
167                                                     tensor::ExpandShapeOp> {
168   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
169                               const AnalysisState &state) const {
170     return false;
171   }
172 
173   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
174                                const AnalysisState &state) const {
175     return false;
176   }
177 
178   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
179                                             const AnalysisState &state) const {
180     if (&opOperand == &op->getOpOperand(0) /*src*/)
181       return {op->getOpResult(0)};
182     return {};
183   }
184 
185   BufferRelation bufferRelation(Operation *op, OpResult opResult,
186                                 const AnalysisState &state) const {
187     return BufferRelation::Equivalent;
188   }
189 
190   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
191                           BufferizationState &state) const {
192     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
193     auto tensorResultType = expandShapeOp.getResultType();
194     Value buffer =
195         *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
196 
197     // Memref result type is inferred by the builder based on reassociation
198     // indices and result shape.
199     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
200         rewriter, op, tensorResultType.getShape(), buffer,
201         expandShapeOp.getReassociationIndices());
202     return success();
203   }
204 };
205 
206 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
207 struct ExtractSliceOpInterface
208     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
209                                                     tensor::ExtractSliceOp> {
210   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
211                               const AnalysisState &state) const {
212     return false;
213   }
214 
215   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
216                                const AnalysisState &state) const {
217     return false;
218   }
219 
220   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
221                                             const AnalysisState &state) const {
222     if (&opOperand == &op->getOpOperand(0) /*source*/)
223       return {op->getOpResult(0)};
224     return {};
225   }
226 
227   BufferRelation bufferRelation(Operation *op, OpResult opResult,
228                                 const AnalysisState &state) const {
229     return BufferRelation::None;
230   }
231 
232   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
233                           BufferizationState &state) const {
234     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
235     Location loc = extractSliceOp.getLoc();
236     Value srcMemref =
237         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
238                          /*forceInPlace=*/true);
239     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
240     auto dstTensorType =
241         extractSliceOp.result().getType().cast<RankedTensorType>();
242 
243     // If not inplaceable, alloc.
244     bool inplace =
245         state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
246     Value alloc;
247     if (!inplace) {
248       FailureOr<Value> allocOrFailure =
249           state.createAlloc(rewriter, loc, extractSliceOp.result());
250       if (failed(allocOrFailure))
251         return failure();
252       alloc = *allocOrFailure;
253     }
254 
255     // Expand offsets, sizes and strides to the full rank to handle the
256     // rank-reducing case.
257     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
258     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
259     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
260     OffsetSizeAndStrideOpInterface::expandToRank(
261         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
262         [&](Value target, int64_t dim) -> OpFoldResult {
263           auto shapedType = target.getType().cast<ShapedType>();
264           if (shapedType.isDynamicDim(dim))
265             return rewriter.create<memref::DimOp>(loc, target, dim).result();
266           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
267         });
268     // Bufferize to subview.
269     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
270                                  dstTensorType.getRank(), srcMemrefType,
271                                  mixedOffsets, mixedSizes, mixedStrides)
272                                  .cast<MemRefType>();
273     Value subView = rewriter.create<memref::SubViewOp>(
274         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
275         mixedStrides);
276 
277     // If not inplaceable, copy.
278     if (!inplace) {
279       // Do not copy if the copied data is never read.
280       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
281         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
282                                 alloc, state.getOptions())))
283           return failure();
284       subView = alloc;
285     }
286 
287     replaceOpWithBufferizedValues(rewriter, op, subView);
288     return success();
289   }
290 };
291 
292 /// Bufferization of tensor.extract. Replace with memref.load.
293 struct ExtractOpInterface
294     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
295                                                     tensor::ExtractOp> {
296   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
297                               const AnalysisState &state) const {
298     return true;
299   }
300 
301   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
302                                const AnalysisState &state) const {
303     return false;
304   }
305 
306   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
307                                             const AnalysisState &state) const {
308     return {};
309   }
310 
311   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
312                           BufferizationState &state) const {
313     auto extractOp = cast<tensor::ExtractOp>(op);
314     Value srcMemref =
315         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
316     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
317                                                  extractOp.indices());
318     return success();
319   }
320 };
321 
322 // Implements backtracking to traverse indices of the output buffer while
323 // iterating over op.elements().
324 static void createStores(RewriterBase &rewriter, Location loc, int dim,
325                          Value buffer, ArrayRef<int64_t> shape,
326                          ArrayRef<Value> constants,
327                          OperandRange::iterator &elementIt,
328                          SmallVectorImpl<Value> &indices) {
329   if (dim == static_cast<int>(shape.size()) - 1) {
330     for (int i = 0; i < shape.back(); ++i) {
331       indices.back() = constants[i];
332       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
333       ++elementIt;
334     }
335     return;
336   }
337   for (int i = 0; i < shape[dim]; ++i) {
338     indices[dim] = constants[i];
339     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
340                  indices);
341   }
342 }
343 
344 /// Bufferization of tensor.from_elements.
345 struct FromElementsOpInterface
346     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
347                                                     tensor::FromElementsOp> {
348   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
349                           BufferizationState &state) const {
350     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
351 
352     // Allocate a buffer for the result.
353     Location loc = op->getLoc();
354     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
355     auto shape = tensorType.getShape();
356     FailureOr<Value> maybeBuffer =
357         state.createAlloc(rewriter, loc, fromElementsOp.result());
358     if (failed(maybeBuffer))
359       return failure();
360     Value buffer = *maybeBuffer;
361 
362     // Case: tensor<0xelem_type>.
363     if (fromElementsOp.elements().empty()) {
364       replaceOpWithBufferizedValues(rewriter, op, buffer);
365       return success();
366     }
367 
368     // Case: tensor<elem_type>.
369     if (shape.empty()) {
370       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
371                                        buffer);
372       replaceOpWithBufferizedValues(rewriter, op, buffer);
373       return success();
374     }
375 
376     // Create constants for the range of possible indices [0, max{shape_i}).
377     auto maxDim = *std::max_element(shape.begin(), shape.end());
378     SmallVector<Value, 2> constants;
379     constants.reserve(maxDim);
380     for (int i = 0; i < maxDim; ++i)
381       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
382 
383     // Traverse all `elements` and create `memref.store` ops.
384     auto elementIt = fromElementsOp.elements().begin();
385     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
386     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
387                  indices);
388 
389     replaceOpWithBufferizedValues(rewriter, op, buffer);
390     return success();
391   }
392 };
393 
394 /// Bufferization of tensor.generate.
395 struct GenerateOpInterface
396     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
397                                                     tensor::GenerateOp> {
398   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
399                           BufferizationState &state) const {
400     auto generateOp = cast<tensor::GenerateOp>(op);
401 
402     // Allocate memory.
403     Location loc = op->getLoc();
404     MemRefType memrefType =
405         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
406     FailureOr<Value> maybeResult =
407         state.createAlloc(rewriter, loc, generateOp.result());
408     if (failed(maybeResult))
409       return failure();
410     Value result = *maybeResult;
411 
412     // Collect loop bounds.
413     int64_t rank = memrefType.getRank();
414     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
415     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
416     SmallVector<Value, 4> lowerBounds(rank, zero);
417     SmallVector<Value, 4> steps(rank, one);
418     SmallVector<Value, 4> upperBounds;
419     int nextDynamicIndex = 0;
420     for (int i = 0; i < rank; i++) {
421       Value upperBound = memrefType.isDynamicDim(i)
422                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
423                              : rewriter.create<arith::ConstantIndexOp>(
424                                    loc, memrefType.getDimSize(i));
425       upperBounds.push_back(upperBound);
426     }
427 
428     // Generate tensor elements with a parallel loop that stores into
429     // each element of the resulting memref. We use mergeBlockBefore to "move"
430     // this op's body into the scf.parallel's body.
431     auto parallel =
432         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
433     Block *parallelBody = parallel.getBody();
434     rewriter.mergeBlockBefore(generateOp.getBody(),
435                               parallelBody->getTerminator(),
436                               parallelBody->getArguments());
437     // Replace the inlined yield op with a store op. The scf.parallel's builder
438     // already populated an scf.yield at the end, so we don't need to worry
439     // about creating that.
440     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
441     rewriter.setInsertionPointAfter(elementYield);
442     rewriter.replaceOpWithNewOp<memref::StoreOp>(
443         elementYield, elementYield->getOperands()[0], result,
444         parallelBody->getArguments());
445 
446     replaceOpWithBufferizedValues(rewriter, op, result);
447     return success();
448   }
449 };
450 
451 /// Bufferization of tensor.insert. Replace with memref.store.
452 struct InsertOpInterface
453     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
454                                                     tensor::InsertOp> {
455   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
456                               const AnalysisState &state) const {
457     return true;
458   }
459 
460   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
461                                const AnalysisState &state) const {
462     return true;
463   }
464 
465   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
466                                             const AnalysisState &state) const {
467     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
468            "expected dest OpOperand");
469     return {op->getOpResult(0)};
470   }
471 
472   SmallVector<OpOperand *>
473   getAliasingOpOperand(Operation *op, OpResult opResult,
474                        const AnalysisState &state) const {
475     return {&op->getOpOperand(1) /*dest*/};
476   }
477 
478   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
479                           BufferizationState &state) const {
480     auto insertOp = cast<tensor::InsertOp>(op);
481     FailureOr<Value> destMemref =
482         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
483     if (failed(destMemref))
484       return failure();
485     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
486                                      *destMemref, insertOp.indices());
487     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
488     return success();
489   }
490 
491   BufferRelation bufferRelation(Operation *op, OpResult opResult,
492                                 const AnalysisState &state) const {
493     return BufferRelation::Equivalent;
494   }
495 };
496 
497 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
498 /// equivalent operand / result and same offset/sizes/strides specification).
499 ///
500 /// This is one particular type of relationship between ops on tensors that
501 /// reduce to an equivalence on buffers. This should be generalized and
502 /// exposed as interfaces on the proper types.
503 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
504                                          ExtractSliceOp st, InsertSliceOp sti) {
505   if (!st || !sti)
506     return false;
507   if (sti != sti &&
508       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
509     return false;
510   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
511     return false;
512   return true;
513 }
514 
515 /// Return true if `value` is originating from an ExtractSliceOp that matches
516 /// the given InsertSliceOp.
517 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
518                                       InsertSliceOp insertOp) {
519   auto condition = [&](Value val) {
520     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
521       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
522         return true;
523     return false;
524   };
525 
526   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
527                       condition);
528 }
529 
530 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
531 /// certain circumstances, this op can also be a no-op.
532 struct InsertSliceOpInterface
533     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
534                                                     tensor::InsertSliceOp> {
535   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
536                               const AnalysisState &state) const {
537     return true;
538   }
539 
540   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
541                                const AnalysisState &state) const {
542     return &opOperand == &op->getOpOperand(1) /*dest*/;
543   }
544 
545   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
546                                             const AnalysisState &state) const {
547     if (&opOperand == &op->getOpOperand(1) /*dest*/)
548       return {op->getResult(0)};
549     return {};
550   }
551 
552   BufferRelation bufferRelation(Operation *op, OpResult opResult,
553                                 const AnalysisState &state) const {
554     return BufferRelation::Equivalent;
555   }
556 
557   bool isNotConflicting(Operation *op, OpOperand *uRead,
558                         OpOperand *uConflictingWrite,
559                         const AnalysisState &state) const {
560     Operation *readingOp = uRead->getOwner();
561     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
562 
563     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
564     // uRead is an InsertSliceOp...
565     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
566       // As an example, consider the following IR.
567       //
568       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
569       // %1 = linalg.fill %cst, %0 {inplace= [true] }
570       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
571       //     {inplace= [true] }
572 
573       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
574       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
575           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
576                                     insertSliceOp))
577         // Case 1: The main insight is that InsertSliceOp reads only part of
578         // the destination tensor. The overwritten area is not read. If
579         // uConflictingWrite writes into exactly the memory location that is
580         // being read by uRead, this is not a conflict.
581         //
582         // In the above example:
583         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
584         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
585         //
586         // The read of %t does not conflict with the write of the FillOp
587         // (same aliases!) because the area that the FillOp operates on is
588         // exactly the one that is *not* read via %t.
589         return true;
590 
591       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
592           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
593           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
594         // Case 2: The read of the source tensor and the write to the dest
595         // tensor via an InsertSliceOp is not a conflict if the read is
596         // reading exactly that part of an equivalent tensor that the
597         // InsertSliceOp is writing.
598         //
599         // In the above example:
600         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
601         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
602         return true;
603     }
604 
605     // If uConflictingWrite is an InsertSliceOp...
606     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
607       // As an example, consider the following IR.
608       //
609       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
610       // %1 = linalg.fill %cst, %0 {inplace= [true] }
611       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
612       //     {inplace= [true] }
613       // %3 = vector.transfer_read %1, %cst
614       //
615       // In the above example:
616       // uRead             = OpOperand 0 (%1) of vector.transfer_read
617       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
618       // lastWrite         = %1
619       //
620       // This is not a conflict because the InsertSliceOp overwrites the
621       // memory segment of %1 with the exact same data. (Effectively, there
622       // is no memory write here.)
623       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
624           state.areEquivalentBufferizedValues(uRead->get(),
625                                               insertSliceOp.source()) &&
626           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
627                                     insertSliceOp))
628         return true;
629 
630     return false;
631   }
632 
633   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
634                           BufferizationState &state) const {
635     // insert_slice ops arise from tiling and bufferizing them out-of-place is
636     // generally a deal breaker. When used with loops, this ends up cloning the
637     // whole tensor on every single iteration and is a symptom of a
638     // catastrophically bad scheduling decision.
639     // TODO: be very loud about it or even consider failing the pass.
640     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
641     Location loc = insertSliceOp.getLoc();
642 
643     // When bufferizing out-of-place, `getResultBuffer` allocates.
644     FailureOr<Value> dstMemref =
645         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
646     if (failed(dstMemref))
647       return failure();
648 
649     // Expand offsets, sizes and strides to the full rank to handle the
650     // rank-reducing case.
651     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
652     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
653     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
654     OffsetSizeAndStrideOpInterface::expandToRank(
655         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
656         [&](Value target, int64_t dim) -> OpFoldResult {
657           auto shapedType = target.getType().cast<ShapedType>();
658           if (shapedType.isDynamicDim(dim))
659             return rewriter.create<memref::DimOp>(loc, target, dim).result();
660           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
661         });
662     // Take a subview of the dst.
663     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
664     auto subviewMemRefType =
665         memref::SubViewOp::inferRankReducedResultType(
666             insertSliceOp.getSourceType().getRank(), dstMemrefType,
667             mixedOffsets, mixedSizes, mixedStrides)
668             .cast<MemRefType>();
669     Value subView = rewriter.create<memref::SubViewOp>(
670         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
671         mixedStrides);
672 
673     // Copy tensor. If this tensor.insert_slice has a matching
674     // tensor.extract_slice, the copy operation will eventually fold away.
675     Value srcMemref =
676         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
677     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
678                             state.getOptions())))
679       return failure();
680 
681     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
682     return success();
683   }
684 };
685 
686 /// Bufferization of tensor.rank. Replace with memref.rank.
687 struct RankOpInterface
688     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
689                                                     tensor::RankOp> {
690   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
691                               const AnalysisState &state) const {
692     return true;
693   }
694 
695   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
696                                const AnalysisState &state) const {
697     return false;
698   }
699 
700   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
701                                             const AnalysisState &state) const {
702     return {};
703   }
704 
705   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
706                           BufferizationState &state) const {
707     auto rankOp = cast<tensor::RankOp>(op);
708     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
709     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
710                                                  v);
711     return success();
712   }
713 };
714 
715 } // namespace
716 } // namespace tensor
717 } // namespace mlir
718 
719 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
720     DialectRegistry &registry) {
721   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
722     CastOp::attachInterface<CastOpInterface>(*ctx);
723     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
724     DimOp::attachInterface<DimOpInterface>(*ctx);
725     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
726     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
727     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
728     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
729     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
730     InsertOp::attachInterface<InsertOpInterface>(*ctx);
731     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
732     RankOp::attachInterface<RankOpInterface>(*ctx);
733   });
734 }
735