1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::tensor;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 
25 struct CastOpInterface
26     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
27                                                     tensor::CastOp> {
28   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
29                               const BufferizationState &state) const {
30     return false;
31   }
32 
33   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
34                                const BufferizationState &state) const {
35     return false;
36   }
37 
38   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
39                                const BufferizationState &state) const {
40     return op->getResult(0);
41   }
42 
43   BufferRelation bufferRelation(Operation *op, OpResult opResult,
44                                 const BufferizationState &state) const {
45     return BufferRelation::Equivalent;
46   }
47 
48   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
49                           const BufferizationState &state) const {
50     auto castOp = cast<tensor::CastOp>(op);
51 
52     // The result buffer still has the old (pre-cast) type.
53     FailureOr<Value> resultBuffer =
54         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
55     if (failed(resultBuffer))
56       return failure();
57     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
58     Attribute memorySpace = sourceMemRefType.getMemorySpace();
59     TensorType resultTensorType =
60         castOp.getResult().getType().cast<TensorType>();
61     MemRefLayoutAttrInterface layout;
62 
63     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
64       if (resultTensorType.isa<RankedTensorType>())
65         layout = rankedMemRefType.getLayout();
66 
67     // Compute the new memref type.
68     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
69                                           layout, memorySpace);
70 
71     // Replace the op with a memref.cast.
72     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
73                                              resultMemRefType) &&
74            "CallOp::bufferize: cast incompatible");
75     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
76                                                  *resultBuffer);
77 
78     return success();
79   }
80 };
81 
82 /// Bufferization of tensor.dim. Replace with memref.dim.
83 struct DimOpInterface
84     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
85                                                     tensor::DimOp> {
86   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
87                               const BufferizationState &state) const {
88     return true;
89   }
90 
91   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
92                                const BufferizationState &state) const {
93     return false;
94   }
95 
96   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
97                                const BufferizationState &state) const {
98     return OpResult();
99   }
100 
101   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
102                           const BufferizationState &state) const {
103     auto dimOp = cast<tensor::DimOp>(op);
104     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
105     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
106     return success();
107   }
108 };
109 
110 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
111 struct ExtractSliceOpInterface
112     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
113                                                     tensor::ExtractSliceOp> {
114   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
115                               const BufferizationState &state) const {
116     return false;
117   }
118 
119   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
120                                const BufferizationState &state) const {
121     return false;
122   }
123 
124   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
125                                const BufferizationState &state) const {
126     return &opOperand == &op->getOpOperand(0) /*source*/
127                ? op->getResult(0)
128                : OpResult();
129   }
130 
131   BufferRelation bufferRelation(Operation *op, OpResult opResult,
132                                 const BufferizationState &state) const {
133     return BufferRelation::None;
134   }
135 
136   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
137                           const BufferizationState &state) const {
138     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
139     Location loc = extractSliceOp.getLoc();
140     Value srcMemref =
141         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
142                          /*forceInPlace=*/true);
143     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
144     auto dstTensorType =
145         extractSliceOp.result().getType().cast<RankedTensorType>();
146 
147     // If not inplaceable, alloc.
148     bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0));
149     Value alloc;
150     if (!inplace) {
151       FailureOr<Value> allocOrFailure =
152           createAlloc(rewriter, loc, extractSliceOp.result(),
153                       state.getOptions().createDeallocs, state.getOptions());
154       if (failed(allocOrFailure))
155         return failure();
156       alloc = *allocOrFailure;
157     }
158 
159     // Expand offsets, sizes and strides to the full rank to handle the
160     // rank-reducing case.
161     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
162     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
163     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
164     OffsetSizeAndStrideOpInterface::expandToRank(
165         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
166         [&](Value target, int64_t dim) -> OpFoldResult {
167           auto shapedType = target.getType().cast<ShapedType>();
168           if (shapedType.isDynamicDim(dim))
169             return rewriter.create<memref::DimOp>(loc, target, dim).result();
170           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
171         });
172     // Bufferize to subview.
173     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
174                                  dstTensorType.getRank(), srcMemrefType,
175                                  mixedOffsets, mixedSizes, mixedStrides)
176                                  .cast<MemRefType>();
177     Value subView = rewriter.create<memref::SubViewOp>(
178         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
179         mixedStrides);
180 
181     // If not inplaceable, copy.
182     if (!inplace) {
183       // Do not copy if the copied data is never read.
184       if (state.isValueRead(extractSliceOp.result()))
185         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
186                                 alloc, state.getOptions())))
187           return failure();
188       subView = alloc;
189     }
190 
191     replaceOpWithBufferizedValues(rewriter, op, subView);
192     return success();
193   }
194 };
195 
196 /// Bufferization of tensor.extract. Replace with memref.load.
197 struct ExtractOpInterface
198     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
199                                                     tensor::ExtractOp> {
200   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
201                               const BufferizationState &state) const {
202     return true;
203   }
204 
205   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
206                                const BufferizationState &state) const {
207     return false;
208   }
209 
210   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
211                                const BufferizationState &state) const {
212     return OpResult();
213   }
214 
215   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
216                           const BufferizationState &state) const {
217     auto extractOp = cast<tensor::ExtractOp>(op);
218     Value srcMemref =
219         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
220     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
221                                                  extractOp.indices());
222     return success();
223   }
224 };
225 
226 // Implements backtracking to traverse indices of the output buffer while
227 // iterating over op.elements().
228 static void createStores(RewriterBase &rewriter, Location loc, int dim,
229                          Value buffer, ArrayRef<int64_t> shape,
230                          ArrayRef<Value> constants,
231                          OperandRange::iterator &elementIt,
232                          SmallVectorImpl<Value> &indices) {
233   if (dim == static_cast<int>(shape.size()) - 1) {
234     for (int i = 0; i < shape.back(); ++i) {
235       indices.back() = constants[i];
236       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
237       ++elementIt;
238     }
239     return;
240   }
241   for (int i = 0; i < shape[dim]; ++i) {
242     indices[dim] = constants[i];
243     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
244                  indices);
245   }
246 }
247 
248 /// Bufferization of tensor.from_elements.
249 struct FromElementsOpInterface
250     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
251                                                     tensor::FromElementsOp> {
252   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
253                           const BufferizationState &state) const {
254     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
255 
256     // Allocate a buffer for the result.
257     Location loc = op->getLoc();
258     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
259     auto shape = tensorType.getShape();
260     MemRefType resultType = getContiguousMemRefType(tensorType);
261     FailureOr<Value> maybeBuffer =
262         createAlloc(rewriter, loc, resultType, {},
263                     /*deallocMemref=*/state.getOptions().createDeallocs,
264                     state.getOptions());
265     if (failed(maybeBuffer))
266       return failure();
267     Value buffer = *maybeBuffer;
268 
269     // Case: tensor<0xelem_type>.
270     if (fromElementsOp.elements().empty()) {
271       replaceOpWithBufferizedValues(rewriter, op, buffer);
272       return success();
273     }
274 
275     // Case: tensor<elem_type>.
276     if (shape.empty()) {
277       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
278                                        buffer);
279       replaceOpWithBufferizedValues(rewriter, op, buffer);
280       return success();
281     }
282 
283     // Create constants for the range of possible indices [0, max{shape_i}).
284     auto maxDim = *std::max_element(shape.begin(), shape.end());
285     SmallVector<Value, 2> constants;
286     constants.reserve(maxDim);
287     for (int i = 0; i < maxDim; ++i)
288       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
289 
290     // Traverse all `elements` and create `memref.store` ops.
291     auto elementIt = fromElementsOp.elements().begin();
292     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
293     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
294                  indices);
295 
296     replaceOpWithBufferizedValues(rewriter, op, buffer);
297     return success();
298   }
299 };
300 
301 /// Bufferization of tensor.generate.
302 struct GenerateOpInterface
303     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
304                                                     tensor::GenerateOp> {
305   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
306                           const BufferizationState &state) const {
307     auto generateOp = cast<tensor::GenerateOp>(op);
308 
309     // Allocate memory.
310     Location loc = op->getLoc();
311     MemRefType memrefType =
312         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
313     FailureOr<Value> maybeResult =
314         createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
315                     /*deallocMemref=*/state.getOptions().createDeallocs,
316                     state.getOptions());
317     if (failed(maybeResult))
318       return failure();
319     Value result = *maybeResult;
320 
321     // Collect loop bounds.
322     int64_t rank = memrefType.getRank();
323     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
324     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
325     SmallVector<Value, 4> lowerBounds(rank, zero);
326     SmallVector<Value, 4> steps(rank, one);
327     SmallVector<Value, 4> upperBounds;
328     int nextDynamicIndex = 0;
329     for (int i = 0; i < rank; i++) {
330       Value upperBound = memrefType.isDynamicDim(i)
331                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
332                              : rewriter.create<arith::ConstantIndexOp>(
333                                    loc, memrefType.getDimSize(i));
334       upperBounds.push_back(upperBound);
335     }
336 
337     // Generate tensor elements with a parallel loop that stores into
338     // each element of the resulting memref. We use mergeBlockBefore to "move"
339     // this op's body into the scf.parallel's body.
340     auto parallel =
341         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
342     Block *parallelBody = parallel.getBody();
343     rewriter.mergeBlockBefore(generateOp.getBody(),
344                               parallelBody->getTerminator(),
345                               parallelBody->getArguments());
346     // Replace the inlined yield op with a store op. The scf.parallel's builder
347     // already populated an scf.yield at the end, so we don't need to worry
348     // about creating that.
349     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
350     rewriter.setInsertionPointAfter(elementYield);
351     rewriter.replaceOpWithNewOp<memref::StoreOp>(
352         elementYield, elementYield->getOperands()[0], result,
353         parallelBody->getArguments());
354 
355     replaceOpWithBufferizedValues(rewriter, op, result);
356     return success();
357   }
358 };
359 
360 /// Bufferization of tensor.insert. Replace with memref.store.
361 struct InsertOpInterface
362     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
363                                                     tensor::InsertOp> {
364   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
365                               const BufferizationState &state) const {
366     return true;
367   }
368 
369   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
370                                const BufferizationState &state) const {
371     return true;
372   }
373 
374   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
375                                const BufferizationState &state) const {
376     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
377            "expected dest OpOperand");
378     return op->getOpResult(0);
379   }
380 
381   SmallVector<OpOperand *>
382   getAliasingOpOperand(Operation *op, OpResult opResult,
383                        const BufferizationState &state) const {
384     return {&op->getOpOperand(1) /*dest*/};
385   }
386 
387   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
388                           const BufferizationState &state) const {
389     auto insertOp = cast<tensor::InsertOp>(op);
390     FailureOr<Value> destMemref =
391         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
392     if (failed(destMemref))
393       return failure();
394     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
395                                      *destMemref, insertOp.indices());
396     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
397     return success();
398   }
399 
400   BufferRelation bufferRelation(Operation *op, OpResult opResult,
401                                 const BufferizationState &state) const {
402     return BufferRelation::Equivalent;
403   }
404 };
405 
406 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
407 /// equivalent operand / result and same offset/sizes/strides specification).
408 ///
409 /// This is one particular type of relationship between ops on tensors that
410 /// reduce to an equivalence on buffers. This should be generalized and
411 /// exposed as interfaces on the proper types.
412 static bool areEquivalentExtractSliceOps(const BufferizationState &state,
413                                          ExtractSliceOp st, InsertSliceOp sti) {
414   if (!st || !sti)
415     return false;
416   if (sti != sti &&
417       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
418     return false;
419   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
420     return false;
421   return true;
422 }
423 
424 /// Return true if `value` is originating from an ExtractSliceOp that matches
425 /// the given InsertSliceOp.
426 static bool hasMatchingExtractSliceOp(const BufferizationState &state,
427                                       Value value, InsertSliceOp insertOp) {
428   auto condition = [&](Value val) {
429     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
430       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
431         return true;
432     return false;
433   };
434 
435   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
436                       condition);
437 }
438 
439 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
440 /// certain circumstances, this op can also be a no-op.
441 struct InsertSliceOpInterface
442     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
443                                                     tensor::InsertSliceOp> {
444   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
445                               const BufferizationState &state) const {
446     return true;
447   }
448 
449   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
450                                const BufferizationState &state) const {
451     return &opOperand == &op->getOpOperand(1) /*dest*/;
452   }
453 
454   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
455                                const BufferizationState &state) const {
456     return &opOperand == &op->getOpOperand(1) /*dest*/
457                ? op->getResult(0)
458                : OpResult();
459   }
460 
461   BufferRelation bufferRelation(Operation *op, OpResult opResult,
462                                 const BufferizationState &state) const {
463     return BufferRelation::Equivalent;
464   }
465 
466   bool isNotConflicting(Operation *op, OpOperand *uRead,
467                         OpOperand *uConflictingWrite,
468                         const BufferizationState &state) const {
469     Operation *readingOp = uRead->getOwner();
470     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
471 
472     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
473     // uRead is an InsertSliceOp...
474     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
475       // As an example, consider the following IR.
476       //
477       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
478       // %1 = linalg.fill %cst, %0 {inplace= [true] }
479       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
480       //     {inplace= [true] }
481 
482       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
483       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
484           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
485                                     insertSliceOp))
486         // Case 1: The main insight is that InsertSliceOp reads only part of
487         // the destination tensor. The overwritten area is not read. If
488         // uConflictingWrite writes into exactly the memory location that is
489         // being read by uRead, this is not a conflict.
490         //
491         // In the above example:
492         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
493         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
494         //
495         // The read of %t does not conflict with the write of the FillOp
496         // (same aliases!) because the area that the FillOp operates on is
497         // exactly the one that is *not* read via %t.
498         return true;
499 
500       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
501           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
502           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
503         // Case 2: The read of the source tensor and the write to the dest
504         // tensor via an InsertSliceOp is not a conflict if the read is
505         // reading exactly that part of an equivalent tensor that the
506         // InsertSliceOp is writing.
507         //
508         // In the above example:
509         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
510         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
511         return true;
512     }
513 
514     // If uConflictingWrite is an InsertSliceOp...
515     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
516       // As an example, consider the following IR.
517       //
518       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
519       // %1 = linalg.fill %cst, %0 {inplace= [true] }
520       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
521       //     {inplace= [true] }
522       // %3 = vector.transfer_read %1, %cst
523       //
524       // In the above example:
525       // uRead             = OpOperand 0 (%1) of vector.transfer_read
526       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
527       // lastWrite         = %1
528       //
529       // This is not a conflict because the InsertSliceOp overwrites the
530       // memory segment of %1 with the exact same data. (Effectively, there
531       // is no memory write here.)
532       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
533           state.areEquivalentBufferizedValues(uRead->get(),
534                                               insertSliceOp.source()) &&
535           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
536                                     insertSliceOp))
537         return true;
538 
539     return false;
540   }
541 
542   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
543                           const BufferizationState &state) const {
544     // insert_slice ops arise from tiling and bufferizing them out-of-place is
545     // generally a deal breaker. When used with loops, this ends up cloning the
546     // whole tensor on every single iteration and is a symptom of a
547     // catastrophically bad scheduling decision.
548     // TODO: be very loud about it or even consider failing the pass.
549     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
550     Location loc = insertSliceOp.getLoc();
551 
552     // When bufferizing out-of-place, `getResultBuffer` allocates.
553     FailureOr<Value> dstMemref =
554         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
555     if (failed(dstMemref))
556       return failure();
557 
558     // Expand offsets, sizes and strides to the full rank to handle the
559     // rank-reducing case.
560     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
561     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
562     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
563     OffsetSizeAndStrideOpInterface::expandToRank(
564         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
565         [&](Value target, int64_t dim) -> OpFoldResult {
566           auto shapedType = target.getType().cast<ShapedType>();
567           if (shapedType.isDynamicDim(dim))
568             return rewriter.create<memref::DimOp>(loc, target, dim).result();
569           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
570         });
571     // Take a subview of the dst.
572     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
573     auto subviewMemRefType =
574         memref::SubViewOp::inferRankReducedResultType(
575             insertSliceOp.getSourceType().getRank(), dstMemrefType,
576             mixedOffsets, mixedSizes, mixedStrides)
577             .cast<MemRefType>();
578     Value subView = rewriter.create<memref::SubViewOp>(
579         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
580         mixedStrides);
581 
582     // Copy tensor. If this tensor.insert_slice has a matching
583     // tensor.extract_slice, the copy operation will eventually fold away.
584     Value srcMemref =
585         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
586     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
587                             state.getOptions())))
588       return failure();
589 
590     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
591     return success();
592   }
593 };
594 
595 /// Bufferization of tensor.rank. Replace with memref.rank.
596 struct RankOpInterface
597     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
598                                                     tensor::RankOp> {
599   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
600                               const BufferizationState &state) const {
601     return true;
602   }
603 
604   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
605                                const BufferizationState &state) const {
606     return false;
607   }
608 
609   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
610                                const BufferizationState &state) const {
611     return OpResult();
612   }
613 
614   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
615                           const BufferizationState &state) const {
616     auto rankOp = cast<tensor::RankOp>(op);
617     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
618     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
619                                                  v);
620     return success();
621   }
622 };
623 
624 } // namespace
625 } // namespace tensor
626 } // namespace mlir
627 
628 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
629     DialectRegistry &registry) {
630   registry.addOpInterface<CastOp, CastOpInterface>();
631   registry.addOpInterface<DimOp, DimOpInterface>();
632   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
633   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
634   registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
635   registry.addOpInterface<GenerateOp, GenerateOpInterface>();
636   registry.addOpInterface<InsertOp, InsertOpInterface>();
637   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
638   registry.addOpInterface<RankOp, RankOpInterface>();
639 }
640