1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::tensor;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 
25 struct CastOpInterface
26     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
27                                                     tensor::CastOp> {
28   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
29                               const BufferizationState &state) const {
30     return false;
31   }
32 
33   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
34                                const BufferizationState &state) const {
35     return false;
36   }
37 
38   SmallVector<OpResult>
39   getAliasingOpResult(Operation *op, OpOperand &opOperand,
40                       const BufferizationState &state) const {
41     return {op->getResult(0)};
42   }
43 
44   BufferRelation bufferRelation(Operation *op, OpResult opResult,
45                                 const BufferizationState &state) const {
46     return BufferRelation::Equivalent;
47   }
48 
49   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50                           const BufferizationState &state) const {
51     auto castOp = cast<tensor::CastOp>(op);
52 
53     // The result buffer still has the old (pre-cast) type.
54     FailureOr<Value> resultBuffer =
55         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
56     if (failed(resultBuffer))
57       return failure();
58     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
59     Attribute memorySpace = sourceMemRefType.getMemorySpace();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
70                                           layout, memorySpace);
71 
72     // Replace the op with a memref.cast.
73     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
74                                              resultMemRefType) &&
75            "CallOp::bufferize: cast incompatible");
76     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
77                                                  *resultBuffer);
78 
79     return success();
80   }
81 };
82 
83 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
84 struct CollapseShapeOpInterface
85     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
86                                                     tensor::CollapseShapeOp> {
87   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88                               const BufferizationState &state) const {
89     return false;
90   }
91 
92   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93                                const BufferizationState &state) const {
94     return false;
95   }
96 
97   SmallVector<OpResult>
98   getAliasingOpResult(Operation *op, OpOperand &opOperand,
99                       const BufferizationState &state) const {
100     if (&opOperand == &op->getOpOperand(0) /*src*/)
101       return {op->getOpResult(0)};
102     return {};
103   }
104 
105   BufferRelation bufferRelation(Operation *op, OpResult opResult,
106                                 const BufferizationState &state) const {
107     return BufferRelation::Equivalent;
108   }
109 
110   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111                           const BufferizationState &state) const {
112     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
113     Value buffer =
114         *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
115     Type resultType =
116         getMemRefType(collapseShapeOp.getResultType(), state.getOptions());
117     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
118         rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
119     return success();
120   }
121 };
122 
123 /// Bufferization of tensor.dim. Replace with memref.dim.
124 struct DimOpInterface
125     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
126                                                     tensor::DimOp> {
127   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
128                               const BufferizationState &state) const {
129     return true;
130   }
131 
132   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
133                                const BufferizationState &state) const {
134     return false;
135   }
136 
137   SmallVector<OpResult>
138   getAliasingOpResult(Operation *op, OpOperand &opOperand,
139                       const BufferizationState &state) const {
140     return {};
141   }
142 
143   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
144                           const BufferizationState &state) const {
145     auto dimOp = cast<tensor::DimOp>(op);
146     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
147     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
148     return success();
149   }
150 };
151 
152 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
153 struct ExpandShapeOpInterface
154     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
155                                                     tensor::ExpandShapeOp> {
156   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
157                               const BufferizationState &state) const {
158     return false;
159   }
160 
161   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
162                                const BufferizationState &state) const {
163     return false;
164   }
165 
166   SmallVector<OpResult>
167   getAliasingOpResult(Operation *op, OpOperand &opOperand,
168                       const BufferizationState &state) const {
169     if (&opOperand == &op->getOpOperand(0) /*src*/)
170       return {op->getOpResult(0)};
171     return {};
172   }
173 
174   BufferRelation bufferRelation(Operation *op, OpResult opResult,
175                                 const BufferizationState &state) const {
176     return BufferRelation::Equivalent;
177   }
178 
179   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
180                           const BufferizationState &state) const {
181     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
182     Value buffer =
183         *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
184     Type resultType =
185         getMemRefType(expandShapeOp.getResultType(), state.getOptions());
186     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
187         rewriter, op, resultType, buffer, expandShapeOp.reassociation());
188     return success();
189   }
190 };
191 
192 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
193 struct ExtractSliceOpInterface
194     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
195                                                     tensor::ExtractSliceOp> {
196   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
197                               const BufferizationState &state) const {
198     return false;
199   }
200 
201   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
202                                const BufferizationState &state) const {
203     return false;
204   }
205 
206   SmallVector<OpResult>
207   getAliasingOpResult(Operation *op, OpOperand &opOperand,
208                       const BufferizationState &state) const {
209     if (&opOperand == &op->getOpOperand(0) /*source*/)
210       return {op->getOpResult(0)};
211     return {};
212   }
213 
214   BufferRelation bufferRelation(Operation *op, OpResult opResult,
215                                 const BufferizationState &state) const {
216     return BufferRelation::None;
217   }
218 
219   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
220                           const BufferizationState &state) const {
221     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
222     Location loc = extractSliceOp.getLoc();
223     Value srcMemref =
224         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
225                          /*forceInPlace=*/true);
226     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
227     auto dstTensorType =
228         extractSliceOp.result().getType().cast<RankedTensorType>();
229 
230     // If not inplaceable, alloc.
231     bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0));
232     Value alloc;
233     if (!inplace) {
234       FailureOr<Value> allocOrFailure =
235           createAlloc(rewriter, loc, extractSliceOp.result(),
236                       state.getOptions().createDeallocs, state.getOptions());
237       if (failed(allocOrFailure))
238         return failure();
239       alloc = *allocOrFailure;
240     }
241 
242     // Expand offsets, sizes and strides to the full rank to handle the
243     // rank-reducing case.
244     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
245     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
246     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
247     OffsetSizeAndStrideOpInterface::expandToRank(
248         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
249         [&](Value target, int64_t dim) -> OpFoldResult {
250           auto shapedType = target.getType().cast<ShapedType>();
251           if (shapedType.isDynamicDim(dim))
252             return rewriter.create<memref::DimOp>(loc, target, dim).result();
253           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
254         });
255     // Bufferize to subview.
256     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
257                                  dstTensorType.getRank(), srcMemrefType,
258                                  mixedOffsets, mixedSizes, mixedStrides)
259                                  .cast<MemRefType>();
260     Value subView = rewriter.create<memref::SubViewOp>(
261         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
262         mixedStrides);
263 
264     // If not inplaceable, copy.
265     if (!inplace) {
266       // Do not copy if the copied data is never read.
267       if (state.isValueRead(extractSliceOp.result()))
268         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
269                                 alloc, state.getOptions())))
270           return failure();
271       subView = alloc;
272     }
273 
274     replaceOpWithBufferizedValues(rewriter, op, subView);
275     return success();
276   }
277 };
278 
279 /// Bufferization of tensor.extract. Replace with memref.load.
280 struct ExtractOpInterface
281     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
282                                                     tensor::ExtractOp> {
283   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
284                               const BufferizationState &state) const {
285     return true;
286   }
287 
288   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
289                                const BufferizationState &state) const {
290     return false;
291   }
292 
293   SmallVector<OpResult>
294   getAliasingOpResult(Operation *op, OpOperand &opOperand,
295                       const BufferizationState &state) const {
296     return {};
297   }
298 
299   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
300                           const BufferizationState &state) const {
301     auto extractOp = cast<tensor::ExtractOp>(op);
302     Value srcMemref =
303         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
304     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
305                                                  extractOp.indices());
306     return success();
307   }
308 };
309 
310 // Implements backtracking to traverse indices of the output buffer while
311 // iterating over op.elements().
312 static void createStores(RewriterBase &rewriter, Location loc, int dim,
313                          Value buffer, ArrayRef<int64_t> shape,
314                          ArrayRef<Value> constants,
315                          OperandRange::iterator &elementIt,
316                          SmallVectorImpl<Value> &indices) {
317   if (dim == static_cast<int>(shape.size()) - 1) {
318     for (int i = 0; i < shape.back(); ++i) {
319       indices.back() = constants[i];
320       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
321       ++elementIt;
322     }
323     return;
324   }
325   for (int i = 0; i < shape[dim]; ++i) {
326     indices[dim] = constants[i];
327     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
328                  indices);
329   }
330 }
331 
332 /// Bufferization of tensor.from_elements.
333 struct FromElementsOpInterface
334     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
335                                                     tensor::FromElementsOp> {
336   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
337                           const BufferizationState &state) const {
338     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
339 
340     // Allocate a buffer for the result.
341     Location loc = op->getLoc();
342     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
343     auto shape = tensorType.getShape();
344     MemRefType resultType = getContiguousMemRefType(tensorType);
345     FailureOr<Value> maybeBuffer =
346         createAlloc(rewriter, loc, resultType, {},
347                     /*deallocMemref=*/state.getOptions().createDeallocs,
348                     state.getOptions());
349     if (failed(maybeBuffer))
350       return failure();
351     Value buffer = *maybeBuffer;
352 
353     // Case: tensor<0xelem_type>.
354     if (fromElementsOp.elements().empty()) {
355       replaceOpWithBufferizedValues(rewriter, op, buffer);
356       return success();
357     }
358 
359     // Case: tensor<elem_type>.
360     if (shape.empty()) {
361       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
362                                        buffer);
363       replaceOpWithBufferizedValues(rewriter, op, buffer);
364       return success();
365     }
366 
367     // Create constants for the range of possible indices [0, max{shape_i}).
368     auto maxDim = *std::max_element(shape.begin(), shape.end());
369     SmallVector<Value, 2> constants;
370     constants.reserve(maxDim);
371     for (int i = 0; i < maxDim; ++i)
372       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
373 
374     // Traverse all `elements` and create `memref.store` ops.
375     auto elementIt = fromElementsOp.elements().begin();
376     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
377     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
378                  indices);
379 
380     replaceOpWithBufferizedValues(rewriter, op, buffer);
381     return success();
382   }
383 };
384 
385 /// Bufferization of tensor.generate.
386 struct GenerateOpInterface
387     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
388                                                     tensor::GenerateOp> {
389   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
390                           const BufferizationState &state) const {
391     auto generateOp = cast<tensor::GenerateOp>(op);
392 
393     // Allocate memory.
394     Location loc = op->getLoc();
395     MemRefType memrefType =
396         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
397     FailureOr<Value> maybeResult =
398         createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
399                     /*deallocMemref=*/state.getOptions().createDeallocs,
400                     state.getOptions());
401     if (failed(maybeResult))
402       return failure();
403     Value result = *maybeResult;
404 
405     // Collect loop bounds.
406     int64_t rank = memrefType.getRank();
407     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
408     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
409     SmallVector<Value, 4> lowerBounds(rank, zero);
410     SmallVector<Value, 4> steps(rank, one);
411     SmallVector<Value, 4> upperBounds;
412     int nextDynamicIndex = 0;
413     for (int i = 0; i < rank; i++) {
414       Value upperBound = memrefType.isDynamicDim(i)
415                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
416                              : rewriter.create<arith::ConstantIndexOp>(
417                                    loc, memrefType.getDimSize(i));
418       upperBounds.push_back(upperBound);
419     }
420 
421     // Generate tensor elements with a parallel loop that stores into
422     // each element of the resulting memref. We use mergeBlockBefore to "move"
423     // this op's body into the scf.parallel's body.
424     auto parallel =
425         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
426     Block *parallelBody = parallel.getBody();
427     rewriter.mergeBlockBefore(generateOp.getBody(),
428                               parallelBody->getTerminator(),
429                               parallelBody->getArguments());
430     // Replace the inlined yield op with a store op. The scf.parallel's builder
431     // already populated an scf.yield at the end, so we don't need to worry
432     // about creating that.
433     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
434     rewriter.setInsertionPointAfter(elementYield);
435     rewriter.replaceOpWithNewOp<memref::StoreOp>(
436         elementYield, elementYield->getOperands()[0], result,
437         parallelBody->getArguments());
438 
439     replaceOpWithBufferizedValues(rewriter, op, result);
440     return success();
441   }
442 };
443 
444 /// Bufferization of tensor.insert. Replace with memref.store.
445 struct InsertOpInterface
446     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
447                                                     tensor::InsertOp> {
448   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
449                               const BufferizationState &state) const {
450     return true;
451   }
452 
453   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
454                                const BufferizationState &state) const {
455     return true;
456   }
457 
458   SmallVector<OpResult>
459   getAliasingOpResult(Operation *op, OpOperand &opOperand,
460                       const BufferizationState &state) const {
461     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
462            "expected dest OpOperand");
463     return {op->getOpResult(0)};
464   }
465 
466   SmallVector<OpOperand *>
467   getAliasingOpOperand(Operation *op, OpResult opResult,
468                        const BufferizationState &state) const {
469     return {&op->getOpOperand(1) /*dest*/};
470   }
471 
472   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
473                           const BufferizationState &state) const {
474     auto insertOp = cast<tensor::InsertOp>(op);
475     FailureOr<Value> destMemref =
476         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
477     if (failed(destMemref))
478       return failure();
479     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
480                                      *destMemref, insertOp.indices());
481     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
482     return success();
483   }
484 
485   BufferRelation bufferRelation(Operation *op, OpResult opResult,
486                                 const BufferizationState &state) const {
487     return BufferRelation::Equivalent;
488   }
489 };
490 
491 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
492 /// equivalent operand / result and same offset/sizes/strides specification).
493 ///
494 /// This is one particular type of relationship between ops on tensors that
495 /// reduce to an equivalence on buffers. This should be generalized and
496 /// exposed as interfaces on the proper types.
497 static bool areEquivalentExtractSliceOps(const BufferizationState &state,
498                                          ExtractSliceOp st, InsertSliceOp sti) {
499   if (!st || !sti)
500     return false;
501   if (sti != sti &&
502       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
503     return false;
504   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
505     return false;
506   return true;
507 }
508 
509 /// Return true if `value` is originating from an ExtractSliceOp that matches
510 /// the given InsertSliceOp.
511 static bool hasMatchingExtractSliceOp(const BufferizationState &state,
512                                       Value value, InsertSliceOp insertOp) {
513   auto condition = [&](Value val) {
514     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
515       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
516         return true;
517     return false;
518   };
519 
520   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
521                       condition);
522 }
523 
524 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
525 /// certain circumstances, this op can also be a no-op.
526 struct InsertSliceOpInterface
527     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
528                                                     tensor::InsertSliceOp> {
529   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
530                               const BufferizationState &state) const {
531     return true;
532   }
533 
534   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
535                                const BufferizationState &state) const {
536     return &opOperand == &op->getOpOperand(1) /*dest*/;
537   }
538 
539   SmallVector<OpResult>
540   getAliasingOpResult(Operation *op, OpOperand &opOperand,
541                       const BufferizationState &state) const {
542     if (&opOperand == &op->getOpOperand(1) /*dest*/)
543       return {op->getResult(0)};
544     return {};
545   }
546 
547   BufferRelation bufferRelation(Operation *op, OpResult opResult,
548                                 const BufferizationState &state) const {
549     return BufferRelation::Equivalent;
550   }
551 
552   bool isNotConflicting(Operation *op, OpOperand *uRead,
553                         OpOperand *uConflictingWrite,
554                         const BufferizationState &state) const {
555     Operation *readingOp = uRead->getOwner();
556     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
557 
558     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
559     // uRead is an InsertSliceOp...
560     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
561       // As an example, consider the following IR.
562       //
563       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
564       // %1 = linalg.fill %cst, %0 {inplace= [true] }
565       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
566       //     {inplace= [true] }
567 
568       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
569       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
570           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
571                                     insertSliceOp))
572         // Case 1: The main insight is that InsertSliceOp reads only part of
573         // the destination tensor. The overwritten area is not read. If
574         // uConflictingWrite writes into exactly the memory location that is
575         // being read by uRead, this is not a conflict.
576         //
577         // In the above example:
578         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
579         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
580         //
581         // The read of %t does not conflict with the write of the FillOp
582         // (same aliases!) because the area that the FillOp operates on is
583         // exactly the one that is *not* read via %t.
584         return true;
585 
586       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
587           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
588           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
589         // Case 2: The read of the source tensor and the write to the dest
590         // tensor via an InsertSliceOp is not a conflict if the read is
591         // reading exactly that part of an equivalent tensor that the
592         // InsertSliceOp is writing.
593         //
594         // In the above example:
595         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
596         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
597         return true;
598     }
599 
600     // If uConflictingWrite is an InsertSliceOp...
601     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
602       // As an example, consider the following IR.
603       //
604       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
605       // %1 = linalg.fill %cst, %0 {inplace= [true] }
606       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
607       //     {inplace= [true] }
608       // %3 = vector.transfer_read %1, %cst
609       //
610       // In the above example:
611       // uRead             = OpOperand 0 (%1) of vector.transfer_read
612       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
613       // lastWrite         = %1
614       //
615       // This is not a conflict because the InsertSliceOp overwrites the
616       // memory segment of %1 with the exact same data. (Effectively, there
617       // is no memory write here.)
618       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
619           state.areEquivalentBufferizedValues(uRead->get(),
620                                               insertSliceOp.source()) &&
621           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
622                                     insertSliceOp))
623         return true;
624 
625     return false;
626   }
627 
628   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
629                           const BufferizationState &state) const {
630     // insert_slice ops arise from tiling and bufferizing them out-of-place is
631     // generally a deal breaker. When used with loops, this ends up cloning the
632     // whole tensor on every single iteration and is a symptom of a
633     // catastrophically bad scheduling decision.
634     // TODO: be very loud about it or even consider failing the pass.
635     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
636     Location loc = insertSliceOp.getLoc();
637 
638     // When bufferizing out-of-place, `getResultBuffer` allocates.
639     FailureOr<Value> dstMemref =
640         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
641     if (failed(dstMemref))
642       return failure();
643 
644     // Expand offsets, sizes and strides to the full rank to handle the
645     // rank-reducing case.
646     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
647     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
648     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
649     OffsetSizeAndStrideOpInterface::expandToRank(
650         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
651         [&](Value target, int64_t dim) -> OpFoldResult {
652           auto shapedType = target.getType().cast<ShapedType>();
653           if (shapedType.isDynamicDim(dim))
654             return rewriter.create<memref::DimOp>(loc, target, dim).result();
655           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
656         });
657     // Take a subview of the dst.
658     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
659     auto subviewMemRefType =
660         memref::SubViewOp::inferRankReducedResultType(
661             insertSliceOp.getSourceType().getRank(), dstMemrefType,
662             mixedOffsets, mixedSizes, mixedStrides)
663             .cast<MemRefType>();
664     Value subView = rewriter.create<memref::SubViewOp>(
665         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
666         mixedStrides);
667 
668     // Copy tensor. If this tensor.insert_slice has a matching
669     // tensor.extract_slice, the copy operation will eventually fold away.
670     Value srcMemref =
671         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
672     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
673                             state.getOptions())))
674       return failure();
675 
676     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
677     return success();
678   }
679 };
680 
681 /// Bufferization of tensor.rank. Replace with memref.rank.
682 struct RankOpInterface
683     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
684                                                     tensor::RankOp> {
685   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
686                               const BufferizationState &state) const {
687     return true;
688   }
689 
690   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
691                                const BufferizationState &state) const {
692     return false;
693   }
694 
695   SmallVector<OpResult>
696   getAliasingOpResult(Operation *op, OpOperand &opOperand,
697                       const BufferizationState &state) const {
698     return {};
699   }
700 
701   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
702                           const BufferizationState &state) const {
703     auto rankOp = cast<tensor::RankOp>(op);
704     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
705     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
706                                                  v);
707     return success();
708   }
709 };
710 
711 } // namespace
712 } // namespace tensor
713 } // namespace mlir
714 
715 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
716     DialectRegistry &registry) {
717   registry.addOpInterface<CastOp, CastOpInterface>();
718   registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>();
719   registry.addOpInterface<DimOp, DimOpInterface>();
720   registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>();
721   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
722   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
723   registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
724   registry.addOpInterface<GenerateOp, GenerateOpInterface>();
725   registry.addOpInterface<InsertOp, InsertOpInterface>();
726   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
727   registry.addOpInterface<RankOp, RankOpInterface>();
728 }
729