1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Matchers.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 
20 //===----------------------------------------------------------------------===//
21 // Helper functions
22 //===----------------------------------------------------------------------===//
23 
24 FailureOr<Value>
25 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
26                                               MemRefType destType) {
27   auto srcType = value.getType().cast<MemRefType>();
28 
29   // Element type, rank and memory space must match.
30   if (srcType.getElementType() != destType.getElementType())
31     return failure();
32   if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
33     return failure();
34   if (srcType.getRank() != destType.getRank())
35     return failure();
36 
37   // In case the affine maps are different, we may need to use a copy if we go
38   // from dynamic to static offset or stride (the canonicalization cannot know
39   // at this point that it is really cast compatible).
40   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
41     int64_t sourceOffset, targetOffset;
42     SmallVector<int64_t, 4> sourceStrides, targetStrides;
43     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
44         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
45       return false;
46     auto dynamicToStatic = [](int64_t a, int64_t b) {
47       return a == MemRefType::getDynamicStrideOrOffset() &&
48              b != MemRefType::getDynamicStrideOrOffset();
49     };
50     if (dynamicToStatic(sourceOffset, targetOffset))
51       return false;
52     for (auto it : zip(sourceStrides, targetStrides))
53       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
54         return false;
55     return true;
56   };
57 
58   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
59   // ensure that we only generate casts that always succeed at runtime, we check
60   // a fix extra conditions in `isGuaranteedCastCompatible`.
61   if (memref::CastOp::areCastCompatible(srcType, destType) &&
62       isGuaranteedCastCompatible(srcType, destType)) {
63     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
64     return casted;
65   }
66 
67   auto loc = value.getLoc();
68   SmallVector<Value, 4> dynamicOperands;
69   for (int i = 0; i < destType.getRank(); ++i) {
70     if (destType.getShape()[i] != ShapedType::kDynamicSize)
71       continue;
72     auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
73     Value size = b.create<memref::DimOp>(loc, value, index);
74     dynamicOperands.push_back(size);
75   }
76   // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
77   // BufferizableOpInterface impl of ToMemrefOp.
78   Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
79   b.create<memref::CopyOp>(loc, value, copy);
80   return copy;
81 }
82 
83 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
84 /// to_memref op are different, a memref.cast is needed.
85 LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
86     RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
87   auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
88   if (!memrefToTensor)
89     return failure();
90 
91   Type srcType = memrefToTensor.getMemref().getType();
92   Type destType = toMemref.getType();
93 
94   // Directly rewrite if the type did not change.
95   if (srcType == destType) {
96     // Function can be configured to only handle cases where a cast is needed.
97     if (!allowSameType)
98       return failure();
99     rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
100     return success();
101   }
102 
103   auto rankedSrcType = srcType.dyn_cast<MemRefType>();
104   auto rankedDestType = destType.dyn_cast<MemRefType>();
105   auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
106 
107   // Ranked memref -> Ranked memref cast.
108   if (rankedSrcType && rankedDestType) {
109     FailureOr<Value> replacement = castOrReallocMemRefValue(
110         rewriter, memrefToTensor.getMemref(), rankedDestType);
111     if (failed(replacement))
112       return failure();
113 
114     rewriter.replaceOp(toMemref, *replacement);
115     return success();
116   }
117 
118   // Unranked memref -> Ranked memref cast: May require a copy.
119   // TODO: Not implemented at the moment.
120   if (unrankedSrcType && rankedDestType)
121     return failure();
122 
123   // Unranked memref -> unranked memref cast
124   // Ranked memref -> unranked memref cast: No copy needed.
125   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126          "expected that types are cast compatible");
127   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
128                                               memrefToTensor.getMemref());
129   return success();
130 }
131 
132 void mlir::bufferization::populateDynamicDimSizes(
133     OpBuilder &b, Location loc, Value shapedValue,
134     SmallVector<Value> &dynamicDims) {
135   auto shapedType = shapedValue.getType().cast<ShapedType>();
136   for (int64_t i = 0; i < shapedType.getRank(); ++i) {
137     if (shapedType.isDynamicDim(i)) {
138       if (shapedType.isa<MemRefType>()) {
139         dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
140       } else {
141         assert(shapedType.isa<RankedTensorType>() && "expected tensor");
142         dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
143       }
144     }
145   }
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // AllocTensorOp
150 //===----------------------------------------------------------------------===//
151 
152 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
153                                        const BufferizationOptions &options) {
154   OpBuilder::InsertionGuard g(rewriter);
155   Location loc = getLoc();
156 
157   // Nothing to do for dead AllocTensorOps.
158   if (getOperation()->getUses().empty()) {
159     rewriter.eraseOp(getOperation());
160     return success();
161   }
162 
163   // Create buffer allocation.
164   Value copyBuffer;
165   if (getCopy())
166     copyBuffer = getBuffer(rewriter, getCopy(), options);
167   auto allocType =
168       MemRefType::get(getType().getShape(), getType().getElementType());
169   SmallVector<Value> dynamicDims = getDynamicSizes();
170   if (getCopy()) {
171     assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
172     populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
173   }
174   FailureOr<Value> alloc =
175       options.createAlloc(rewriter, loc, allocType, dynamicDims);
176   if (failed(alloc))
177     return failure();
178 
179   // Create memory copy (if any).
180   if (getCopy()) {
181     if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
182       return failure();
183   }
184 
185   // Should the buffer be deallocated?
186   AnalysisState analysisState(options);
187   bool dealloc;
188   if (getEscape()) {
189     dealloc = !*getEscape();
190   } else {
191     // No "escape" annotation found.
192     if (options.createDeallocs) {
193       // Perform an ad-hoc analysis.
194       dealloc = !analysisState.isTensorYielded(getResult());
195     } else {
196       dealloc = false;
197     }
198   }
199 
200   // Replace op.
201   replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
202 
203   // Create buffer deallocation (if requested).
204   if (!dealloc)
205     return success();
206 
207   rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
208   if (failed(options.createDealloc(rewriter, loc, *alloc)))
209     return failure();
210   return success();
211 }
212 
213 bool AllocTensorOp::isMemoryWrite(OpResult opResult,
214                                   const AnalysisState &state) {
215   // AllocTensorOps do not write unless they have a `copy` value.
216   return static_cast<bool>(getCopy());
217 }
218 
219 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
220                                            const AnalysisState &state) {
221   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
222          "expected copy operand");
223   return true;
224 }
225 
226 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
227                                             const AnalysisState &state) {
228   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
229          "expected copy operand");
230   return false;
231 }
232 
233 SmallVector<OpResult>
234 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
235                                    const AnalysisState &state) {
236   // This is a new allocation. It does not alias with any other buffer.
237   return {};
238 }
239 
240 LogicalResult AllocTensorOp::verify() {
241   if (getCopy() && !getDynamicSizes().empty())
242     return emitError("dynamic sizes not needed when copying a tensor");
243   if (!getCopy() && getType().getNumDynamicDims() !=
244                         static_cast<int64_t>(getDynamicSizes().size()))
245     return emitError("expected ")
246            << getType().getNumDynamicDims() << " dynamic sizes";
247   if (getCopy() && getCopy().getType() != getType())
248     return emitError("expected that `copy` and return type match");
249   return success();
250 }
251 
252 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
253                           RankedTensorType type, ValueRange dynamicSizes) {
254   build(builder, result, type, dynamicSizes, /*copy=*/Value(),
255         /*escape=*/BoolAttr());
256 }
257 
258 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
259                           RankedTensorType type, ValueRange dynamicSizes,
260                           Value copy) {
261   build(builder, result, type, dynamicSizes, copy, /*escape=*/BoolAttr());
262 }
263 
264 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
265                           RankedTensorType type, ValueRange dynamicSizes,
266                           Value copy, bool escape) {
267   build(builder, result, type, dynamicSizes, copy, builder.getBoolAttr(escape));
268 }
269 
270 namespace {
271 /// Change the type of the result of a `bufferization.alloc_tensor` by making
272 /// the result type statically sized along dimension that in the original
273 /// operation where defined as dynamic, but the size was defined using a
274 /// `constant` op. For example:
275 ///
276 ///  %c5 = arith.constant 5: index
277 ///  %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
278 ///
279 ///  to
280 ///
281 ///  %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
282 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
283   using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
284 
285   LogicalResult matchAndRewrite(AllocTensorOp op,
286                                 PatternRewriter &rewriter) const override {
287     if (op.getCopy())
288       return failure();
289     SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
290     SmallVector<Value> newDynamicSizes;
291     unsigned int dynValCounter = 0;
292     for (int64_t i = 0; i < op.getType().getRank(); ++i) {
293       if (!op.isDynamicDim(i))
294         continue;
295       Value value = op.getDynamicSizes()[dynValCounter++];
296       APInt intVal;
297       if (matchPattern(value, m_ConstantInt(&intVal))) {
298         newShape[i] = intVal.getSExtValue();
299       } else {
300         newDynamicSizes.push_back(value);
301       }
302     }
303     RankedTensorType newType = RankedTensorType::get(
304         newShape, op.getType().getElementType(), op.getType().getEncoding());
305     if (newType == op.getType())
306       return failure();
307     auto newOp = rewriter.create<AllocTensorOp>(
308         op.getLoc(), newType, newDynamicSizes, /*copy=*/Value(),
309         /*escape=*/op.getEscapeAttr());
310     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
311     return success();
312   }
313 };
314 
315 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
316   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
317 
318   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
319                                 PatternRewriter &rewriter) const override {
320     Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
321     auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>();
322     if (!allocTensorOp || !maybeConstantIndex)
323       return failure();
324     if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
325       return failure();
326     rewriter.replaceOp(
327         dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
328     return success();
329   }
330 };
331 } // namespace
332 
333 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
334                                                 MLIRContext *ctx) {
335   results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
336 }
337 
338 LogicalResult AllocTensorOp::reifyResultShapes(
339     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
340   auto shapes = llvm::to_vector<4>(llvm::map_range(
341       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
342         if (isDynamicDim(dim))
343           return getDynamicSize(builder, dim);
344         return builder.create<arith::ConstantIndexOp>(getLoc(),
345                                                       getStaticSize(dim));
346       }));
347   reifiedReturnShapes.emplace_back(std::move(shapes));
348   return success();
349 }
350 
351 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
352   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
353   if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
354       parser.parseRParen())
355     return failure();
356   ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
357   OpAsmParser::UnresolvedOperand copyOperand;
358   if (copyKeyword.succeeded())
359     if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
360         parser.parseRParen())
361       return failure();
362   if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
363     return failure();
364 
365   TensorType type;
366   if (parser.parseCustomTypeWithFallback(type))
367     return failure();
368   result.addTypes(type);
369 
370   Type indexType = parser.getBuilder().getIndexType();
371   if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
372     return failure();
373   if (copyKeyword.succeeded())
374     if (parser.resolveOperand(copyOperand, type, result.operands))
375       return failure();
376   result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
377                       parser.getBuilder().getI32VectorAttr(
378                           {static_cast<int32_t>(dynamicSizesOperands.size()),
379                            static_cast<int32_t>(copyKeyword.succeeded())}));
380   return success();
381 }
382 
383 void AllocTensorOp::print(OpAsmPrinter &p) {
384   p << "(" << getDynamicSizes() << ")";
385   if (getCopy())
386     p << " copy(" << getCopy() << ")";
387   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
388                               AllocTensorOp::getOperandSegmentSizeAttr()});
389   p << " : ";
390   auto type = getResult().getType();
391   if (auto validType = type.dyn_cast<::mlir::TensorType>())
392     p.printStrippedAttrOrType(validType);
393   else
394     p << type;
395 }
396 
397 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
398   assert(isDynamicDim(idx) && "expected dynamic dim");
399   if (getCopy())
400     return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
401   return getOperand(getIndexOfDynamicSize(idx));
402 }
403 
404 //===----------------------------------------------------------------------===//
405 // CloneOp
406 //===----------------------------------------------------------------------===//
407 
408 void CloneOp::getEffects(
409     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
410         &effects) {
411   effects.emplace_back(MemoryEffects::Read::get(), getInput(),
412                        SideEffects::DefaultResource::get());
413   effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
414                        SideEffects::DefaultResource::get());
415   effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
416                        SideEffects::DefaultResource::get());
417 }
418 
419 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
420   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
421 }
422 
423 namespace {
424 
425 /// Merge the clone and its source (by converting the clone to a cast) when
426 /// possible.
427 struct SimplifyClones : public OpRewritePattern<CloneOp> {
428   using OpRewritePattern<CloneOp>::OpRewritePattern;
429 
430   LogicalResult matchAndRewrite(CloneOp cloneOp,
431                                 PatternRewriter &rewriter) const override {
432     if (cloneOp.use_empty()) {
433       rewriter.eraseOp(cloneOp);
434       return success();
435     }
436 
437     Value source = cloneOp.getInput();
438 
439     // This only finds dealloc operations for the immediate value. It should
440     // also consider aliases. That would also make the safety check below
441     // redundant.
442     llvm::Optional<Operation *> maybeCloneDeallocOp =
443         memref::findDealloc(cloneOp.getOutput());
444     // Skip if either of them has > 1 deallocate operations.
445     if (!maybeCloneDeallocOp.hasValue())
446       return failure();
447     llvm::Optional<Operation *> maybeSourceDeallocOp =
448         memref::findDealloc(source);
449     if (!maybeSourceDeallocOp.hasValue())
450       return failure();
451     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
452     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
453 
454     // If both are deallocated in the same block, their in-block lifetimes
455     // might not fully overlap, so we cannot decide which one to drop.
456     if (cloneDeallocOp && sourceDeallocOp &&
457         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
458       return failure();
459 
460     Block *currentBlock = cloneOp->getBlock();
461     Operation *redundantDealloc = nullptr;
462     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
463       redundantDealloc = cloneDeallocOp;
464     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
465       redundantDealloc = sourceDeallocOp;
466     }
467 
468     if (!redundantDealloc)
469       return failure();
470 
471     // Safety check that there are no other deallocations inbetween
472     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
473     // of source before the uses of the clone. With alias information, we could
474     // restrict this to only fail of the dealloc's operand is an alias
475     // of the source.
476     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
477          pos = pos->getNextNode()) {
478       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
479       if (!effectInterface)
480         continue;
481       if (effectInterface.hasEffect<MemoryEffects::Free>())
482         return failure();
483     }
484 
485     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
486                                                 source);
487     rewriter.eraseOp(redundantDealloc);
488     return success();
489   }
490 };
491 
492 } // namespace
493 
494 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
495                                           MLIRContext *context) {
496   results.add<SimplifyClones>(context);
497 }
498 
499 //===----------------------------------------------------------------------===//
500 // ToTensorOp
501 //===----------------------------------------------------------------------===//
502 
503 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
504   if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
505     // Approximate alias analysis by conservatively folding only when no there
506     // is no interleaved operation.
507     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
508         toMemref->getNextNode() == this->getOperation())
509       return toMemref.getTensor();
510   return {};
511 }
512 
513 namespace {
514 
515 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
516   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
517 
518   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
519                                 PatternRewriter &rewriter) const override {
520     auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>();
521     if (!memrefToTensorOp)
522       return failure();
523 
524     rewriter.replaceOpWithNewOp<memref::DimOp>(
525         dimOp, memrefToTensorOp.getMemref(), dimOp.index());
526     return success();
527   }
528 };
529 
530 } // namespace
531 
532 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
533                                              MLIRContext *context) {
534   results.add<DimOfToTensorFolder>(context);
535 }
536 
537 //===----------------------------------------------------------------------===//
538 // ToMemrefOp
539 //===----------------------------------------------------------------------===//
540 
541 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
542   if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
543     if (memrefToTensor.getMemref().getType() == getType())
544       return memrefToTensor.getMemref();
545   return {};
546 }
547 
548 namespace {
549 
550 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
551 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
552   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
553 
554   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
555                                 PatternRewriter &rewriter) const final {
556     auto tensorCastOperand =
557         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
558     if (!tensorCastOperand)
559       return failure();
560     auto srcTensorType =
561         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
562     if (!srcTensorType)
563       return failure();
564     auto memrefType = MemRefType::get(srcTensorType.getShape(),
565                                       srcTensorType.getElementType());
566     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
567                                                tensorCastOperand.getOperand());
568     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
569                                                 memref);
570     return success();
571   }
572 };
573 
574 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
575 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
576 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
577   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
578 
579   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
580                                 PatternRewriter &rewriter) const final {
581     // Only handle cases where a cast is needed. The other case is handled by
582     // the folder.
583     return foldToMemrefToTensorPair(rewriter, toMemref,
584                                     /*allowSameType=*/false);
585   }
586 };
587 
588 /// Fold a load on a to_memref operation into an tensor.extract on the
589 /// corresponding tensor.
590 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
591   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
592 
593   LogicalResult matchAndRewrite(memref::LoadOp load,
594                                 PatternRewriter &rewriter) const override {
595     auto toMemref = load.memref().getDefiningOp<ToMemrefOp>();
596     if (!toMemref)
597       return failure();
598 
599     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
600                                                    load.indices());
601     return success();
602   }
603 };
604 
605 /// Fold dim of a to_memref into the dim of the tensor.
606 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
607   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
608 
609   LogicalResult matchAndRewrite(memref::DimOp dimOp,
610                                 PatternRewriter &rewriter) const override {
611     auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>();
612     if (!castOp)
613       return failure();
614     Value newSource = castOp.getOperand();
615     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
616     return success();
617   }
618 };
619 
620 } // namespace
621 
622 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
623                                              MLIRContext *context) {
624   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
625       context);
626 }
627 
628 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
629                                     const BufferizationOptions &options) {
630   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
631   (void)foldToMemrefToTensorPair(rewriter, *this);
632   // Note: The return value of `bufferize` indicates whether there was an error
633   // or not. (And not whether the pattern matched or not.)
634   return success();
635 }
636 
637 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
638   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
639       .getOperation();
640 }
641 
642 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
643   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // TableGen'd op method definitions
648 //===----------------------------------------------------------------------===//
649 
650 #define GET_OP_CLASSES
651 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
652