1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::tensor;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 
25 struct CastOpInterface
26     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
27                                                     tensor::CastOp> {
28   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
29                               const AnalysisState &state) const {
30     return false;
31   }
32 
33   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
34                                const AnalysisState &state) const {
35     return false;
36   }
37 
38   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
39                                             const AnalysisState &state) const {
40     return {op->getResult(0)};
41   }
42 
43   BufferRelation bufferRelation(Operation *op, OpResult opResult,
44                                 const AnalysisState &state) const {
45     return BufferRelation::Equivalent;
46   }
47 
48   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49                           BufferizationState &state) const {
50     auto castOp = cast<tensor::CastOp>(op);
51 
52     // The result buffer still has the old (pre-cast) type.
53     FailureOr<Value> resultBuffer =
54         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
55     if (failed(resultBuffer))
56       return failure();
57     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
58     Attribute memorySpace = sourceMemRefType.getMemorySpace();
59     TensorType resultTensorType =
60         castOp.getResult().getType().cast<TensorType>();
61     MemRefLayoutAttrInterface layout;
62 
63     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
64       if (resultTensorType.isa<RankedTensorType>())
65         layout = rankedMemRefType.getLayout();
66 
67     // Compute the new memref type.
68     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
69                                           layout, memorySpace);
70 
71     // Replace the op with a memref.cast.
72     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
73                                              resultMemRefType) &&
74            "CallOp::bufferize: cast incompatible");
75     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
76                                                  *resultBuffer);
77 
78     return success();
79   }
80 };
81 
82 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
83 struct CollapseShapeOpInterface
84     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
85                                                     tensor::CollapseShapeOp> {
86   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
87                               const AnalysisState &state) const {
88     return false;
89   }
90 
91   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
92                                const AnalysisState &state) const {
93     return false;
94   }
95 
96   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
97                                             const AnalysisState &state) const {
98     if (&opOperand == &op->getOpOperand(0) /*src*/)
99       return {op->getOpResult(0)};
100     return {};
101   }
102 
103   BufferRelation bufferRelation(Operation *op, OpResult opResult,
104                                 const AnalysisState &state) const {
105     return BufferRelation::Equivalent;
106   }
107 
108   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
109                           BufferizationState &state) const {
110     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
111     Value buffer =
112         *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
113     Type resultType =
114         getMemRefType(collapseShapeOp.getResultType(), state.getOptions());
115     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
116         rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
117     return success();
118   }
119 };
120 
121 /// Bufferization of tensor.dim. Replace with memref.dim.
122 struct DimOpInterface
123     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
124                                                     tensor::DimOp> {
125   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
126                               const AnalysisState &state) const {
127     return true;
128   }
129 
130   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
131                                const AnalysisState &state) const {
132     return false;
133   }
134 
135   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
136                                             const AnalysisState &state) const {
137     return {};
138   }
139 
140   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
141                           BufferizationState &state) const {
142     auto dimOp = cast<tensor::DimOp>(op);
143     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
144     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
145     return success();
146   }
147 };
148 
149 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
150 struct ExpandShapeOpInterface
151     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
152                                                     tensor::ExpandShapeOp> {
153   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
154                               const AnalysisState &state) const {
155     return false;
156   }
157 
158   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
159                                const AnalysisState &state) const {
160     return false;
161   }
162 
163   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
164                                             const AnalysisState &state) const {
165     if (&opOperand == &op->getOpOperand(0) /*src*/)
166       return {op->getOpResult(0)};
167     return {};
168   }
169 
170   BufferRelation bufferRelation(Operation *op, OpResult opResult,
171                                 const AnalysisState &state) const {
172     return BufferRelation::Equivalent;
173   }
174 
175   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
176                           BufferizationState &state) const {
177     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
178     Value buffer =
179         *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
180     Type resultType =
181         getMemRefType(expandShapeOp.getResultType(), state.getOptions());
182     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
183         rewriter, op, resultType, buffer, expandShapeOp.reassociation());
184     return success();
185   }
186 };
187 
188 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
189 struct ExtractSliceOpInterface
190     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
191                                                     tensor::ExtractSliceOp> {
192   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
193                               const AnalysisState &state) const {
194     return false;
195   }
196 
197   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
198                                const AnalysisState &state) const {
199     return false;
200   }
201 
202   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
203                                             const AnalysisState &state) const {
204     if (&opOperand == &op->getOpOperand(0) /*source*/)
205       return {op->getOpResult(0)};
206     return {};
207   }
208 
209   BufferRelation bufferRelation(Operation *op, OpResult opResult,
210                                 const AnalysisState &state) const {
211     return BufferRelation::None;
212   }
213 
214   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
215                           BufferizationState &state) const {
216     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
217     Location loc = extractSliceOp.getLoc();
218     Value srcMemref =
219         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
220                          /*forceInPlace=*/true);
221     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
222     auto dstTensorType =
223         extractSliceOp.result().getType().cast<RankedTensorType>();
224 
225     // If not inplaceable, alloc.
226     bool inplace =
227         state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
228     Value alloc;
229     if (!inplace) {
230       FailureOr<Value> allocOrFailure =
231           state.createAlloc(rewriter, loc, extractSliceOp.result());
232       if (failed(allocOrFailure))
233         return failure();
234       alloc = *allocOrFailure;
235     }
236 
237     // Expand offsets, sizes and strides to the full rank to handle the
238     // rank-reducing case.
239     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
240     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
241     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
242     OffsetSizeAndStrideOpInterface::expandToRank(
243         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
244         [&](Value target, int64_t dim) -> OpFoldResult {
245           auto shapedType = target.getType().cast<ShapedType>();
246           if (shapedType.isDynamicDim(dim))
247             return rewriter.create<memref::DimOp>(loc, target, dim).result();
248           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
249         });
250     // Bufferize to subview.
251     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
252                                  dstTensorType.getRank(), srcMemrefType,
253                                  mixedOffsets, mixedSizes, mixedStrides)
254                                  .cast<MemRefType>();
255     Value subView = rewriter.create<memref::SubViewOp>(
256         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
257         mixedStrides);
258 
259     // If not inplaceable, copy.
260     if (!inplace) {
261       // Do not copy if the copied data is never read.
262       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
263         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
264                                 alloc, state.getOptions())))
265           return failure();
266       subView = alloc;
267     }
268 
269     replaceOpWithBufferizedValues(rewriter, op, subView);
270     return success();
271   }
272 };
273 
274 /// Bufferization of tensor.extract. Replace with memref.load.
275 struct ExtractOpInterface
276     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
277                                                     tensor::ExtractOp> {
278   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
279                               const AnalysisState &state) const {
280     return true;
281   }
282 
283   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
284                                const AnalysisState &state) const {
285     return false;
286   }
287 
288   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
289                                             const AnalysisState &state) const {
290     return {};
291   }
292 
293   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
294                           BufferizationState &state) const {
295     auto extractOp = cast<tensor::ExtractOp>(op);
296     Value srcMemref =
297         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
298     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
299                                                  extractOp.indices());
300     return success();
301   }
302 };
303 
304 // Implements backtracking to traverse indices of the output buffer while
305 // iterating over op.elements().
306 static void createStores(RewriterBase &rewriter, Location loc, int dim,
307                          Value buffer, ArrayRef<int64_t> shape,
308                          ArrayRef<Value> constants,
309                          OperandRange::iterator &elementIt,
310                          SmallVectorImpl<Value> &indices) {
311   if (dim == static_cast<int>(shape.size()) - 1) {
312     for (int i = 0; i < shape.back(); ++i) {
313       indices.back() = constants[i];
314       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
315       ++elementIt;
316     }
317     return;
318   }
319   for (int i = 0; i < shape[dim]; ++i) {
320     indices[dim] = constants[i];
321     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
322                  indices);
323   }
324 }
325 
326 /// Bufferization of tensor.from_elements.
327 struct FromElementsOpInterface
328     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
329                                                     tensor::FromElementsOp> {
330   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
331                           BufferizationState &state) const {
332     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
333 
334     // Allocate a buffer for the result.
335     Location loc = op->getLoc();
336     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
337     auto shape = tensorType.getShape();
338     FailureOr<Value> maybeBuffer =
339         state.createAlloc(rewriter, loc, fromElementsOp.result());
340     if (failed(maybeBuffer))
341       return failure();
342     Value buffer = *maybeBuffer;
343 
344     // Case: tensor<0xelem_type>.
345     if (fromElementsOp.elements().empty()) {
346       replaceOpWithBufferizedValues(rewriter, op, buffer);
347       return success();
348     }
349 
350     // Case: tensor<elem_type>.
351     if (shape.empty()) {
352       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
353                                        buffer);
354       replaceOpWithBufferizedValues(rewriter, op, buffer);
355       return success();
356     }
357 
358     // Create constants for the range of possible indices [0, max{shape_i}).
359     auto maxDim = *std::max_element(shape.begin(), shape.end());
360     SmallVector<Value, 2> constants;
361     constants.reserve(maxDim);
362     for (int i = 0; i < maxDim; ++i)
363       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
364 
365     // Traverse all `elements` and create `memref.store` ops.
366     auto elementIt = fromElementsOp.elements().begin();
367     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
368     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
369                  indices);
370 
371     replaceOpWithBufferizedValues(rewriter, op, buffer);
372     return success();
373   }
374 };
375 
376 /// Bufferization of tensor.generate.
377 struct GenerateOpInterface
378     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
379                                                     tensor::GenerateOp> {
380   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
381                           BufferizationState &state) const {
382     auto generateOp = cast<tensor::GenerateOp>(op);
383 
384     // Allocate memory.
385     Location loc = op->getLoc();
386     MemRefType memrefType =
387         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
388     FailureOr<Value> maybeResult =
389         state.createAlloc(rewriter, loc, generateOp.result());
390     if (failed(maybeResult))
391       return failure();
392     Value result = *maybeResult;
393 
394     // Collect loop bounds.
395     int64_t rank = memrefType.getRank();
396     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
397     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
398     SmallVector<Value, 4> lowerBounds(rank, zero);
399     SmallVector<Value, 4> steps(rank, one);
400     SmallVector<Value, 4> upperBounds;
401     int nextDynamicIndex = 0;
402     for (int i = 0; i < rank; i++) {
403       Value upperBound = memrefType.isDynamicDim(i)
404                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
405                              : rewriter.create<arith::ConstantIndexOp>(
406                                    loc, memrefType.getDimSize(i));
407       upperBounds.push_back(upperBound);
408     }
409 
410     // Generate tensor elements with a parallel loop that stores into
411     // each element of the resulting memref. We use mergeBlockBefore to "move"
412     // this op's body into the scf.parallel's body.
413     auto parallel =
414         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
415     Block *parallelBody = parallel.getBody();
416     rewriter.mergeBlockBefore(generateOp.getBody(),
417                               parallelBody->getTerminator(),
418                               parallelBody->getArguments());
419     // Replace the inlined yield op with a store op. The scf.parallel's builder
420     // already populated an scf.yield at the end, so we don't need to worry
421     // about creating that.
422     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
423     rewriter.setInsertionPointAfter(elementYield);
424     rewriter.replaceOpWithNewOp<memref::StoreOp>(
425         elementYield, elementYield->getOperands()[0], result,
426         parallelBody->getArguments());
427 
428     replaceOpWithBufferizedValues(rewriter, op, result);
429     return success();
430   }
431 };
432 
433 /// Bufferization of tensor.insert. Replace with memref.store.
434 struct InsertOpInterface
435     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
436                                                     tensor::InsertOp> {
437   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
438                               const AnalysisState &state) const {
439     return true;
440   }
441 
442   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
443                                const AnalysisState &state) const {
444     return true;
445   }
446 
447   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
448                                             const AnalysisState &state) const {
449     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
450            "expected dest OpOperand");
451     return {op->getOpResult(0)};
452   }
453 
454   SmallVector<OpOperand *>
455   getAliasingOpOperand(Operation *op, OpResult opResult,
456                        const AnalysisState &state) const {
457     return {&op->getOpOperand(1) /*dest*/};
458   }
459 
460   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
461                           BufferizationState &state) const {
462     auto insertOp = cast<tensor::InsertOp>(op);
463     FailureOr<Value> destMemref =
464         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
465     if (failed(destMemref))
466       return failure();
467     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
468                                      *destMemref, insertOp.indices());
469     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
470     return success();
471   }
472 
473   BufferRelation bufferRelation(Operation *op, OpResult opResult,
474                                 const AnalysisState &state) const {
475     return BufferRelation::Equivalent;
476   }
477 };
478 
479 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
480 /// equivalent operand / result and same offset/sizes/strides specification).
481 ///
482 /// This is one particular type of relationship between ops on tensors that
483 /// reduce to an equivalence on buffers. This should be generalized and
484 /// exposed as interfaces on the proper types.
485 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
486                                          ExtractSliceOp st, InsertSliceOp sti) {
487   if (!st || !sti)
488     return false;
489   if (sti != sti &&
490       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
491     return false;
492   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
493     return false;
494   return true;
495 }
496 
497 /// Return true if `value` is originating from an ExtractSliceOp that matches
498 /// the given InsertSliceOp.
499 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
500                                       InsertSliceOp insertOp) {
501   auto condition = [&](Value val) {
502     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
503       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
504         return true;
505     return false;
506   };
507 
508   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
509                       condition);
510 }
511 
512 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
513 /// certain circumstances, this op can also be a no-op.
514 struct InsertSliceOpInterface
515     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
516                                                     tensor::InsertSliceOp> {
517   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
518                               const AnalysisState &state) const {
519     return true;
520   }
521 
522   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
523                                const AnalysisState &state) const {
524     return &opOperand == &op->getOpOperand(1) /*dest*/;
525   }
526 
527   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
528                                             const AnalysisState &state) const {
529     if (&opOperand == &op->getOpOperand(1) /*dest*/)
530       return {op->getResult(0)};
531     return {};
532   }
533 
534   BufferRelation bufferRelation(Operation *op, OpResult opResult,
535                                 const AnalysisState &state) const {
536     return BufferRelation::Equivalent;
537   }
538 
539   bool isNotConflicting(Operation *op, OpOperand *uRead,
540                         OpOperand *uConflictingWrite,
541                         const AnalysisState &state) const {
542     Operation *readingOp = uRead->getOwner();
543     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
544 
545     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
546     // uRead is an InsertSliceOp...
547     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
548       // As an example, consider the following IR.
549       //
550       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
551       // %1 = linalg.fill %cst, %0 {inplace= [true] }
552       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
553       //     {inplace= [true] }
554 
555       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
556       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
557           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
558                                     insertSliceOp))
559         // Case 1: The main insight is that InsertSliceOp reads only part of
560         // the destination tensor. The overwritten area is not read. If
561         // uConflictingWrite writes into exactly the memory location that is
562         // being read by uRead, this is not a conflict.
563         //
564         // In the above example:
565         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
566         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
567         //
568         // The read of %t does not conflict with the write of the FillOp
569         // (same aliases!) because the area that the FillOp operates on is
570         // exactly the one that is *not* read via %t.
571         return true;
572 
573       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
574           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
575           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
576         // Case 2: The read of the source tensor and the write to the dest
577         // tensor via an InsertSliceOp is not a conflict if the read is
578         // reading exactly that part of an equivalent tensor that the
579         // InsertSliceOp is writing.
580         //
581         // In the above example:
582         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
583         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
584         return true;
585     }
586 
587     // If uConflictingWrite is an InsertSliceOp...
588     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
589       // As an example, consider the following IR.
590       //
591       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
592       // %1 = linalg.fill %cst, %0 {inplace= [true] }
593       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
594       //     {inplace= [true] }
595       // %3 = vector.transfer_read %1, %cst
596       //
597       // In the above example:
598       // uRead             = OpOperand 0 (%1) of vector.transfer_read
599       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
600       // lastWrite         = %1
601       //
602       // This is not a conflict because the InsertSliceOp overwrites the
603       // memory segment of %1 with the exact same data. (Effectively, there
604       // is no memory write here.)
605       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
606           state.areEquivalentBufferizedValues(uRead->get(),
607                                               insertSliceOp.source()) &&
608           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
609                                     insertSliceOp))
610         return true;
611 
612     return false;
613   }
614 
615   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
616                           BufferizationState &state) const {
617     // insert_slice ops arise from tiling and bufferizing them out-of-place is
618     // generally a deal breaker. When used with loops, this ends up cloning the
619     // whole tensor on every single iteration and is a symptom of a
620     // catastrophically bad scheduling decision.
621     // TODO: be very loud about it or even consider failing the pass.
622     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
623     Location loc = insertSliceOp.getLoc();
624 
625     // When bufferizing out-of-place, `getResultBuffer` allocates.
626     FailureOr<Value> dstMemref =
627         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
628     if (failed(dstMemref))
629       return failure();
630 
631     // Expand offsets, sizes and strides to the full rank to handle the
632     // rank-reducing case.
633     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
634     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
635     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
636     OffsetSizeAndStrideOpInterface::expandToRank(
637         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
638         [&](Value target, int64_t dim) -> OpFoldResult {
639           auto shapedType = target.getType().cast<ShapedType>();
640           if (shapedType.isDynamicDim(dim))
641             return rewriter.create<memref::DimOp>(loc, target, dim).result();
642           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
643         });
644     // Take a subview of the dst.
645     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
646     auto subviewMemRefType =
647         memref::SubViewOp::inferRankReducedResultType(
648             insertSliceOp.getSourceType().getRank(), dstMemrefType,
649             mixedOffsets, mixedSizes, mixedStrides)
650             .cast<MemRefType>();
651     Value subView = rewriter.create<memref::SubViewOp>(
652         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
653         mixedStrides);
654 
655     // Copy tensor. If this tensor.insert_slice has a matching
656     // tensor.extract_slice, the copy operation will eventually fold away.
657     Value srcMemref =
658         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
659     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
660                             state.getOptions())))
661       return failure();
662 
663     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
664     return success();
665   }
666 };
667 
668 /// Bufferization of tensor.rank. Replace with memref.rank.
669 struct RankOpInterface
670     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
671                                                     tensor::RankOp> {
672   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
673                               const AnalysisState &state) const {
674     return true;
675   }
676 
677   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
678                                const AnalysisState &state) const {
679     return false;
680   }
681 
682   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
683                                             const AnalysisState &state) const {
684     return {};
685   }
686 
687   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
688                           BufferizationState &state) const {
689     auto rankOp = cast<tensor::RankOp>(op);
690     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
691     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
692                                                  v);
693     return success();
694   }
695 };
696 
697 } // namespace
698 } // namespace tensor
699 } // namespace mlir
700 
701 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
702     DialectRegistry &registry) {
703   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
704     CastOp::attachInterface<CastOpInterface>(*ctx);
705     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
706     DimOp::attachInterface<DimOpInterface>(*ctx);
707     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
708     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
709     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
710     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
711     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
712     InsertOp::attachInterface<InsertOpInterface>(*ctx);
713     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
714     RankOp::attachInterface<RankOpInterface>(*ctx);
715   });
716 }
717