1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::tensor;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 
25 struct CastOpInterface
26     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
27                                                     tensor::CastOp> {
28   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
29                               const BufferizationState &state) const {
30     return false;
31   }
32 
33   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
34                                const BufferizationState &state) const {
35     return false;
36   }
37 
38   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
39                                const BufferizationState &state) const {
40     return op->getResult(0);
41   }
42 
43   BufferRelation bufferRelation(Operation *op, OpResult opResult,
44                                 const BufferizationState &state) const {
45     return BufferRelation::Equivalent;
46   }
47 
48   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49                           const BufferizationState &state) const {
50     auto castOp = cast<tensor::CastOp>(op);
51 
52     // The result buffer still has the old (pre-cast) type.
53     FailureOr<Value> resultBuffer =
54         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
55     if (failed(resultBuffer))
56       return failure();
57     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
58     Attribute memorySpace = sourceMemRefType.getMemorySpace();
59     TensorType resultTensorType =
60         castOp.getResult().getType().cast<TensorType>();
61     MemRefLayoutAttrInterface layout;
62 
63     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
64       if (resultTensorType.isa<RankedTensorType>())
65         layout = rankedMemRefType.getLayout();
66 
67     // Compute the new memref type.
68     Type resultMemRefType;
69     if (resultTensorType.isa<RankedTensorType>()) {
70       resultMemRefType =
71           getContiguousMemRefType(resultTensorType, layout, memorySpace);
72     } else {
73       resultMemRefType =
74           getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace);
75     }
76 
77     // Replace the op with a memref.cast.
78     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
79                                              resultMemRefType) &&
80            "CallOp::bufferize: cast incompatible");
81     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
82                                                  *resultBuffer);
83 
84     return success();
85   }
86 };
87 
88 /// Bufferization of tensor.dim. Replace with memref.dim.
89 struct DimOpInterface
90     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
91                                                     tensor::DimOp> {
92   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
93                               const BufferizationState &state) const {
94     return true;
95   }
96 
97   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
98                                const BufferizationState &state) const {
99     return false;
100   }
101 
102   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
103                                const BufferizationState &state) const {
104     return OpResult();
105   }
106 
107   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
108                           const BufferizationState &state) const {
109     auto dimOp = cast<tensor::DimOp>(op);
110     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
111     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
112     return success();
113   }
114 };
115 
116 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
117 struct ExtractSliceOpInterface
118     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
119                                                     tensor::ExtractSliceOp> {
120   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
121                               const BufferizationState &state) const {
122     return false;
123   }
124 
125   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
126                                const BufferizationState &state) const {
127     return false;
128   }
129 
130   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
131                                const BufferizationState &state) const {
132     return &opOperand == &op->getOpOperand(0) /*source*/
133                ? op->getResult(0)
134                : OpResult();
135   }
136 
137   BufferRelation bufferRelation(Operation *op, OpResult opResult,
138                                 const BufferizationState &state) const {
139     return BufferRelation::None;
140   }
141 
142   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
143                           const BufferizationState &state) const {
144     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
145     Location loc = extractSliceOp.getLoc();
146     Value srcMemref =
147         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
148                          /*forceInPlace=*/true);
149     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
150     auto dstTensorType =
151         extractSliceOp.result().getType().cast<RankedTensorType>();
152 
153     // If not inplaceable, alloc.
154     bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0));
155     Value alloc;
156     if (!inplace) {
157       FailureOr<Value> allocOrFailure =
158           createAlloc(rewriter, loc, extractSliceOp.result(),
159                       state.getOptions().createDeallocs, state.getOptions());
160       if (failed(allocOrFailure))
161         return failure();
162       alloc = *allocOrFailure;
163     }
164 
165     // Expand offsets, sizes and strides to the full rank to handle the
166     // rank-reducing case.
167     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
168     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
169     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
170     OffsetSizeAndStrideOpInterface::expandToRank(
171         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
172         [&](Value target, int64_t dim) -> OpFoldResult {
173           auto shapedType = target.getType().cast<ShapedType>();
174           if (shapedType.isDynamicDim(dim))
175             return rewriter.create<memref::DimOp>(loc, target, dim).result();
176           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
177         });
178     // Bufferize to subview.
179     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
180                                  dstTensorType.getRank(), srcMemrefType,
181                                  mixedOffsets, mixedSizes, mixedStrides)
182                                  .cast<MemRefType>();
183     Value subView = rewriter.create<memref::SubViewOp>(
184         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
185         mixedStrides);
186 
187     // If not inplaceable, copy.
188     if (!inplace) {
189       // Do not copy if the copied data is never read.
190       if (state.isValueRead(extractSliceOp.result()))
191         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
192                                 alloc, state.getOptions())))
193           return failure();
194       subView = alloc;
195     }
196 
197     replaceOpWithBufferizedValues(rewriter, op, subView);
198     return success();
199   }
200 };
201 
202 /// Bufferization of tensor.extract. Replace with memref.load.
203 struct ExtractOpInterface
204     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
205                                                     tensor::ExtractOp> {
206   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
207                               const BufferizationState &state) const {
208     return true;
209   }
210 
211   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
212                                const BufferizationState &state) const {
213     return false;
214   }
215 
216   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
217                                const BufferizationState &state) const {
218     return OpResult();
219   }
220 
221   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
222                           const BufferizationState &state) const {
223     auto extractOp = cast<tensor::ExtractOp>(op);
224     Value srcMemref =
225         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
226     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
227                                                  extractOp.indices());
228     return success();
229   }
230 };
231 
232 // Implements backtracking to traverse indices of the output buffer while
233 // iterating over op.elements().
234 static void createStores(RewriterBase &rewriter, Location loc, int dim,
235                          Value buffer, ArrayRef<int64_t> shape,
236                          ArrayRef<Value> constants,
237                          OperandRange::iterator &elementIt,
238                          SmallVectorImpl<Value> &indices) {
239   if (dim == static_cast<int>(shape.size()) - 1) {
240     for (int i = 0; i < shape.back(); ++i) {
241       indices.back() = constants[i];
242       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
243       ++elementIt;
244     }
245     return;
246   }
247   for (int i = 0; i < shape[dim]; ++i) {
248     indices[dim] = constants[i];
249     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
250                  indices);
251   }
252 }
253 
254 /// Bufferization of tensor.from_elements.
255 struct FromElementsOpInterface
256     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
257                                                     tensor::FromElementsOp> {
258   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
259                           const BufferizationState &state) const {
260     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
261 
262     // Allocate a buffer for the result.
263     Location loc = op->getLoc();
264     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
265     auto shape = tensorType.getShape();
266     MemRefType resultType =
267         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
268     FailureOr<Value> maybeBuffer =
269         createAlloc(rewriter, loc, resultType, {},
270                     /*deallocMemref=*/state.getOptions().createDeallocs,
271                     state.getOptions());
272     if (failed(maybeBuffer))
273       return failure();
274     Value buffer = *maybeBuffer;
275 
276     // Case: tensor<0xelem_type>.
277     if (fromElementsOp.elements().empty()) {
278       replaceOpWithBufferizedValues(rewriter, op, buffer);
279       return success();
280     }
281 
282     // Case: tensor<elem_type>.
283     if (shape.empty()) {
284       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
285                                        buffer);
286       replaceOpWithBufferizedValues(rewriter, op, buffer);
287       return success();
288     }
289 
290     // Create constants for the range of possible indices [0, max{shape_i}).
291     auto maxDim = *std::max_element(shape.begin(), shape.end());
292     SmallVector<Value, 2> constants;
293     constants.reserve(maxDim);
294     for (int i = 0; i < maxDim; ++i)
295       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
296 
297     // Traverse all `elements` and create `memref.store` ops.
298     auto elementIt = fromElementsOp.elements().begin();
299     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
300     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
301                  indices);
302 
303     replaceOpWithBufferizedValues(rewriter, op, buffer);
304     return success();
305   }
306 };
307 
308 /// Bufferization of tensor.generate.
309 struct GenerateOpInterface
310     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
311                                                     tensor::GenerateOp> {
312   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
313                           const BufferizationState &state) const {
314     auto generateOp = cast<tensor::GenerateOp>(op);
315 
316     // Allocate memory.
317     Location loc = op->getLoc();
318     MemRefType memrefType =
319         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
320     FailureOr<Value> maybeResult =
321         createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
322                     /*deallocMemref=*/state.getOptions().createDeallocs,
323                     state.getOptions());
324     if (failed(maybeResult))
325       return failure();
326     Value result = *maybeResult;
327 
328     // Collect loop bounds.
329     int64_t rank = memrefType.getRank();
330     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
331     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
332     SmallVector<Value, 4> lowerBounds(rank, zero);
333     SmallVector<Value, 4> steps(rank, one);
334     SmallVector<Value, 4> upperBounds;
335     int nextDynamicIndex = 0;
336     for (int i = 0; i < rank; i++) {
337       Value upperBound = memrefType.isDynamicDim(i)
338                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
339                              : rewriter.create<arith::ConstantIndexOp>(
340                                    loc, memrefType.getDimSize(i));
341       upperBounds.push_back(upperBound);
342     }
343 
344     // Generate tensor elements with a parallel loop that stores into
345     // each element of the resulting memref. We use mergeBlockBefore to "move"
346     // this op's body into the scf.parallel's body.
347     auto parallel =
348         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
349     Block *parallelBody = parallel.getBody();
350     rewriter.mergeBlockBefore(generateOp.getBody(),
351                               parallelBody->getTerminator(),
352                               parallelBody->getArguments());
353     // Replace the inlined yield op with a store op. The scf.parallel's builder
354     // already populated an scf.yield at the end, so we don't need to worry
355     // about creating that.
356     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
357     rewriter.setInsertionPointAfter(elementYield);
358     rewriter.replaceOpWithNewOp<memref::StoreOp>(
359         elementYield, elementYield->getOperands()[0], result,
360         parallelBody->getArguments());
361 
362     replaceOpWithBufferizedValues(rewriter, op, result);
363     return success();
364   }
365 };
366 
367 /// Bufferization of tensor.insert. Replace with memref.store.
368 struct InsertOpInterface
369     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
370                                                     tensor::InsertOp> {
371   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
372                               const BufferizationState &state) const {
373     return true;
374   }
375 
376   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
377                                const BufferizationState &state) const {
378     return true;
379   }
380 
381   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
382                                const BufferizationState &state) const {
383     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
384            "expected dest OpOperand");
385     return op->getOpResult(0);
386   }
387 
388   SmallVector<OpOperand *>
389   getAliasingOpOperand(Operation *op, OpResult opResult,
390                        const BufferizationState &state) const {
391     return {&op->getOpOperand(1) /*dest*/};
392   }
393 
394   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
395                           const BufferizationState &state) const {
396     auto insertOp = cast<tensor::InsertOp>(op);
397     FailureOr<Value> destMemref =
398         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
399     if (failed(destMemref))
400       return failure();
401     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
402                                      *destMemref, insertOp.indices());
403     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
404     return success();
405   }
406 
407   BufferRelation bufferRelation(Operation *op, OpResult opResult,
408                                 const BufferizationState &state) const {
409     return BufferRelation::Equivalent;
410   }
411 };
412 
413 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
414 /// equivalent operand / result and same offset/sizes/strides specification).
415 ///
416 /// This is one particular type of relationship between ops on tensors that
417 /// reduce to an equivalence on buffers. This should be generalized and
418 /// exposed as interfaces on the proper types.
419 static bool areEquivalentExtractSliceOps(const BufferizationState &state,
420                                          ExtractSliceOp st, InsertSliceOp sti) {
421   if (!st || !sti)
422     return false;
423   if (sti != sti &&
424       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
425     return false;
426   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
427     return false;
428   return true;
429 }
430 
431 /// Return true if `value` is originating from an ExtractSliceOp that matches
432 /// the given InsertSliceOp.
433 static bool hasMatchingExtractSliceOp(const BufferizationState &state,
434                                       Value value, InsertSliceOp insertOp) {
435   auto condition = [&](Value val) {
436     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
437       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
438         return true;
439     return false;
440   };
441 
442   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
443                       condition);
444 }
445 
446 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
447 /// certain circumstances, this op can also be a no-op.
448 struct InsertSliceOpInterface
449     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
450                                                     tensor::InsertSliceOp> {
451   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
452                               const BufferizationState &state) const {
453     return true;
454   }
455 
456   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
457                                const BufferizationState &state) const {
458     return &opOperand == &op->getOpOperand(1) /*dest*/;
459   }
460 
461   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
462                                const BufferizationState &state) const {
463     return &opOperand == &op->getOpOperand(1) /*dest*/
464                ? op->getResult(0)
465                : OpResult();
466   }
467 
468   BufferRelation bufferRelation(Operation *op, OpResult opResult,
469                                 const BufferizationState &state) const {
470     return BufferRelation::Equivalent;
471   }
472 
473   bool isNotConflicting(Operation *op, OpOperand *uRead,
474                         OpOperand *uConflictingWrite,
475                         const BufferizationState &state) const {
476     Operation *readingOp = uRead->getOwner();
477     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
478 
479     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
480     // uRead is an InsertSliceOp...
481     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
482       // As an example, consider the following IR.
483       //
484       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
485       // %1 = linalg.fill %cst, %0 {inplace= [true] }
486       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
487       //     {inplace= [true] }
488 
489       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
490       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
491           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
492                                     insertSliceOp))
493         // Case 1: The main insight is that InsertSliceOp reads only part of
494         // the destination tensor. The overwritten area is not read. If
495         // uConflictingWrite writes into exactly the memory location that is
496         // being read by uRead, this is not a conflict.
497         //
498         // In the above example:
499         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
500         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
501         //
502         // The read of %t does not conflict with the write of the FillOp
503         // (same aliases!) because the area that the FillOp operates on is
504         // exactly the one that is *not* read via %t.
505         return true;
506 
507       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
508           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
509           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
510         // Case 2: The read of the source tensor and the write to the dest
511         // tensor via an InsertSliceOp is not a conflict if the read is
512         // reading exactly that part of an equivalent tensor that the
513         // InsertSliceOp is writing.
514         //
515         // In the above example:
516         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
517         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
518         return true;
519     }
520 
521     // If uConflictingWrite is an InsertSliceOp...
522     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
523       // As an example, consider the following IR.
524       //
525       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
526       // %1 = linalg.fill %cst, %0 {inplace= [true] }
527       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
528       //     {inplace= [true] }
529       // %3 = vector.transfer_read %1, %cst
530       //
531       // In the above example:
532       // uRead             = OpOperand 0 (%1) of vector.transfer_read
533       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
534       // lastWrite         = %1
535       //
536       // This is not a conflict because the InsertSliceOp overwrites the
537       // memory segment of %1 with the exact same data. (Effectively, there
538       // is no memory write here.)
539       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
540           state.areEquivalentBufferizedValues(uRead->get(),
541                                               insertSliceOp.source()) &&
542           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
543                                     insertSliceOp))
544         return true;
545 
546     return false;
547   }
548 
549   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
550                           const BufferizationState &state) const {
551     // insert_slice ops arise from tiling and bufferizing them out-of-place is
552     // generally a deal breaker. When used with loops, this ends up cloning the
553     // whole tensor on every single iteration and is a symptom of a
554     // catastrophically bad scheduling decision.
555     // TODO: be very loud about it or even consider failing the pass.
556     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
557     Location loc = insertSliceOp.getLoc();
558 
559     // When bufferizing out-of-place, `getResultBuffer` allocates.
560     FailureOr<Value> dstMemref =
561         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
562     if (failed(dstMemref))
563       return failure();
564 
565     // Expand offsets, sizes and strides to the full rank to handle the
566     // rank-reducing case.
567     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
568     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
569     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
570     OffsetSizeAndStrideOpInterface::expandToRank(
571         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
572         [&](Value target, int64_t dim) -> OpFoldResult {
573           auto shapedType = target.getType().cast<ShapedType>();
574           if (shapedType.isDynamicDim(dim))
575             return rewriter.create<memref::DimOp>(loc, target, dim).result();
576           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
577         });
578     // Take a subview of the dst.
579     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
580     auto subviewMemRefType =
581         memref::SubViewOp::inferRankReducedResultType(
582             insertSliceOp.getSourceType().getRank(), dstMemrefType,
583             mixedOffsets, mixedSizes, mixedStrides)
584             .cast<MemRefType>();
585     Value subView = rewriter.create<memref::SubViewOp>(
586         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
587         mixedStrides);
588 
589     // Copy tensor. If this tensor.insert_slice has a matching
590     // tensor.extract_slice, the copy operation will eventually fold away.
591     Value srcMemref =
592         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
593     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
594                             state.getOptions())))
595       return failure();
596 
597     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
598     return success();
599   }
600 };
601 
602 /// Bufferization of tensor.rank. Replace with memref.rank.
603 struct RankOpInterface
604     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
605                                                     tensor::RankOp> {
606   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
607                               const BufferizationState &state) const {
608     return true;
609   }
610 
611   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
612                                const BufferizationState &state) const {
613     return false;
614   }
615 
616   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
617                                const BufferizationState &state) const {
618     return OpResult();
619   }
620 
621   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
622                           const BufferizationState &state) const {
623     auto rankOp = cast<tensor::RankOp>(op);
624     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
625     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
626                                                  v);
627     return success();
628   }
629 };
630 
631 } // namespace
632 } // namespace tensor
633 } // namespace mlir
634 
635 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
636     DialectRegistry &registry) {
637   registry.addOpInterface<CastOp, CastOpInterface>();
638   registry.addOpInterface<DimOp, DimOpInterface>();
639   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
640   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
641   registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
642   registry.addOpInterface<GenerateOp, GenerateOpInterface>();
643   registry.addOpInterface<InsertOp, InsertOpInterface>();
644   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
645   registry.addOpInterface<RankOp, RankOpInterface>();
646 }
647