1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/Matchers.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 
20 //===----------------------------------------------------------------------===//
21 // Helper functions
22 //===----------------------------------------------------------------------===//
23 
24 FailureOr<Value>
25 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
26                                               MemRefType destType) {
27   auto srcType = value.getType().cast<MemRefType>();
28 
29   // Element type, rank and memory space must match.
30   if (srcType.getElementType() != destType.getElementType())
31     return failure();
32   if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
33     return failure();
34   if (srcType.getRank() != destType.getRank())
35     return failure();
36 
37   // In case the affine maps are different, we may need to use a copy if we go
38   // from dynamic to static offset or stride (the canonicalization cannot know
39   // at this point that it is really cast compatible).
40   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
41     int64_t sourceOffset, targetOffset;
42     SmallVector<int64_t, 4> sourceStrides, targetStrides;
43     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
44         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
45       return false;
46     auto dynamicToStatic = [](int64_t a, int64_t b) {
47       return a == MemRefType::getDynamicStrideOrOffset() &&
48              b != MemRefType::getDynamicStrideOrOffset();
49     };
50     if (dynamicToStatic(sourceOffset, targetOffset))
51       return false;
52     for (auto it : zip(sourceStrides, targetStrides))
53       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
54         return false;
55     return true;
56   };
57 
58   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
59   // ensure that we only generate casts that always succeed at runtime, we check
60   // a fix extra conditions in `isGuaranteedCastCompatible`.
61   if (memref::CastOp::areCastCompatible(srcType, destType) &&
62       isGuaranteedCastCompatible(srcType, destType)) {
63     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
64     return casted;
65   }
66 
67   auto loc = value.getLoc();
68   SmallVector<Value, 4> dynamicOperands;
69   for (int i = 0; i < destType.getRank(); ++i) {
70     if (destType.getShape()[i] != ShapedType::kDynamicSize)
71       continue;
72     auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
73     Value size = b.create<memref::DimOp>(loc, value, index);
74     dynamicOperands.push_back(size);
75   }
76   // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
77   // BufferizableOpInterface impl of ToMemrefOp.
78   Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
79   b.create<memref::CopyOp>(loc, value, copy);
80   return copy;
81 }
82 
83 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
84 /// to_memref op are different, a memref.cast is needed.
85 LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
86     RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
87   auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
88   if (!memrefToTensor)
89     return failure();
90 
91   Type srcType = memrefToTensor.memref().getType();
92   Type destType = toMemref.getType();
93 
94   // Directly rewrite if the type did not change.
95   if (srcType == destType) {
96     // Function can be configured to only handle cases where a cast is needed.
97     if (!allowSameType)
98       return failure();
99     rewriter.replaceOp(toMemref, memrefToTensor.memref());
100     return success();
101   }
102 
103   auto rankedSrcType = srcType.dyn_cast<MemRefType>();
104   auto rankedDestType = destType.dyn_cast<MemRefType>();
105   auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
106 
107   // Ranked memref -> Ranked memref cast.
108   if (rankedSrcType && rankedDestType) {
109     FailureOr<Value> replacement = castOrReallocMemRefValue(
110         rewriter, memrefToTensor.memref(), rankedDestType);
111     if (failed(replacement))
112       return failure();
113 
114     rewriter.replaceOp(toMemref, *replacement);
115     return success();
116   }
117 
118   // Unranked memref -> Ranked memref cast: May require a copy.
119   // TODO: Not implemented at the moment.
120   if (unrankedSrcType && rankedDestType)
121     return failure();
122 
123   // Unranked memref -> unranked memref cast
124   // Ranked memref -> unranked memref cast: No copy needed.
125   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126          "expected that types are cast compatible");
127   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
128                                               memrefToTensor.memref());
129   return success();
130 }
131 
132 void mlir::bufferization::populateDynamicDimSizes(
133     OpBuilder &b, Location loc, Value shapedValue,
134     SmallVector<Value> &dynamicDims) {
135   auto shapedType = shapedValue.getType().cast<ShapedType>();
136   for (int64_t i = 0; i < shapedType.getRank(); ++i) {
137     if (shapedType.isDynamicDim(i)) {
138       if (shapedType.isa<MemRefType>()) {
139         dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
140       } else {
141         assert(shapedType.isa<RankedTensorType>() && "expected tensor");
142         dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
143       }
144     }
145   }
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // AllocTensorOp
150 //===----------------------------------------------------------------------===//
151 
152 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
153                                        BufferizationState &state) {
154   OpBuilder::InsertionGuard g(rewriter);
155   Location loc = getLoc();
156 
157   // Nothing to do for dead AllocTensorOps.
158   if (getOperation()->getUses().empty()) {
159     rewriter.eraseOp(getOperation());
160     return success();
161   }
162 
163   // Create buffer allocation.
164   Value copyBuffer;
165   if (copy())
166     copyBuffer = state.getBuffer(rewriter, copy());
167   auto allocType =
168       MemRefType::get(getType().getShape(), getType().getElementType());
169   SmallVector<Value> dynamicDims = dynamicSizes();
170   if (copy()) {
171     assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
172     populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
173   }
174   FailureOr<Value> alloc =
175       state.getOptions().createAlloc(rewriter, loc, allocType, dynamicDims);
176   if (failed(alloc))
177     return failure();
178 
179   // Create memory copy (if any).
180   if (copy()) {
181     if (failed(
182             state.getOptions().createMemCpy(rewriter, loc, copyBuffer, *alloc)))
183       return failure();
184   }
185 
186   // Should the buffer be deallocated?
187   AnalysisState analysisState(state.getOptions());
188   bool dealloc;
189   if (escape().hasValue()) {
190     dealloc = !*escape();
191   } else {
192     // No "escape" annotation found.
193     if (state.getOptions().createDeallocs) {
194       // Perform an ad-hoc analysis.
195       dealloc = !analysisState.isTensorYielded(getResult());
196     } else {
197       dealloc = false;
198     }
199   }
200 
201   // Replace op.
202   replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
203 
204   // Create buffer deallocation (if requested).
205   if (!dealloc)
206     return success();
207 
208   rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
209   if (failed(state.getOptions().createDealloc(rewriter, loc, *alloc)))
210     return failure();
211   return success();
212 }
213 
214 bool AllocTensorOp::isMemoryWrite(OpResult opResult,
215                                   const AnalysisState &state) {
216   // AllocTensorOps do not write unless they have a `copy` value.
217   return static_cast<bool>(copy());
218 }
219 
220 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
221                                            const AnalysisState &state) {
222   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
223          "expected copy operand");
224   return true;
225 }
226 
227 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
228                                             const AnalysisState &state) {
229   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
230          "expected copy operand");
231   return false;
232 }
233 
234 SmallVector<OpResult>
235 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
236                                    const AnalysisState &state) {
237   // This is a new allocation. It does not alias with any other buffer.
238   return {};
239 }
240 
241 LogicalResult AllocTensorOp::verify() {
242   if (copy() && !dynamicSizes().empty())
243     return emitError("dynamic sizes not needed when copying a tensor");
244   if (!copy() && getType().getNumDynamicDims() !=
245                      static_cast<int64_t>(dynamicSizes().size()))
246     return emitError("expected ")
247            << getType().getNumDynamicDims() << " dynamic sizes";
248   if (copy() && copy().getType() != getType())
249     return emitError("expected that `copy` and return type match");
250   return success();
251 }
252 
253 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
254                           RankedTensorType type, ValueRange dynamicSizes) {
255   build(builder, result, type, dynamicSizes, /*copy=*/Value(),
256         /*escape=*/BoolAttr());
257 }
258 
259 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
260                           RankedTensorType type, ValueRange dynamicSizes,
261                           Value copy) {
262   build(builder, result, type, dynamicSizes, copy, /*escape=*/BoolAttr());
263 }
264 
265 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
266                           RankedTensorType type, ValueRange dynamicSizes,
267                           Value copy, bool escape) {
268   build(builder, result, type, dynamicSizes, copy, builder.getBoolAttr(escape));
269 }
270 
271 namespace {
272 /// Change the type of the result of a `bufferization.alloc_tensor` by making
273 /// the result type statically sized along dimension that in the original
274 /// operation where defined as dynamic, but the size was defined using a
275 /// `constant` op. For example:
276 ///
277 ///  %c5 = arith.constant 5: index
278 ///  %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
279 ///
280 ///  to
281 ///
282 ///  %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
283 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
284   using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
285 
286   LogicalResult matchAndRewrite(AllocTensorOp op,
287                                 PatternRewriter &rewriter) const override {
288     if (op.copy())
289       return failure();
290     SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
291     SmallVector<Value> newDynamicSizes;
292     unsigned int dynValCounter = 0;
293     for (int64_t i = 0; i < op.getType().getRank(); ++i) {
294       if (!op.isDynamicDim(i))
295         continue;
296       Value value = op.dynamicSizes()[dynValCounter++];
297       APInt intVal;
298       if (matchPattern(value, m_ConstantInt(&intVal))) {
299         newShape[i] = intVal.getSExtValue();
300       } else {
301         newDynamicSizes.push_back(value);
302       }
303     }
304     RankedTensorType newType = RankedTensorType::get(
305         newShape, op.getType().getElementType(), op.getType().getEncoding());
306     if (newType == op.getType())
307       return failure();
308     auto newOp = rewriter.create<AllocTensorOp>(
309         op.getLoc(), newType, newDynamicSizes, /*copy=*/Value(),
310         /*escape=*/op.escapeAttr());
311     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
312     return success();
313   }
314 };
315 
316 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
317   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
318 
319   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
320                                 PatternRewriter &rewriter) const override {
321     Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
322     auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>();
323     if (!allocTensorOp || !maybeConstantIndex)
324       return failure();
325     if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
326       return failure();
327     rewriter.replaceOp(
328         dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
329     return success();
330   }
331 };
332 } // namespace
333 
334 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
335                                                 MLIRContext *ctx) {
336   results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
337 }
338 
339 LogicalResult AllocTensorOp::reifyResultShapes(
340     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
341   auto shapes = llvm::to_vector<4>(llvm::map_range(
342       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
343         if (isDynamicDim(dim))
344           return getDynamicSize(builder, dim);
345         return builder.create<arith::ConstantIndexOp>(getLoc(),
346                                                       getStaticSize(dim));
347       }));
348   reifiedReturnShapes.emplace_back(std::move(shapes));
349   return success();
350 }
351 
352 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
353   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
354   if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
355       parser.parseRParen())
356     return failure();
357   ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
358   OpAsmParser::UnresolvedOperand copyOperand;
359   if (copyKeyword.succeeded())
360     if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
361         parser.parseRParen())
362       return failure();
363   if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
364     return failure();
365 
366   TensorType type;
367   if (parser.parseCustomTypeWithFallback(type))
368     return failure();
369   result.addTypes(type);
370 
371   Type indexType = parser.getBuilder().getIndexType();
372   if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
373     return failure();
374   if (copyKeyword.succeeded())
375     if (parser.resolveOperand(copyOperand, type, result.operands))
376       return failure();
377   result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
378                       parser.getBuilder().getI32VectorAttr(
379                           {static_cast<int32_t>(dynamicSizesOperands.size()),
380                            static_cast<int32_t>(copyKeyword.succeeded())}));
381   return success();
382 }
383 
384 void AllocTensorOp::print(OpAsmPrinter &p) {
385   p << "(" << dynamicSizes() << ")";
386   if (copy())
387     p << " copy(" << copy() << ")";
388   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
389                               AllocTensorOp::getOperandSegmentSizeAttr()});
390   p << " : ";
391   auto type = result().getType();
392   if (auto validType = type.dyn_cast<::mlir::TensorType>())
393     p.printStrippedAttrOrType(validType);
394   else
395     p << type;
396 }
397 
398 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
399   assert(isDynamicDim(idx) && "expected dynamic dim");
400   if (copy())
401     return b.create<tensor::DimOp>(getLoc(), copy(), idx);
402   return getOperand(getIndexOfDynamicSize(idx));
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // CloneOp
407 //===----------------------------------------------------------------------===//
408 
409 void CloneOp::getEffects(
410     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
411         &effects) {
412   effects.emplace_back(MemoryEffects::Read::get(), input(),
413                        SideEffects::DefaultResource::get());
414   effects.emplace_back(MemoryEffects::Write::get(), output(),
415                        SideEffects::DefaultResource::get());
416   effects.emplace_back(MemoryEffects::Allocate::get(), output(),
417                        SideEffects::DefaultResource::get());
418 }
419 
420 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
421   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
422 }
423 
424 namespace {
425 
426 /// Merge the clone and its source (by converting the clone to a cast) when
427 /// possible.
428 struct SimplifyClones : public OpRewritePattern<CloneOp> {
429   using OpRewritePattern<CloneOp>::OpRewritePattern;
430 
431   LogicalResult matchAndRewrite(CloneOp cloneOp,
432                                 PatternRewriter &rewriter) const override {
433     if (cloneOp.use_empty()) {
434       rewriter.eraseOp(cloneOp);
435       return success();
436     }
437 
438     Value source = cloneOp.input();
439 
440     // This only finds dealloc operations for the immediate value. It should
441     // also consider aliases. That would also make the safety check below
442     // redundant.
443     llvm::Optional<Operation *> maybeCloneDeallocOp =
444         memref::findDealloc(cloneOp.output());
445     // Skip if either of them has > 1 deallocate operations.
446     if (!maybeCloneDeallocOp.hasValue())
447       return failure();
448     llvm::Optional<Operation *> maybeSourceDeallocOp =
449         memref::findDealloc(source);
450     if (!maybeSourceDeallocOp.hasValue())
451       return failure();
452     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
453     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
454 
455     // If both are deallocated in the same block, their in-block lifetimes
456     // might not fully overlap, so we cannot decide which one to drop.
457     if (cloneDeallocOp && sourceDeallocOp &&
458         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
459       return failure();
460 
461     Block *currentBlock = cloneOp->getBlock();
462     Operation *redundantDealloc = nullptr;
463     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
464       redundantDealloc = cloneDeallocOp;
465     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
466       redundantDealloc = sourceDeallocOp;
467     }
468 
469     if (!redundantDealloc)
470       return failure();
471 
472     // Safety check that there are no other deallocations inbetween
473     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
474     // of source before the uses of the clone. With alias information, we could
475     // restrict this to only fail of the dealloc's operand is an alias
476     // of the source.
477     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
478          pos = pos->getNextNode()) {
479       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
480       if (!effectInterface)
481         continue;
482       if (effectInterface.hasEffect<MemoryEffects::Free>())
483         return failure();
484     }
485 
486     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
487                                                 source);
488     rewriter.eraseOp(redundantDealloc);
489     return success();
490   }
491 };
492 
493 } // namespace
494 
495 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
496                                           MLIRContext *context) {
497   results.add<SimplifyClones>(context);
498 }
499 
500 //===----------------------------------------------------------------------===//
501 // ToTensorOp
502 //===----------------------------------------------------------------------===//
503 
504 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
505   if (auto toMemref = memref().getDefiningOp<ToMemrefOp>())
506     // Approximate alias analysis by conservatively folding only when no there
507     // is no interleaved operation.
508     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
509         toMemref->getNextNode() == this->getOperation())
510       return toMemref.tensor();
511   return {};
512 }
513 
514 namespace {
515 
516 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
517   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
518 
519   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
520                                 PatternRewriter &rewriter) const override {
521     auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>();
522     if (!memrefToTensorOp)
523       return failure();
524 
525     rewriter.replaceOpWithNewOp<memref::DimOp>(dimOp, memrefToTensorOp.memref(),
526                                                dimOp.index());
527     return success();
528   }
529 };
530 
531 } // namespace
532 
533 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
534                                              MLIRContext *context) {
535   results.add<DimOfToTensorFolder>(context);
536 }
537 
538 //===----------------------------------------------------------------------===//
539 // ToMemrefOp
540 //===----------------------------------------------------------------------===//
541 
542 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
543   if (auto memrefToTensor = tensor().getDefiningOp<ToTensorOp>())
544     if (memrefToTensor.memref().getType() == getType())
545       return memrefToTensor.memref();
546   return {};
547 }
548 
549 namespace {
550 
551 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
552 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
553   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
554 
555   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
556                                 PatternRewriter &rewriter) const final {
557     auto tensorCastOperand =
558         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
559     if (!tensorCastOperand)
560       return failure();
561     auto srcTensorType =
562         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
563     if (!srcTensorType)
564       return failure();
565     auto memrefType = MemRefType::get(srcTensorType.getShape(),
566                                       srcTensorType.getElementType());
567     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
568                                                tensorCastOperand.getOperand());
569     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
570                                                 memref);
571     return success();
572   }
573 };
574 
575 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
576 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
577 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
578   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
579 
580   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
581                                 PatternRewriter &rewriter) const final {
582     // Only handle cases where a cast is needed. The other case is handled by
583     // the folder.
584     return foldToMemrefToTensorPair(rewriter, toMemref,
585                                     /*allowSameType=*/false);
586   }
587 };
588 
589 /// Fold a load on a to_memref operation into an tensor.extract on the
590 /// corresponding tensor.
591 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
592   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
593 
594   LogicalResult matchAndRewrite(memref::LoadOp load,
595                                 PatternRewriter &rewriter) const override {
596     auto toMemref = load.memref().getDefiningOp<ToMemrefOp>();
597     if (!toMemref)
598       return failure();
599 
600     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.tensor(),
601                                                    load.indices());
602     return success();
603   }
604 };
605 
606 /// Fold dim of a to_memref into the dim of the tensor.
607 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
608   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
609 
610   LogicalResult matchAndRewrite(memref::DimOp dimOp,
611                                 PatternRewriter &rewriter) const override {
612     auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>();
613     if (!castOp)
614       return failure();
615     Value newSource = castOp.getOperand();
616     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
617     return success();
618   }
619 };
620 
621 } // namespace
622 
623 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
624                                              MLIRContext *context) {
625   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
626       context);
627 }
628 
629 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
630                                     BufferizationState &state) {
631   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
632   (void)foldToMemrefToTensorPair(rewriter, *this);
633   // Note: The return value of `bufferize` indicates whether there was an error
634   // or not. (And not whether the pattern matched or not.)
635   return success();
636 }
637 
638 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
639   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
640       .getOperation();
641 }
642 
643 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
644   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // TableGen'd op method definitions
649 //===----------------------------------------------------------------------===//
650 
651 #define GET_OP_CLASSES
652 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
653