1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::tensor;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 
25 struct CastOpInterface
26     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
27                                                     tensor::CastOp> {
28   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
29                               const AnalysisState &state) const {
30     return false;
31   }
32 
33   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
34                                const AnalysisState &state) const {
35     return false;
36   }
37 
38   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
39                                             const AnalysisState &state) const {
40     return {op->getResult(0)};
41   }
42 
43   BufferRelation bufferRelation(Operation *op, OpResult opResult,
44                                 const AnalysisState &state) const {
45     return BufferRelation::Equivalent;
46   }
47 
48   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49                           BufferizationState &state) const {
50     auto castOp = cast<tensor::CastOp>(op);
51 
52     // The result buffer still has the old (pre-cast) type.
53     FailureOr<Value> resultBuffer =
54         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
55     if (failed(resultBuffer))
56       return failure();
57     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
58     Attribute memorySpace = sourceMemRefType.getMemorySpace();
59     TensorType resultTensorType =
60         castOp.getResult().getType().cast<TensorType>();
61     MemRefLayoutAttrInterface layout;
62 
63     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
64       if (resultTensorType.isa<RankedTensorType>())
65         layout = rankedMemRefType.getLayout();
66 
67     // Compute the new memref type.
68     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
69                                           layout, memorySpace);
70 
71     // Replace the op with a memref.cast.
72     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
73                                              resultMemRefType) &&
74            "CallOp::bufferize: cast incompatible");
75     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
76                                                  *resultBuffer);
77 
78     return success();
79   }
80 };
81 
82 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
83 struct CollapseShapeOpInterface
84     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
85                                                     tensor::CollapseShapeOp> {
86   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
87                               const AnalysisState &state) const {
88     return false;
89   }
90 
91   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
92                                const AnalysisState &state) const {
93     return false;
94   }
95 
96   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
97                                             const AnalysisState &state) const {
98     if (&opOperand == &op->getOpOperand(0) /*src*/)
99       return {op->getOpResult(0)};
100     return {};
101   }
102 
103   BufferRelation bufferRelation(Operation *op, OpResult opResult,
104                                 const AnalysisState &state) const {
105     return BufferRelation::Equivalent;
106   }
107 
108   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
109                           BufferizationState &state) const {
110     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
111     Value buffer =
112         *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
113     Type resultType =
114         getMemRefType(collapseShapeOp.getResultType(), state.getOptions());
115     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
116         rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
117     return success();
118   }
119 };
120 
121 /// Bufferization of tensor.dim. Replace with memref.dim.
122 struct DimOpInterface
123     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
124                                                     tensor::DimOp> {
125   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
126                               const AnalysisState &state) const {
127     return true;
128   }
129 
130   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
131                                const AnalysisState &state) const {
132     return false;
133   }
134 
135   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
136                                             const AnalysisState &state) const {
137     return {};
138   }
139 
140   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
141                           BufferizationState &state) const {
142     auto dimOp = cast<tensor::DimOp>(op);
143     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
144     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
145     return success();
146   }
147 };
148 
149 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
150 struct ExpandShapeOpInterface
151     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
152                                                     tensor::ExpandShapeOp> {
153   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
154                               const AnalysisState &state) const {
155     return false;
156   }
157 
158   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
159                                const AnalysisState &state) const {
160     return false;
161   }
162 
163   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
164                                             const AnalysisState &state) const {
165     if (&opOperand == &op->getOpOperand(0) /*src*/)
166       return {op->getOpResult(0)};
167     return {};
168   }
169 
170   BufferRelation bufferRelation(Operation *op, OpResult opResult,
171                                 const AnalysisState &state) const {
172     return BufferRelation::Equivalent;
173   }
174 
175   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
176                           BufferizationState &state) const {
177     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
178     Value buffer =
179         *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
180     Type resultType =
181         getMemRefType(expandShapeOp.getResultType(), state.getOptions());
182     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
183         rewriter, op, resultType, buffer, expandShapeOp.reassociation());
184     return success();
185   }
186 };
187 
188 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
189 struct ExtractSliceOpInterface
190     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
191                                                     tensor::ExtractSliceOp> {
192   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
193                               const AnalysisState &state) const {
194     return false;
195   }
196 
197   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
198                                const AnalysisState &state) const {
199     return false;
200   }
201 
202   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
203                                             const AnalysisState &state) const {
204     if (&opOperand == &op->getOpOperand(0) /*source*/)
205       return {op->getOpResult(0)};
206     return {};
207   }
208 
209   BufferRelation bufferRelation(Operation *op, OpResult opResult,
210                                 const AnalysisState &state) const {
211     return BufferRelation::None;
212   }
213 
214   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
215                           BufferizationState &state) const {
216     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
217     Location loc = extractSliceOp.getLoc();
218     Value srcMemref =
219         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
220                          /*forceInPlace=*/true);
221     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
222     auto dstTensorType =
223         extractSliceOp.result().getType().cast<RankedTensorType>();
224 
225     // If not inplaceable, alloc.
226     bool inplace =
227         state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
228     Value alloc;
229     if (!inplace) {
230       FailureOr<Value> allocOrFailure =
231           state.createAlloc(rewriter, loc, extractSliceOp.result());
232       if (failed(allocOrFailure))
233         return failure();
234       alloc = *allocOrFailure;
235     }
236 
237     // Expand offsets, sizes and strides to the full rank to handle the
238     // rank-reducing case.
239     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
240     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
241     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
242     OffsetSizeAndStrideOpInterface::expandToRank(
243         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
244         [&](Value target, int64_t dim) -> OpFoldResult {
245           auto shapedType = target.getType().cast<ShapedType>();
246           if (shapedType.isDynamicDim(dim))
247             return rewriter.create<memref::DimOp>(loc, target, dim).result();
248           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
249         });
250     // Bufferize to subview.
251     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
252                                  dstTensorType.getRank(), srcMemrefType,
253                                  mixedOffsets, mixedSizes, mixedStrides)
254                                  .cast<MemRefType>();
255     Value subView = rewriter.create<memref::SubViewOp>(
256         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
257         mixedStrides);
258 
259     // If not inplaceable, copy.
260     if (!inplace) {
261       // Do not copy if the copied data is never read.
262       if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
263         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
264                                 alloc, state.getOptions())))
265           return failure();
266       subView = alloc;
267     }
268 
269     replaceOpWithBufferizedValues(rewriter, op, subView);
270     return success();
271   }
272 };
273 
274 /// Bufferization of tensor.extract. Replace with memref.load.
275 struct ExtractOpInterface
276     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
277                                                     tensor::ExtractOp> {
278   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
279                               const AnalysisState &state) const {
280     return true;
281   }
282 
283   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
284                                const AnalysisState &state) const {
285     return false;
286   }
287 
288   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
289                                             const AnalysisState &state) const {
290     return {};
291   }
292 
293   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
294                           BufferizationState &state) const {
295     auto extractOp = cast<tensor::ExtractOp>(op);
296     Value srcMemref =
297         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
298     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
299                                                  extractOp.indices());
300     return success();
301   }
302 };
303 
304 // Implements backtracking to traverse indices of the output buffer while
305 // iterating over op.elements().
306 static void createStores(RewriterBase &rewriter, Location loc, int dim,
307                          Value buffer, ArrayRef<int64_t> shape,
308                          ArrayRef<Value> constants,
309                          OperandRange::iterator &elementIt,
310                          SmallVectorImpl<Value> &indices) {
311   if (dim == static_cast<int>(shape.size()) - 1) {
312     for (int i = 0; i < shape.back(); ++i) {
313       indices.back() = constants[i];
314       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
315       ++elementIt;
316     }
317     return;
318   }
319   for (int i = 0; i < shape[dim]; ++i) {
320     indices[dim] = constants[i];
321     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
322                  indices);
323   }
324 }
325 
326 /// Bufferization of tensor.from_elements.
327 struct FromElementsOpInterface
328     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
329                                                     tensor::FromElementsOp> {
330   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
331                           BufferizationState &state) const {
332     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
333 
334     // Allocate a buffer for the result.
335     Location loc = op->getLoc();
336     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
337     auto shape = tensorType.getShape();
338     MemRefType resultType = getContiguousMemRefType(tensorType);
339     FailureOr<Value> maybeBuffer =
340         state.createAlloc(rewriter, loc, resultType, {});
341     if (failed(maybeBuffer))
342       return failure();
343     Value buffer = *maybeBuffer;
344 
345     // Case: tensor<0xelem_type>.
346     if (fromElementsOp.elements().empty()) {
347       replaceOpWithBufferizedValues(rewriter, op, buffer);
348       return success();
349     }
350 
351     // Case: tensor<elem_type>.
352     if (shape.empty()) {
353       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
354                                        buffer);
355       replaceOpWithBufferizedValues(rewriter, op, buffer);
356       return success();
357     }
358 
359     // Create constants for the range of possible indices [0, max{shape_i}).
360     auto maxDim = *std::max_element(shape.begin(), shape.end());
361     SmallVector<Value, 2> constants;
362     constants.reserve(maxDim);
363     for (int i = 0; i < maxDim; ++i)
364       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
365 
366     // Traverse all `elements` and create `memref.store` ops.
367     auto elementIt = fromElementsOp.elements().begin();
368     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
369     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
370                  indices);
371 
372     replaceOpWithBufferizedValues(rewriter, op, buffer);
373     return success();
374   }
375 };
376 
377 /// Bufferization of tensor.generate.
378 struct GenerateOpInterface
379     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
380                                                     tensor::GenerateOp> {
381   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
382                           BufferizationState &state) const {
383     auto generateOp = cast<tensor::GenerateOp>(op);
384 
385     // Allocate memory.
386     Location loc = op->getLoc();
387     MemRefType memrefType =
388         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
389     FailureOr<Value> maybeResult = state.createAlloc(
390         rewriter, loc, memrefType, generateOp.dynamicExtents());
391     if (failed(maybeResult))
392       return failure();
393     Value result = *maybeResult;
394 
395     // Collect loop bounds.
396     int64_t rank = memrefType.getRank();
397     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
398     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
399     SmallVector<Value, 4> lowerBounds(rank, zero);
400     SmallVector<Value, 4> steps(rank, one);
401     SmallVector<Value, 4> upperBounds;
402     int nextDynamicIndex = 0;
403     for (int i = 0; i < rank; i++) {
404       Value upperBound = memrefType.isDynamicDim(i)
405                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
406                              : rewriter.create<arith::ConstantIndexOp>(
407                                    loc, memrefType.getDimSize(i));
408       upperBounds.push_back(upperBound);
409     }
410 
411     // Generate tensor elements with a parallel loop that stores into
412     // each element of the resulting memref. We use mergeBlockBefore to "move"
413     // this op's body into the scf.parallel's body.
414     auto parallel =
415         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
416     Block *parallelBody = parallel.getBody();
417     rewriter.mergeBlockBefore(generateOp.getBody(),
418                               parallelBody->getTerminator(),
419                               parallelBody->getArguments());
420     // Replace the inlined yield op with a store op. The scf.parallel's builder
421     // already populated an scf.yield at the end, so we don't need to worry
422     // about creating that.
423     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
424     rewriter.setInsertionPointAfter(elementYield);
425     rewriter.replaceOpWithNewOp<memref::StoreOp>(
426         elementYield, elementYield->getOperands()[0], result,
427         parallelBody->getArguments());
428 
429     replaceOpWithBufferizedValues(rewriter, op, result);
430     return success();
431   }
432 };
433 
434 /// Bufferization of tensor.insert. Replace with memref.store.
435 struct InsertOpInterface
436     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
437                                                     tensor::InsertOp> {
438   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
439                               const AnalysisState &state) const {
440     return true;
441   }
442 
443   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
444                                const AnalysisState &state) const {
445     return true;
446   }
447 
448   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
449                                             const AnalysisState &state) const {
450     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
451            "expected dest OpOperand");
452     return {op->getOpResult(0)};
453   }
454 
455   SmallVector<OpOperand *>
456   getAliasingOpOperand(Operation *op, OpResult opResult,
457                        const AnalysisState &state) const {
458     return {&op->getOpOperand(1) /*dest*/};
459   }
460 
461   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
462                           BufferizationState &state) const {
463     auto insertOp = cast<tensor::InsertOp>(op);
464     FailureOr<Value> destMemref =
465         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
466     if (failed(destMemref))
467       return failure();
468     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
469                                      *destMemref, insertOp.indices());
470     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
471     return success();
472   }
473 
474   BufferRelation bufferRelation(Operation *op, OpResult opResult,
475                                 const AnalysisState &state) const {
476     return BufferRelation::Equivalent;
477   }
478 };
479 
480 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
481 /// equivalent operand / result and same offset/sizes/strides specification).
482 ///
483 /// This is one particular type of relationship between ops on tensors that
484 /// reduce to an equivalence on buffers. This should be generalized and
485 /// exposed as interfaces on the proper types.
486 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
487                                          ExtractSliceOp st, InsertSliceOp sti) {
488   if (!st || !sti)
489     return false;
490   if (sti != sti &&
491       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
492     return false;
493   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
494     return false;
495   return true;
496 }
497 
498 /// Return true if `value` is originating from an ExtractSliceOp that matches
499 /// the given InsertSliceOp.
500 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
501                                       InsertSliceOp insertOp) {
502   auto condition = [&](Value val) {
503     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
504       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
505         return true;
506     return false;
507   };
508 
509   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
510                       condition);
511 }
512 
513 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
514 /// certain circumstances, this op can also be a no-op.
515 struct InsertSliceOpInterface
516     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
517                                                     tensor::InsertSliceOp> {
518   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
519                               const AnalysisState &state) const {
520     return true;
521   }
522 
523   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
524                                const AnalysisState &state) const {
525     return &opOperand == &op->getOpOperand(1) /*dest*/;
526   }
527 
528   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
529                                             const AnalysisState &state) const {
530     if (&opOperand == &op->getOpOperand(1) /*dest*/)
531       return {op->getResult(0)};
532     return {};
533   }
534 
535   BufferRelation bufferRelation(Operation *op, OpResult opResult,
536                                 const AnalysisState &state) const {
537     return BufferRelation::Equivalent;
538   }
539 
540   bool isNotConflicting(Operation *op, OpOperand *uRead,
541                         OpOperand *uConflictingWrite,
542                         const AnalysisState &state) const {
543     Operation *readingOp = uRead->getOwner();
544     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
545 
546     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
547     // uRead is an InsertSliceOp...
548     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
549       // As an example, consider the following IR.
550       //
551       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
552       // %1 = linalg.fill %cst, %0 {inplace= [true] }
553       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
554       //     {inplace= [true] }
555 
556       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
557       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
558           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
559                                     insertSliceOp))
560         // Case 1: The main insight is that InsertSliceOp reads only part of
561         // the destination tensor. The overwritten area is not read. If
562         // uConflictingWrite writes into exactly the memory location that is
563         // being read by uRead, this is not a conflict.
564         //
565         // In the above example:
566         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
567         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
568         //
569         // The read of %t does not conflict with the write of the FillOp
570         // (same aliases!) because the area that the FillOp operates on is
571         // exactly the one that is *not* read via %t.
572         return true;
573 
574       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
575           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
576           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
577         // Case 2: The read of the source tensor and the write to the dest
578         // tensor via an InsertSliceOp is not a conflict if the read is
579         // reading exactly that part of an equivalent tensor that the
580         // InsertSliceOp is writing.
581         //
582         // In the above example:
583         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
584         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
585         return true;
586     }
587 
588     // If uConflictingWrite is an InsertSliceOp...
589     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
590       // As an example, consider the following IR.
591       //
592       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
593       // %1 = linalg.fill %cst, %0 {inplace= [true] }
594       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
595       //     {inplace= [true] }
596       // %3 = vector.transfer_read %1, %cst
597       //
598       // In the above example:
599       // uRead             = OpOperand 0 (%1) of vector.transfer_read
600       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
601       // lastWrite         = %1
602       //
603       // This is not a conflict because the InsertSliceOp overwrites the
604       // memory segment of %1 with the exact same data. (Effectively, there
605       // is no memory write here.)
606       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
607           state.areEquivalentBufferizedValues(uRead->get(),
608                                               insertSliceOp.source()) &&
609           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
610                                     insertSliceOp))
611         return true;
612 
613     return false;
614   }
615 
616   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
617                           BufferizationState &state) const {
618     // insert_slice ops arise from tiling and bufferizing them out-of-place is
619     // generally a deal breaker. When used with loops, this ends up cloning the
620     // whole tensor on every single iteration and is a symptom of a
621     // catastrophically bad scheduling decision.
622     // TODO: be very loud about it or even consider failing the pass.
623     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
624     Location loc = insertSliceOp.getLoc();
625 
626     // When bufferizing out-of-place, `getResultBuffer` allocates.
627     FailureOr<Value> dstMemref =
628         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
629     if (failed(dstMemref))
630       return failure();
631 
632     // Expand offsets, sizes and strides to the full rank to handle the
633     // rank-reducing case.
634     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
635     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
636     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
637     OffsetSizeAndStrideOpInterface::expandToRank(
638         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
639         [&](Value target, int64_t dim) -> OpFoldResult {
640           auto shapedType = target.getType().cast<ShapedType>();
641           if (shapedType.isDynamicDim(dim))
642             return rewriter.create<memref::DimOp>(loc, target, dim).result();
643           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
644         });
645     // Take a subview of the dst.
646     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
647     auto subviewMemRefType =
648         memref::SubViewOp::inferRankReducedResultType(
649             insertSliceOp.getSourceType().getRank(), dstMemrefType,
650             mixedOffsets, mixedSizes, mixedStrides)
651             .cast<MemRefType>();
652     Value subView = rewriter.create<memref::SubViewOp>(
653         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
654         mixedStrides);
655 
656     // Copy tensor. If this tensor.insert_slice has a matching
657     // tensor.extract_slice, the copy operation will eventually fold away.
658     Value srcMemref =
659         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
660     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
661                             state.getOptions())))
662       return failure();
663 
664     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
665     return success();
666   }
667 };
668 
669 /// Bufferization of tensor.rank. Replace with memref.rank.
670 struct RankOpInterface
671     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
672                                                     tensor::RankOp> {
673   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
674                               const AnalysisState &state) const {
675     return true;
676   }
677 
678   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
679                                const AnalysisState &state) const {
680     return false;
681   }
682 
683   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
684                                             const AnalysisState &state) const {
685     return {};
686   }
687 
688   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
689                           BufferizationState &state) const {
690     auto rankOp = cast<tensor::RankOp>(op);
691     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
692     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
693                                                  v);
694     return success();
695   }
696 };
697 
698 } // namespace
699 } // namespace tensor
700 } // namespace mlir
701 
702 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
703     DialectRegistry &registry) {
704   registry.addOpInterface<CastOp, CastOpInterface>();
705   registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>();
706   registry.addOpInterface<DimOp, DimOpInterface>();
707   registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>();
708   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
709   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
710   registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
711   registry.addOpInterface<GenerateOp, GenerateOpInterface>();
712   registry.addOpInterface<InsertOp, InsertOpInterface>();
713   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
714   registry.addOpInterface<RankOp, RankOpInterface>();
715 }
716