1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 using namespace mlir::tensor;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 
25 struct CastOpInterface
26     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
27                                                     tensor::CastOp> {
28   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
29                               const BufferizationState &state) const {
30     return false;
31   }
32 
33   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
34                                const BufferizationState &state) const {
35     return false;
36   }
37 
38   SmallVector<OpResult>
39   getAliasingOpResult(Operation *op, OpOperand &opOperand,
40                       const BufferizationState &state) const {
41     return {op->getResult(0)};
42   }
43 
44   BufferRelation bufferRelation(Operation *op, OpResult opResult,
45                                 const BufferizationState &state) const {
46     return BufferRelation::Equivalent;
47   }
48 
49   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50                           const BufferizationState &state) const {
51     auto castOp = cast<tensor::CastOp>(op);
52 
53     // The result buffer still has the old (pre-cast) type.
54     FailureOr<Value> resultBuffer =
55         state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/);
56     if (failed(resultBuffer))
57       return failure();
58     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
59     Attribute memorySpace = sourceMemRefType.getMemorySpace();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
70                                           layout, memorySpace);
71 
72     // Replace the op with a memref.cast.
73     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
74                                              resultMemRefType) &&
75            "CallOp::bufferize: cast incompatible");
76     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
77                                                  *resultBuffer);
78 
79     return success();
80   }
81 };
82 
83 /// Bufferization of tensor.dim. Replace with memref.dim.
84 struct DimOpInterface
85     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
86                                                     tensor::DimOp> {
87   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88                               const BufferizationState &state) const {
89     return true;
90   }
91 
92   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93                                const BufferizationState &state) const {
94     return false;
95   }
96 
97   SmallVector<OpResult>
98   getAliasingOpResult(Operation *op, OpOperand &opOperand,
99                       const BufferizationState &state) const {
100     return {};
101   }
102 
103   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
104                           const BufferizationState &state) const {
105     auto dimOp = cast<tensor::DimOp>(op);
106     Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
107     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
108     return success();
109   }
110 };
111 
112 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
113 struct ExtractSliceOpInterface
114     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
115                                                     tensor::ExtractSliceOp> {
116   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
117                               const BufferizationState &state) const {
118     return false;
119   }
120 
121   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
122                                const BufferizationState &state) const {
123     return false;
124   }
125 
126   SmallVector<OpResult>
127   getAliasingOpResult(Operation *op, OpOperand &opOperand,
128                       const BufferizationState &state) const {
129     if (&opOperand == &op->getOpOperand(0) /*source*/)
130       return {op->getOpResult(0)};
131     return {};
132   }
133 
134   BufferRelation bufferRelation(Operation *op, OpResult opResult,
135                                 const BufferizationState &state) const {
136     return BufferRelation::None;
137   }
138 
139   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
140                           const BufferizationState &state) const {
141     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
142     Location loc = extractSliceOp.getLoc();
143     Value srcMemref =
144         *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
145                          /*forceInPlace=*/true);
146     auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
147     auto dstTensorType =
148         extractSliceOp.result().getType().cast<RankedTensorType>();
149 
150     // If not inplaceable, alloc.
151     bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0));
152     Value alloc;
153     if (!inplace) {
154       FailureOr<Value> allocOrFailure =
155           createAlloc(rewriter, loc, extractSliceOp.result(),
156                       state.getOptions().createDeallocs, state.getOptions());
157       if (failed(allocOrFailure))
158         return failure();
159       alloc = *allocOrFailure;
160     }
161 
162     // Expand offsets, sizes and strides to the full rank to handle the
163     // rank-reducing case.
164     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
165     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
166     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
167     OffsetSizeAndStrideOpInterface::expandToRank(
168         srcMemref, mixedOffsets, mixedSizes, mixedStrides,
169         [&](Value target, int64_t dim) -> OpFoldResult {
170           auto shapedType = target.getType().cast<ShapedType>();
171           if (shapedType.isDynamicDim(dim))
172             return rewriter.create<memref::DimOp>(loc, target, dim).result();
173           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
174         });
175     // Bufferize to subview.
176     auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
177                                  dstTensorType.getRank(), srcMemrefType,
178                                  mixedOffsets, mixedSizes, mixedStrides)
179                                  .cast<MemRefType>();
180     Value subView = rewriter.create<memref::SubViewOp>(
181         loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
182         mixedStrides);
183 
184     // If not inplaceable, copy.
185     if (!inplace) {
186       // Do not copy if the copied data is never read.
187       if (state.isValueRead(extractSliceOp.result()))
188         if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
189                                 alloc, state.getOptions())))
190           return failure();
191       subView = alloc;
192     }
193 
194     replaceOpWithBufferizedValues(rewriter, op, subView);
195     return success();
196   }
197 };
198 
199 /// Bufferization of tensor.extract. Replace with memref.load.
200 struct ExtractOpInterface
201     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
202                                                     tensor::ExtractOp> {
203   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
204                               const BufferizationState &state) const {
205     return true;
206   }
207 
208   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
209                                const BufferizationState &state) const {
210     return false;
211   }
212 
213   SmallVector<OpResult>
214   getAliasingOpResult(Operation *op, OpOperand &opOperand,
215                       const BufferizationState &state) const {
216     return {};
217   }
218 
219   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
220                           const BufferizationState &state) const {
221     auto extractOp = cast<tensor::ExtractOp>(op);
222     Value srcMemref =
223         *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
224     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
225                                                  extractOp.indices());
226     return success();
227   }
228 };
229 
230 // Implements backtracking to traverse indices of the output buffer while
231 // iterating over op.elements().
232 static void createStores(RewriterBase &rewriter, Location loc, int dim,
233                          Value buffer, ArrayRef<int64_t> shape,
234                          ArrayRef<Value> constants,
235                          OperandRange::iterator &elementIt,
236                          SmallVectorImpl<Value> &indices) {
237   if (dim == static_cast<int>(shape.size()) - 1) {
238     for (int i = 0; i < shape.back(); ++i) {
239       indices.back() = constants[i];
240       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
241       ++elementIt;
242     }
243     return;
244   }
245   for (int i = 0; i < shape[dim]; ++i) {
246     indices[dim] = constants[i];
247     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
248                  indices);
249   }
250 }
251 
252 /// Bufferization of tensor.from_elements.
253 struct FromElementsOpInterface
254     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
255                                                     tensor::FromElementsOp> {
256   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
257                           const BufferizationState &state) const {
258     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
259 
260     // Allocate a buffer for the result.
261     Location loc = op->getLoc();
262     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
263     auto shape = tensorType.getShape();
264     MemRefType resultType = getContiguousMemRefType(tensorType);
265     FailureOr<Value> maybeBuffer =
266         createAlloc(rewriter, loc, resultType, {},
267                     /*deallocMemref=*/state.getOptions().createDeallocs,
268                     state.getOptions());
269     if (failed(maybeBuffer))
270       return failure();
271     Value buffer = *maybeBuffer;
272 
273     // Case: tensor<0xelem_type>.
274     if (fromElementsOp.elements().empty()) {
275       replaceOpWithBufferizedValues(rewriter, op, buffer);
276       return success();
277     }
278 
279     // Case: tensor<elem_type>.
280     if (shape.empty()) {
281       rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
282                                        buffer);
283       replaceOpWithBufferizedValues(rewriter, op, buffer);
284       return success();
285     }
286 
287     // Create constants for the range of possible indices [0, max{shape_i}).
288     auto maxDim = *std::max_element(shape.begin(), shape.end());
289     SmallVector<Value, 2> constants;
290     constants.reserve(maxDim);
291     for (int i = 0; i < maxDim; ++i)
292       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
293 
294     // Traverse all `elements` and create `memref.store` ops.
295     auto elementIt = fromElementsOp.elements().begin();
296     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
297     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
298                  indices);
299 
300     replaceOpWithBufferizedValues(rewriter, op, buffer);
301     return success();
302   }
303 };
304 
305 /// Bufferization of tensor.generate.
306 struct GenerateOpInterface
307     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
308                                                     tensor::GenerateOp> {
309   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
310                           const BufferizationState &state) const {
311     auto generateOp = cast<tensor::GenerateOp>(op);
312 
313     // Allocate memory.
314     Location loc = op->getLoc();
315     MemRefType memrefType =
316         getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
317     FailureOr<Value> maybeResult =
318         createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
319                     /*deallocMemref=*/state.getOptions().createDeallocs,
320                     state.getOptions());
321     if (failed(maybeResult))
322       return failure();
323     Value result = *maybeResult;
324 
325     // Collect loop bounds.
326     int64_t rank = memrefType.getRank();
327     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
328     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
329     SmallVector<Value, 4> lowerBounds(rank, zero);
330     SmallVector<Value, 4> steps(rank, one);
331     SmallVector<Value, 4> upperBounds;
332     int nextDynamicIndex = 0;
333     for (int i = 0; i < rank; i++) {
334       Value upperBound = memrefType.isDynamicDim(i)
335                              ? generateOp.dynamicExtents()[nextDynamicIndex++]
336                              : rewriter.create<arith::ConstantIndexOp>(
337                                    loc, memrefType.getDimSize(i));
338       upperBounds.push_back(upperBound);
339     }
340 
341     // Generate tensor elements with a parallel loop that stores into
342     // each element of the resulting memref. We use mergeBlockBefore to "move"
343     // this op's body into the scf.parallel's body.
344     auto parallel =
345         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
346     Block *parallelBody = parallel.getBody();
347     rewriter.mergeBlockBefore(generateOp.getBody(),
348                               parallelBody->getTerminator(),
349                               parallelBody->getArguments());
350     // Replace the inlined yield op with a store op. The scf.parallel's builder
351     // already populated an scf.yield at the end, so we don't need to worry
352     // about creating that.
353     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
354     rewriter.setInsertionPointAfter(elementYield);
355     rewriter.replaceOpWithNewOp<memref::StoreOp>(
356         elementYield, elementYield->getOperands()[0], result,
357         parallelBody->getArguments());
358 
359     replaceOpWithBufferizedValues(rewriter, op, result);
360     return success();
361   }
362 };
363 
364 /// Bufferization of tensor.insert. Replace with memref.store.
365 struct InsertOpInterface
366     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
367                                                     tensor::InsertOp> {
368   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
369                               const BufferizationState &state) const {
370     return true;
371   }
372 
373   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
374                                const BufferizationState &state) const {
375     return true;
376   }
377 
378   SmallVector<OpResult>
379   getAliasingOpResult(Operation *op, OpOperand &opOperand,
380                       const BufferizationState &state) const {
381     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
382            "expected dest OpOperand");
383     return {op->getOpResult(0)};
384   }
385 
386   SmallVector<OpOperand *>
387   getAliasingOpOperand(Operation *op, OpResult opResult,
388                        const BufferizationState &state) const {
389     return {&op->getOpOperand(1) /*dest*/};
390   }
391 
392   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
393                           const BufferizationState &state) const {
394     auto insertOp = cast<tensor::InsertOp>(op);
395     FailureOr<Value> destMemref =
396         state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
397     if (failed(destMemref))
398       return failure();
399     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
400                                      *destMemref, insertOp.indices());
401     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
402     return success();
403   }
404 
405   BufferRelation bufferRelation(Operation *op, OpResult opResult,
406                                 const BufferizationState &state) const {
407     return BufferRelation::Equivalent;
408   }
409 };
410 
411 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
412 /// equivalent operand / result and same offset/sizes/strides specification).
413 ///
414 /// This is one particular type of relationship between ops on tensors that
415 /// reduce to an equivalence on buffers. This should be generalized and
416 /// exposed as interfaces on the proper types.
417 static bool areEquivalentExtractSliceOps(const BufferizationState &state,
418                                          ExtractSliceOp st, InsertSliceOp sti) {
419   if (!st || !sti)
420     return false;
421   if (sti != sti &&
422       !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
423     return false;
424   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
425     return false;
426   return true;
427 }
428 
429 /// Return true if `value` is originating from an ExtractSliceOp that matches
430 /// the given InsertSliceOp.
431 static bool hasMatchingExtractSliceOp(const BufferizationState &state,
432                                       Value value, InsertSliceOp insertOp) {
433   auto condition = [&](Value val) {
434     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
435       if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
436         return true;
437     return false;
438   };
439 
440   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
441                       condition);
442 }
443 
444 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
445 /// certain circumstances, this op can also be a no-op.
446 struct InsertSliceOpInterface
447     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
448                                                     tensor::InsertSliceOp> {
449   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
450                               const BufferizationState &state) const {
451     return true;
452   }
453 
454   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
455                                const BufferizationState &state) const {
456     return &opOperand == &op->getOpOperand(1) /*dest*/;
457   }
458 
459   SmallVector<OpResult>
460   getAliasingOpResult(Operation *op, OpOperand &opOperand,
461                       const BufferizationState &state) const {
462     if (&opOperand == &op->getOpOperand(1) /*dest*/)
463       return {op->getResult(0)};
464     return {};
465   }
466 
467   BufferRelation bufferRelation(Operation *op, OpResult opResult,
468                                 const BufferizationState &state) const {
469     return BufferRelation::Equivalent;
470   }
471 
472   bool isNotConflicting(Operation *op, OpOperand *uRead,
473                         OpOperand *uConflictingWrite,
474                         const BufferizationState &state) const {
475     Operation *readingOp = uRead->getOwner();
476     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
477 
478     // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
479     // uRead is an InsertSliceOp...
480     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
481       // As an example, consider the following IR.
482       //
483       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
484       // %1 = linalg.fill %cst, %0 {inplace= [true] }
485       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
486       //     {inplace= [true] }
487 
488       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
489       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
490           hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
491                                     insertSliceOp))
492         // Case 1: The main insight is that InsertSliceOp reads only part of
493         // the destination tensor. The overwritten area is not read. If
494         // uConflictingWrite writes into exactly the memory location that is
495         // being read by uRead, this is not a conflict.
496         //
497         // In the above example:
498         // uRead             = OpOperand 1 (%t) of tensor.insert_slice
499         // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
500         //
501         // The read of %t does not conflict with the write of the FillOp
502         // (same aliases!) because the area that the FillOp operates on is
503         // exactly the one that is *not* read via %t.
504         return true;
505 
506       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
507           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
508           hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
509         // Case 2: The read of the source tensor and the write to the dest
510         // tensor via an InsertSliceOp is not a conflict if the read is
511         // reading exactly that part of an equivalent tensor that the
512         // InsertSliceOp is writing.
513         //
514         // In the above example:
515         // uRead             = OpOperand 0 (%1) of tensor.insert_slice
516         // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
517         return true;
518     }
519 
520     // If uConflictingWrite is an InsertSliceOp...
521     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
522       // As an example, consider the following IR.
523       //
524       // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
525       // %1 = linalg.fill %cst, %0 {inplace= [true] }
526       // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
527       //     {inplace= [true] }
528       // %3 = vector.transfer_read %1, %cst
529       //
530       // In the above example:
531       // uRead             = OpOperand 0 (%1) of vector.transfer_read
532       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
533       // lastWrite         = %1
534       //
535       // This is not a conflict because the InsertSliceOp overwrites the
536       // memory segment of %1 with the exact same data. (Effectively, there
537       // is no memory write here.)
538       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
539           state.areEquivalentBufferizedValues(uRead->get(),
540                                               insertSliceOp.source()) &&
541           hasMatchingExtractSliceOp(state, insertSliceOp.source(),
542                                     insertSliceOp))
543         return true;
544 
545     return false;
546   }
547 
548   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
549                           const BufferizationState &state) const {
550     // insert_slice ops arise from tiling and bufferizing them out-of-place is
551     // generally a deal breaker. When used with loops, this ends up cloning the
552     // whole tensor on every single iteration and is a symptom of a
553     // catastrophically bad scheduling decision.
554     // TODO: be very loud about it or even consider failing the pass.
555     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
556     Location loc = insertSliceOp.getLoc();
557 
558     // When bufferizing out-of-place, `getResultBuffer` allocates.
559     FailureOr<Value> dstMemref =
560         state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/);
561     if (failed(dstMemref))
562       return failure();
563 
564     // Expand offsets, sizes and strides to the full rank to handle the
565     // rank-reducing case.
566     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
567     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
568     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
569     OffsetSizeAndStrideOpInterface::expandToRank(
570         *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
571         [&](Value target, int64_t dim) -> OpFoldResult {
572           auto shapedType = target.getType().cast<ShapedType>();
573           if (shapedType.isDynamicDim(dim))
574             return rewriter.create<memref::DimOp>(loc, target, dim).result();
575           return rewriter.getIndexAttr(shapedType.getDimSize(dim));
576         });
577     // Take a subview of the dst.
578     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
579     auto subviewMemRefType =
580         memref::SubViewOp::inferRankReducedResultType(
581             insertSliceOp.getSourceType().getRank(), dstMemrefType,
582             mixedOffsets, mixedSizes, mixedStrides)
583             .cast<MemRefType>();
584     Value subView = rewriter.create<memref::SubViewOp>(
585         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
586         mixedStrides);
587 
588     // Copy tensor. If this tensor.insert_slice has a matching
589     // tensor.extract_slice, the copy operation will eventually fold away.
590     Value srcMemref =
591         *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
592     if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
593                             state.getOptions())))
594       return failure();
595 
596     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
597     return success();
598   }
599 };
600 
601 /// Bufferization of tensor.rank. Replace with memref.rank.
602 struct RankOpInterface
603     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
604                                                     tensor::RankOp> {
605   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
606                               const BufferizationState &state) const {
607     return true;
608   }
609 
610   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
611                                const BufferizationState &state) const {
612     return false;
613   }
614 
615   SmallVector<OpResult>
616   getAliasingOpResult(Operation *op, OpOperand &opOperand,
617                       const BufferizationState &state) const {
618     return {};
619   }
620 
621   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
622                           const BufferizationState &state) const {
623     auto rankOp = cast<tensor::RankOp>(op);
624     Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
625     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
626                                                  v);
627     return success();
628   }
629 };
630 
631 } // namespace
632 } // namespace tensor
633 } // namespace mlir
634 
635 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
636     DialectRegistry &registry) {
637   registry.addOpInterface<CastOp, CastOpInterface>();
638   registry.addOpInterface<DimOp, DimOpInterface>();
639   registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
640   registry.addOpInterface<ExtractOp, ExtractOpInterface>();
641   registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
642   registry.addOpInterface<GenerateOp, GenerateOpInterface>();
643   registry.addOpInterface<InsertOp, InsertOpInterface>();
644   registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
645   registry.addOpInterface<RankOp, RankOpInterface>();
646 }
647