1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/Matchers.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
26 FailureOr<Value>
27 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
28                                               MemRefType destType) {
29   auto srcType = value.getType().cast<MemRefType>();
30 
31   // Element type, rank and memory space must match.
32   if (srcType.getElementType() != destType.getElementType())
33     return failure();
34   if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
35     return failure();
36   if (srcType.getRank() != destType.getRank())
37     return failure();
38 
39   // In case the affine maps are different, we may need to use a copy if we go
40   // from dynamic to static offset or stride (the canonicalization cannot know
41   // at this point that it is really cast compatible).
42   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43     int64_t sourceOffset, targetOffset;
44     SmallVector<int64_t, 4> sourceStrides, targetStrides;
45     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47       return false;
48     auto dynamicToStatic = [](int64_t a, int64_t b) {
49       return a == MemRefType::getDynamicStrideOrOffset() &&
50              b != MemRefType::getDynamicStrideOrOffset();
51     };
52     if (dynamicToStatic(sourceOffset, targetOffset))
53       return false;
54     for (auto it : zip(sourceStrides, targetStrides))
55       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
56         return false;
57     return true;
58   };
59 
60   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
61   // ensure that we only generate casts that always succeed at runtime, we check
62   // a fix extra conditions in `isGuaranteedCastCompatible`.
63   if (memref::CastOp::areCastCompatible(srcType, destType) &&
64       isGuaranteedCastCompatible(srcType, destType)) {
65     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
66     return casted;
67   }
68 
69   auto loc = value.getLoc();
70   SmallVector<Value, 4> dynamicOperands;
71   for (int i = 0; i < destType.getRank(); ++i) {
72     if (destType.getShape()[i] != ShapedType::kDynamicSize)
73       continue;
74     auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
75     Value size = b.create<memref::DimOp>(loc, value, index);
76     dynamicOperands.push_back(size);
77   }
78   // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
79   // BufferizableOpInterface impl of ToMemrefOp.
80   Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
81   b.create<memref::CopyOp>(loc, value, copy);
82   return copy;
83 }
84 
85 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
86 /// to_memref op are different, a memref.cast is needed.
87 LogicalResult
88 mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
89                                               ToMemrefOp toMemref) {
90   auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
91   if (!memrefToTensor)
92     return failure();
93 
94   Type srcType = memrefToTensor.getMemref().getType();
95   Type destType = toMemref.getType();
96 
97   // Directly rewrite if the type did not change.
98   if (srcType == destType) {
99     rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
100     return success();
101   }
102 
103   auto rankedSrcType = srcType.dyn_cast<MemRefType>();
104   auto rankedDestType = destType.dyn_cast<MemRefType>();
105   auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
106 
107   // Ranked memref -> Ranked memref cast.
108   if (rankedSrcType && rankedDestType) {
109     FailureOr<Value> replacement = castOrReallocMemRefValue(
110         rewriter, memrefToTensor.getMemref(), rankedDestType);
111     if (failed(replacement))
112       return failure();
113 
114     rewriter.replaceOp(toMemref, *replacement);
115     return success();
116   }
117 
118   // Unranked memref -> Ranked memref cast: May require a copy.
119   // TODO: Not implemented at the moment.
120   if (unrankedSrcType && rankedDestType)
121     return failure();
122 
123   // Unranked memref -> unranked memref cast
124   // Ranked memref -> unranked memref cast: No copy needed.
125   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126          "expected that types are cast compatible");
127   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
128                                               memrefToTensor.getMemref());
129   return success();
130 }
131 
132 void mlir::bufferization::populateDynamicDimSizes(
133     OpBuilder &b, Location loc, Value shapedValue,
134     SmallVector<Value> &dynamicDims) {
135   auto shapedType = shapedValue.getType().cast<ShapedType>();
136   for (int64_t i = 0; i < shapedType.getRank(); ++i) {
137     if (shapedType.isDynamicDim(i)) {
138       if (shapedType.isa<MemRefType>()) {
139         dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
140       } else {
141         assert(shapedType.isa<RankedTensorType>() && "expected tensor");
142         dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
143       }
144     }
145   }
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // AllocTensorOp
150 //===----------------------------------------------------------------------===//
151 
152 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
153                                        const BufferizationOptions &options) {
154   OpBuilder::InsertionGuard g(rewriter);
155   Operation *op = this->getOperation();
156   Location loc = getLoc();
157 
158   // Nothing to do for dead AllocTensorOps.
159   if (getOperation()->getUses().empty()) {
160     rewriter.eraseOp(getOperation());
161     return success();
162   }
163 
164   // Get "copy" buffer.
165   Value copyBuffer;
166   if (getCopy()) {
167     FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
168     if (failed(maybeCopyBuffer))
169       return failure();
170     copyBuffer = *maybeCopyBuffer;
171   }
172 
173   // Compute memory space of this allocation.
174   unsigned memorySpace;
175   if (getMemorySpace().has_value()) {
176     memorySpace = *getMemorySpace();
177   } else if (getCopy()) {
178     memorySpace =
179         copyBuffer.getType().cast<BaseMemRefType>().getMemorySpaceAsInt();
180   } else if (options.defaultMemorySpace.has_value()) {
181     memorySpace = *options.defaultMemorySpace;
182   } else {
183     return op->emitError("could not infer memory space");
184   }
185 
186   // Create memory allocation.
187   auto allocType =
188       MemRefType::get(getType().getShape(), getType().getElementType(),
189                       AffineMap(), memorySpace);
190   SmallVector<Value> dynamicDims = getDynamicSizes();
191   if (getCopy()) {
192     assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
193     populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
194   }
195   FailureOr<Value> alloc =
196       options.createAlloc(rewriter, loc, allocType, dynamicDims);
197   if (failed(alloc))
198     return failure();
199 
200   // Create memory copy (if any).
201   if (getCopy()) {
202     if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
203       return failure();
204   }
205 
206   // Should the buffer be deallocated?
207   AnalysisState analysisState(options);
208   bool dealloc;
209   if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) {
210     // AllocTensorOp has one result.
211     ArrayAttr escapeAttr =
212         op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>();
213     dealloc = !escapeAttr[0].cast<BoolAttr>().getValue();
214   } else {
215     // No "escape" annotation found.
216     if (options.createDeallocs) {
217       // Perform an ad-hoc analysis.
218       dealloc = !analysisState.isTensorYielded(getResult());
219     } else {
220       dealloc = false;
221     }
222   }
223 
224   // Replace op.
225   replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
226 
227   // Create buffer deallocation (if requested).
228   if (!dealloc)
229     return success();
230 
231   rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
232   if (failed(options.createDealloc(rewriter, loc, *alloc)))
233     return failure();
234   return success();
235 }
236 
237 bool AllocTensorOp::isMemoryWrite(OpResult opResult,
238                                   const AnalysisState &state) {
239   // AllocTensorOps do not write unless they have a `copy` value.
240   return static_cast<bool>(getCopy());
241 }
242 
243 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
244                                            const AnalysisState &state) {
245   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
246          "expected copy operand");
247   return true;
248 }
249 
250 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
251                                             const AnalysisState &state) {
252   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
253          "expected copy operand");
254   return false;
255 }
256 
257 SmallVector<OpResult>
258 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
259                                    const AnalysisState &state) {
260   // This is a new allocation. It does not alias with any other buffer.
261   return {};
262 }
263 
264 LogicalResult AllocTensorOp::verify() {
265   if (getCopy() && !getDynamicSizes().empty())
266     return emitError("dynamic sizes not needed when copying a tensor");
267   if (!getCopy() && getType().getNumDynamicDims() !=
268                         static_cast<int64_t>(getDynamicSizes().size()))
269     return emitError("expected ")
270            << getType().getNumDynamicDims() << " dynamic sizes";
271   if (getCopy() && getCopy().getType() != getType())
272     return emitError("expected that `copy` and return type match");
273 
274   // For sparse tensor allocation, we require that none of its
275   // uses escapes the function boundary directly.
276   if (sparse_tensor::getSparseTensorEncoding(getType())) {
277     for (auto &use : getOperation()->getUses())
278       if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
279               use.getOwner()))
280         return emitError("sparse tensor allocation should not escape function");
281   }
282 
283   return success();
284 }
285 
286 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
287                           RankedTensorType type, ValueRange dynamicSizes) {
288   build(builder, result, type, dynamicSizes, /*copy=*/Value(),
289         /*memory_space=*/IntegerAttr());
290 }
291 
292 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
293                           RankedTensorType type, ValueRange dynamicSizes,
294                           Value copy) {
295   build(builder, result, type, dynamicSizes, copy,
296         /*memory_space=*/IntegerAttr());
297 }
298 
299 namespace {
300 /// Change the type of the result of a `bufferization.alloc_tensor` by making
301 /// the result type statically sized along dimension that in the original
302 /// operation where defined as dynamic, but the size was defined using a
303 /// `constant` op. For example:
304 ///
305 ///  %c5 = arith.constant 5: index
306 ///  %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
307 ///
308 ///  to
309 ///
310 ///  %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
311 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
312   using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
313 
314   LogicalResult matchAndRewrite(AllocTensorOp op,
315                                 PatternRewriter &rewriter) const override {
316     if (op.getCopy())
317       return failure();
318     SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
319     SmallVector<Value> newDynamicSizes;
320     unsigned int dynValCounter = 0;
321     for (int64_t i = 0; i < op.getType().getRank(); ++i) {
322       if (!op.isDynamicDim(i))
323         continue;
324       Value value = op.getDynamicSizes()[dynValCounter++];
325       APInt intVal;
326       if (matchPattern(value, m_ConstantInt(&intVal))) {
327         newShape[i] = intVal.getSExtValue();
328       } else {
329         newDynamicSizes.push_back(value);
330       }
331     }
332     RankedTensorType newType = RankedTensorType::get(
333         newShape, op.getType().getElementType(), op.getType().getEncoding());
334     if (newType == op.getType())
335       return failure();
336     auto newOp = rewriter.create<AllocTensorOp>(
337         op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
338     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
339     return success();
340   }
341 };
342 
343 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
344   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
345 
346   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
347                                 PatternRewriter &rewriter) const override {
348     Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
349     auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
350     if (!allocTensorOp || !maybeConstantIndex)
351       return failure();
352     if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
353       return failure();
354     rewriter.replaceOp(
355         dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
356     return success();
357   }
358 };
359 } // namespace
360 
361 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
362                                                 MLIRContext *ctx) {
363   results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
364 }
365 
366 LogicalResult AllocTensorOp::reifyResultShapes(
367     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
368   auto shapes = llvm::to_vector<4>(llvm::map_range(
369       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
370         if (isDynamicDim(dim))
371           return getDynamicSize(builder, dim);
372         return builder.create<arith::ConstantIndexOp>(getLoc(),
373                                                       getStaticSize(dim));
374       }));
375   reifiedReturnShapes.emplace_back(std::move(shapes));
376   return success();
377 }
378 
379 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
380   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
381   if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
382       parser.parseRParen())
383     return failure();
384   ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
385   OpAsmParser::UnresolvedOperand copyOperand;
386   if (copyKeyword.succeeded())
387     if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
388         parser.parseRParen())
389       return failure();
390   if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
391     return failure();
392 
393   TensorType type;
394   if (parser.parseCustomTypeWithFallback(type))
395     return failure();
396   result.addTypes(type);
397 
398   Type indexType = parser.getBuilder().getIndexType();
399   if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
400     return failure();
401   if (copyKeyword.succeeded())
402     if (parser.resolveOperand(copyOperand, type, result.operands))
403       return failure();
404   result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
405                       parser.getBuilder().getI32VectorAttr(
406                           {static_cast<int32_t>(dynamicSizesOperands.size()),
407                            static_cast<int32_t>(copyKeyword.succeeded())}));
408   return success();
409 }
410 
411 void AllocTensorOp::print(OpAsmPrinter &p) {
412   p << "(" << getDynamicSizes() << ")";
413   if (getCopy())
414     p << " copy(" << getCopy() << ")";
415   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
416                               AllocTensorOp::getOperandSegmentSizeAttr()});
417   p << " : ";
418   auto type = getResult().getType();
419   if (auto validType = type.dyn_cast<::mlir::TensorType>())
420     p.printStrippedAttrOrType(validType);
421   else
422     p << type;
423 }
424 
425 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
426   assert(isDynamicDim(idx) && "expected dynamic dim");
427   if (getCopy())
428     return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
429   return getOperand(getIndexOfDynamicSize(idx));
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // CloneOp
434 //===----------------------------------------------------------------------===//
435 
436 void CloneOp::getEffects(
437     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
438         &effects) {
439   effects.emplace_back(MemoryEffects::Read::get(), getInput(),
440                        SideEffects::DefaultResource::get());
441   effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
442                        SideEffects::DefaultResource::get());
443   effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
444                        SideEffects::DefaultResource::get());
445 }
446 
447 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
448   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
449 }
450 
451 namespace {
452 
453 /// Merge the clone and its source (by converting the clone to a cast) when
454 /// possible.
455 struct SimplifyClones : public OpRewritePattern<CloneOp> {
456   using OpRewritePattern<CloneOp>::OpRewritePattern;
457 
458   LogicalResult matchAndRewrite(CloneOp cloneOp,
459                                 PatternRewriter &rewriter) const override {
460     if (cloneOp.use_empty()) {
461       rewriter.eraseOp(cloneOp);
462       return success();
463     }
464 
465     Value source = cloneOp.getInput();
466 
467     // This only finds dealloc operations for the immediate value. It should
468     // also consider aliases. That would also make the safety check below
469     // redundant.
470     llvm::Optional<Operation *> maybeCloneDeallocOp =
471         memref::findDealloc(cloneOp.getOutput());
472     // Skip if either of them has > 1 deallocate operations.
473     if (!maybeCloneDeallocOp.has_value())
474       return failure();
475     llvm::Optional<Operation *> maybeSourceDeallocOp =
476         memref::findDealloc(source);
477     if (!maybeSourceDeallocOp.has_value())
478       return failure();
479     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
480     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
481 
482     // If both are deallocated in the same block, their in-block lifetimes
483     // might not fully overlap, so we cannot decide which one to drop.
484     if (cloneDeallocOp && sourceDeallocOp &&
485         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
486       return failure();
487 
488     Block *currentBlock = cloneOp->getBlock();
489     Operation *redundantDealloc = nullptr;
490     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
491       redundantDealloc = cloneDeallocOp;
492     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
493       redundantDealloc = sourceDeallocOp;
494     }
495 
496     if (!redundantDealloc)
497       return failure();
498 
499     // Safety check that there are no other deallocations inbetween
500     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
501     // of source before the uses of the clone. With alias information, we could
502     // restrict this to only fail of the dealloc's operand is an alias
503     // of the source.
504     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
505          pos = pos->getNextNode()) {
506       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
507       if (!effectInterface)
508         continue;
509       if (effectInterface.hasEffect<MemoryEffects::Free>())
510         return failure();
511     }
512 
513     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
514                                                 source);
515     rewriter.eraseOp(redundantDealloc);
516     return success();
517   }
518 };
519 
520 } // namespace
521 
522 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
523                                           MLIRContext *context) {
524   results.add<SimplifyClones>(context);
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // ToTensorOp
529 //===----------------------------------------------------------------------===//
530 
531 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
532   if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
533     // Approximate alias analysis by conservatively folding only when no there
534     // is no interleaved operation.
535     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
536         toMemref->getNextNode() == this->getOperation())
537       return toMemref.getTensor();
538   return {};
539 }
540 
541 namespace {
542 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
543   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
544 
545   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
546                                 PatternRewriter &rewriter) const override {
547     auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
548     if (!memrefToTensorOp)
549       return failure();
550 
551     rewriter.replaceOpWithNewOp<memref::DimOp>(
552         dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
553     return success();
554   }
555 };
556 } // namespace
557 
558 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
559                                              MLIRContext *context) {
560   results.add<DimOfToTensorFolder>(context);
561 }
562 
563 //===----------------------------------------------------------------------===//
564 // ToMemrefOp
565 //===----------------------------------------------------------------------===//
566 
567 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
568   if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
569     if (memrefToTensor.getMemref().getType() == getType())
570       return memrefToTensor.getMemref();
571   return {};
572 }
573 
574 namespace {
575 
576 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
577 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
578   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
579 
580   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
581                                 PatternRewriter &rewriter) const final {
582     auto tensorCastOperand =
583         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
584     if (!tensorCastOperand)
585       return failure();
586     auto srcTensorType =
587         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
588     if (!srcTensorType)
589       return failure();
590     auto memrefType = MemRefType::get(srcTensorType.getShape(),
591                                       srcTensorType.getElementType());
592     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
593                                                tensorCastOperand.getOperand());
594     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
595                                                 memref);
596     return success();
597   }
598 };
599 
600 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
601 /// cast if necessary.
602 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
603   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
604 
605   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
606                                 PatternRewriter &rewriter) const final {
607     return foldToMemrefToTensorPair(rewriter, toMemref);
608   }
609 };
610 
611 /// Fold a load on a to_memref operation into an tensor.extract on the
612 /// corresponding tensor.
613 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
614   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
615 
616   LogicalResult matchAndRewrite(memref::LoadOp load,
617                                 PatternRewriter &rewriter) const override {
618     auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
619     if (!toMemref)
620       return failure();
621 
622     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
623                                                    load.getIndices());
624     return success();
625   }
626 };
627 
628 /// Fold dim of a to_memref into the dim of the tensor.
629 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
630   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
631 
632   LogicalResult matchAndRewrite(memref::DimOp dimOp,
633                                 PatternRewriter &rewriter) const override {
634     auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
635     if (!castOp)
636       return failure();
637     Value newSource = castOp.getOperand();
638     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
639                                                dimOp.getIndex());
640     return success();
641   }
642 };
643 
644 } // namespace
645 
646 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
647                                              MLIRContext *context) {
648   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
649               ToMemrefToTensorFolding>(context);
650 }
651 
652 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
653                                     const BufferizationOptions &options) {
654   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
655   (void)foldToMemrefToTensorPair(rewriter, *this);
656   // Note: The return value of `bufferize` indicates whether there was an error
657   // or not. (And not whether the pattern matched or not.)
658   return success();
659 }
660 
661 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
662   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
663       .getOperation();
664 }
665 
666 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
667   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // TableGen'd op method definitions
672 //===----------------------------------------------------------------------===//
673 
674 #define GET_OP_CLASSES
675 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
676