1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/Matchers.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 
22 //===----------------------------------------------------------------------===//
23 // Helper functions
24 //===----------------------------------------------------------------------===//
25 
26 FailureOr<Value>
27 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
28                                               MemRefType destType) {
29   auto srcType = value.getType().cast<MemRefType>();
30 
31   // Element type, rank and memory space must match.
32   if (srcType.getElementType() != destType.getElementType())
33     return failure();
34   if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
35     return failure();
36   if (srcType.getRank() != destType.getRank())
37     return failure();
38 
39   // In case the affine maps are different, we may need to use a copy if we go
40   // from dynamic to static offset or stride (the canonicalization cannot know
41   // at this point that it is really cast compatible).
42   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43     int64_t sourceOffset, targetOffset;
44     SmallVector<int64_t, 4> sourceStrides, targetStrides;
45     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47       return false;
48     auto dynamicToStatic = [](int64_t a, int64_t b) {
49       return a == MemRefType::getDynamicStrideOrOffset() &&
50              b != MemRefType::getDynamicStrideOrOffset();
51     };
52     if (dynamicToStatic(sourceOffset, targetOffset))
53       return false;
54     for (auto it : zip(sourceStrides, targetStrides))
55       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
56         return false;
57     return true;
58   };
59 
60   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
61   // ensure that we only generate casts that always succeed at runtime, we check
62   // a fix extra conditions in `isGuaranteedCastCompatible`.
63   if (memref::CastOp::areCastCompatible(srcType, destType) &&
64       isGuaranteedCastCompatible(srcType, destType)) {
65     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
66     return casted;
67   }
68 
69   auto loc = value.getLoc();
70   SmallVector<Value, 4> dynamicOperands;
71   for (int i = 0; i < destType.getRank(); ++i) {
72     if (destType.getShape()[i] != ShapedType::kDynamicSize)
73       continue;
74     auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
75     Value size = b.create<memref::DimOp>(loc, value, index);
76     dynamicOperands.push_back(size);
77   }
78   // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
79   // BufferizableOpInterface impl of ToMemrefOp.
80   Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
81   b.create<memref::CopyOp>(loc, value, copy);
82   return copy;
83 }
84 
85 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
86 /// to_memref op are different, a memref.cast is needed.
87 LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
88     RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
89   auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
90   if (!memrefToTensor)
91     return failure();
92 
93   Type srcType = memrefToTensor.getMemref().getType();
94   Type destType = toMemref.getType();
95 
96   // Directly rewrite if the type did not change.
97   if (srcType == destType) {
98     // Function can be configured to only handle cases where a cast is needed.
99     if (!allowSameType)
100       return failure();
101     rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
102     return success();
103   }
104 
105   auto rankedSrcType = srcType.dyn_cast<MemRefType>();
106   auto rankedDestType = destType.dyn_cast<MemRefType>();
107   auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
108 
109   // Ranked memref -> Ranked memref cast.
110   if (rankedSrcType && rankedDestType) {
111     FailureOr<Value> replacement = castOrReallocMemRefValue(
112         rewriter, memrefToTensor.getMemref(), rankedDestType);
113     if (failed(replacement))
114       return failure();
115 
116     rewriter.replaceOp(toMemref, *replacement);
117     return success();
118   }
119 
120   // Unranked memref -> Ranked memref cast: May require a copy.
121   // TODO: Not implemented at the moment.
122   if (unrankedSrcType && rankedDestType)
123     return failure();
124 
125   // Unranked memref -> unranked memref cast
126   // Ranked memref -> unranked memref cast: No copy needed.
127   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
128          "expected that types are cast compatible");
129   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
130                                               memrefToTensor.getMemref());
131   return success();
132 }
133 
134 void mlir::bufferization::populateDynamicDimSizes(
135     OpBuilder &b, Location loc, Value shapedValue,
136     SmallVector<Value> &dynamicDims) {
137   auto shapedType = shapedValue.getType().cast<ShapedType>();
138   for (int64_t i = 0; i < shapedType.getRank(); ++i) {
139     if (shapedType.isDynamicDim(i)) {
140       if (shapedType.isa<MemRefType>()) {
141         dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
142       } else {
143         assert(shapedType.isa<RankedTensorType>() && "expected tensor");
144         dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
145       }
146     }
147   }
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // AllocTensorOp
152 //===----------------------------------------------------------------------===//
153 
154 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
155                                        const BufferizationOptions &options) {
156   OpBuilder::InsertionGuard g(rewriter);
157   Operation *op = this->getOperation();
158   Location loc = getLoc();
159 
160   // Nothing to do for dead AllocTensorOps.
161   if (getOperation()->getUses().empty()) {
162     rewriter.eraseOp(getOperation());
163     return success();
164   }
165 
166   // Get "copy" buffer.
167   Value copyBuffer;
168   if (getCopy()) {
169     FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
170     if (failed(maybeCopyBuffer))
171       return failure();
172     copyBuffer = *maybeCopyBuffer;
173   }
174 
175   // Compute memory space of this allocation.
176   unsigned memorySpace;
177   if (getMemorySpace().hasValue()) {
178     memorySpace = *getMemorySpace();
179   } else if (getCopy()) {
180     memorySpace =
181         copyBuffer.getType().cast<BaseMemRefType>().getMemorySpaceAsInt();
182   } else if (options.defaultMemorySpace.hasValue()) {
183     memorySpace = *options.defaultMemorySpace;
184   } else {
185     return op->emitError("could not infer memory space");
186   }
187 
188   // Create memory allocation.
189   auto allocType =
190       MemRefType::get(getType().getShape(), getType().getElementType(),
191                       AffineMap(), memorySpace);
192   SmallVector<Value> dynamicDims = getDynamicSizes();
193   if (getCopy()) {
194     assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
195     populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
196   }
197   FailureOr<Value> alloc =
198       options.createAlloc(rewriter, loc, allocType, dynamicDims);
199   if (failed(alloc))
200     return failure();
201 
202   // Create memory copy (if any).
203   if (getCopy()) {
204     if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
205       return failure();
206   }
207 
208   // Should the buffer be deallocated?
209   AnalysisState analysisState(options);
210   bool dealloc;
211   if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) {
212     // AllocTensorOp has one result.
213     ArrayAttr escapeAttr =
214         op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>();
215     dealloc = !escapeAttr[0].cast<BoolAttr>().getValue();
216   } else {
217     // No "escape" annotation found.
218     if (options.createDeallocs) {
219       // Perform an ad-hoc analysis.
220       dealloc = !analysisState.isTensorYielded(getResult());
221     } else {
222       dealloc = false;
223     }
224   }
225 
226   // Replace op.
227   replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
228 
229   // Create buffer deallocation (if requested).
230   if (!dealloc)
231     return success();
232 
233   rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
234   if (failed(options.createDealloc(rewriter, loc, *alloc)))
235     return failure();
236   return success();
237 }
238 
239 bool AllocTensorOp::isMemoryWrite(OpResult opResult,
240                                   const AnalysisState &state) {
241   // AllocTensorOps do not write unless they have a `copy` value.
242   return static_cast<bool>(getCopy());
243 }
244 
245 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
246                                            const AnalysisState &state) {
247   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
248          "expected copy operand");
249   return true;
250 }
251 
252 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
253                                             const AnalysisState &state) {
254   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
255          "expected copy operand");
256   return false;
257 }
258 
259 SmallVector<OpResult>
260 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
261                                    const AnalysisState &state) {
262   // This is a new allocation. It does not alias with any other buffer.
263   return {};
264 }
265 
266 LogicalResult AllocTensorOp::verify() {
267   if (getCopy() && !getDynamicSizes().empty())
268     return emitError("dynamic sizes not needed when copying a tensor");
269   if (!getCopy() && getType().getNumDynamicDims() !=
270                         static_cast<int64_t>(getDynamicSizes().size()))
271     return emitError("expected ")
272            << getType().getNumDynamicDims() << " dynamic sizes";
273   if (getCopy() && getCopy().getType() != getType())
274     return emitError("expected that `copy` and return type match");
275 
276   // For sparse tensor allocation, we require that none of its
277   // uses escapes the function boundary directly.
278   if (sparse_tensor::getSparseTensorEncoding(getType())) {
279     for (auto &use : getOperation()->getUses())
280       if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
281               use.getOwner()))
282         return emitError("sparse tensor allocation should not escape function");
283   }
284 
285   return success();
286 }
287 
288 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
289                           RankedTensorType type, ValueRange dynamicSizes) {
290   build(builder, result, type, dynamicSizes, /*copy=*/Value(),
291         /*memory_space=*/IntegerAttr());
292 }
293 
294 void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
295                           RankedTensorType type, ValueRange dynamicSizes,
296                           Value copy) {
297   build(builder, result, type, dynamicSizes, copy,
298         /*memory_space=*/IntegerAttr());
299 }
300 
301 namespace {
302 /// Change the type of the result of a `bufferization.alloc_tensor` by making
303 /// the result type statically sized along dimension that in the original
304 /// operation where defined as dynamic, but the size was defined using a
305 /// `constant` op. For example:
306 ///
307 ///  %c5 = arith.constant 5: index
308 ///  %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
309 ///
310 ///  to
311 ///
312 ///  %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
313 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
314   using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
315 
316   LogicalResult matchAndRewrite(AllocTensorOp op,
317                                 PatternRewriter &rewriter) const override {
318     if (op.getCopy())
319       return failure();
320     SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
321     SmallVector<Value> newDynamicSizes;
322     unsigned int dynValCounter = 0;
323     for (int64_t i = 0; i < op.getType().getRank(); ++i) {
324       if (!op.isDynamicDim(i))
325         continue;
326       Value value = op.getDynamicSizes()[dynValCounter++];
327       APInt intVal;
328       if (matchPattern(value, m_ConstantInt(&intVal))) {
329         newShape[i] = intVal.getSExtValue();
330       } else {
331         newDynamicSizes.push_back(value);
332       }
333     }
334     RankedTensorType newType = RankedTensorType::get(
335         newShape, op.getType().getElementType(), op.getType().getEncoding());
336     if (newType == op.getType())
337       return failure();
338     auto newOp = rewriter.create<AllocTensorOp>(
339         op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
340     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
341     return success();
342   }
343 };
344 
345 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
346   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
347 
348   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
349                                 PatternRewriter &rewriter) const override {
350     Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
351     auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>();
352     if (!allocTensorOp || !maybeConstantIndex)
353       return failure();
354     if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
355       return failure();
356     rewriter.replaceOp(
357         dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
358     return success();
359   }
360 };
361 } // namespace
362 
363 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
364                                                 MLIRContext *ctx) {
365   results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
366 }
367 
368 LogicalResult AllocTensorOp::reifyResultShapes(
369     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
370   auto shapes = llvm::to_vector<4>(llvm::map_range(
371       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
372         if (isDynamicDim(dim))
373           return getDynamicSize(builder, dim);
374         return builder.create<arith::ConstantIndexOp>(getLoc(),
375                                                       getStaticSize(dim));
376       }));
377   reifiedReturnShapes.emplace_back(std::move(shapes));
378   return success();
379 }
380 
381 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
382   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
383   if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
384       parser.parseRParen())
385     return failure();
386   ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
387   OpAsmParser::UnresolvedOperand copyOperand;
388   if (copyKeyword.succeeded())
389     if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
390         parser.parseRParen())
391       return failure();
392   if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
393     return failure();
394 
395   TensorType type;
396   if (parser.parseCustomTypeWithFallback(type))
397     return failure();
398   result.addTypes(type);
399 
400   Type indexType = parser.getBuilder().getIndexType();
401   if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
402     return failure();
403   if (copyKeyword.succeeded())
404     if (parser.resolveOperand(copyOperand, type, result.operands))
405       return failure();
406   result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
407                       parser.getBuilder().getI32VectorAttr(
408                           {static_cast<int32_t>(dynamicSizesOperands.size()),
409                            static_cast<int32_t>(copyKeyword.succeeded())}));
410   return success();
411 }
412 
413 void AllocTensorOp::print(OpAsmPrinter &p) {
414   p << "(" << getDynamicSizes() << ")";
415   if (getCopy())
416     p << " copy(" << getCopy() << ")";
417   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
418                               AllocTensorOp::getOperandSegmentSizeAttr()});
419   p << " : ";
420   auto type = getResult().getType();
421   if (auto validType = type.dyn_cast<::mlir::TensorType>())
422     p.printStrippedAttrOrType(validType);
423   else
424     p << type;
425 }
426 
427 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
428   assert(isDynamicDim(idx) && "expected dynamic dim");
429   if (getCopy())
430     return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
431   return getOperand(getIndexOfDynamicSize(idx));
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // CloneOp
436 //===----------------------------------------------------------------------===//
437 
438 void CloneOp::getEffects(
439     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
440         &effects) {
441   effects.emplace_back(MemoryEffects::Read::get(), getInput(),
442                        SideEffects::DefaultResource::get());
443   effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
444                        SideEffects::DefaultResource::get());
445   effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
446                        SideEffects::DefaultResource::get());
447 }
448 
449 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
450   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
451 }
452 
453 namespace {
454 
455 /// Merge the clone and its source (by converting the clone to a cast) when
456 /// possible.
457 struct SimplifyClones : public OpRewritePattern<CloneOp> {
458   using OpRewritePattern<CloneOp>::OpRewritePattern;
459 
460   LogicalResult matchAndRewrite(CloneOp cloneOp,
461                                 PatternRewriter &rewriter) const override {
462     if (cloneOp.use_empty()) {
463       rewriter.eraseOp(cloneOp);
464       return success();
465     }
466 
467     Value source = cloneOp.getInput();
468 
469     // This only finds dealloc operations for the immediate value. It should
470     // also consider aliases. That would also make the safety check below
471     // redundant.
472     llvm::Optional<Operation *> maybeCloneDeallocOp =
473         memref::findDealloc(cloneOp.getOutput());
474     // Skip if either of them has > 1 deallocate operations.
475     if (!maybeCloneDeallocOp.hasValue())
476       return failure();
477     llvm::Optional<Operation *> maybeSourceDeallocOp =
478         memref::findDealloc(source);
479     if (!maybeSourceDeallocOp.hasValue())
480       return failure();
481     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
482     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
483 
484     // If both are deallocated in the same block, their in-block lifetimes
485     // might not fully overlap, so we cannot decide which one to drop.
486     if (cloneDeallocOp && sourceDeallocOp &&
487         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
488       return failure();
489 
490     Block *currentBlock = cloneOp->getBlock();
491     Operation *redundantDealloc = nullptr;
492     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
493       redundantDealloc = cloneDeallocOp;
494     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
495       redundantDealloc = sourceDeallocOp;
496     }
497 
498     if (!redundantDealloc)
499       return failure();
500 
501     // Safety check that there are no other deallocations inbetween
502     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
503     // of source before the uses of the clone. With alias information, we could
504     // restrict this to only fail of the dealloc's operand is an alias
505     // of the source.
506     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
507          pos = pos->getNextNode()) {
508       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
509       if (!effectInterface)
510         continue;
511       if (effectInterface.hasEffect<MemoryEffects::Free>())
512         return failure();
513     }
514 
515     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
516                                                 source);
517     rewriter.eraseOp(redundantDealloc);
518     return success();
519   }
520 };
521 
522 } // namespace
523 
524 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
525                                           MLIRContext *context) {
526   results.add<SimplifyClones>(context);
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // ToTensorOp
531 //===----------------------------------------------------------------------===//
532 
533 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
534   if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
535     // Approximate alias analysis by conservatively folding only when no there
536     // is no interleaved operation.
537     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
538         toMemref->getNextNode() == this->getOperation())
539       return toMemref.getTensor();
540   return {};
541 }
542 
543 namespace {
544 
545 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
546   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
547 
548   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
549                                 PatternRewriter &rewriter) const override {
550     auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>();
551     if (!memrefToTensorOp)
552       return failure();
553 
554     rewriter.replaceOpWithNewOp<memref::DimOp>(
555         dimOp, memrefToTensorOp.getMemref(), dimOp.index());
556     return success();
557   }
558 };
559 
560 } // namespace
561 
562 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
563                                              MLIRContext *context) {
564   results.add<DimOfToTensorFolder>(context);
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // ToMemrefOp
569 //===----------------------------------------------------------------------===//
570 
571 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
572   if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
573     if (memrefToTensor.getMemref().getType() == getType())
574       return memrefToTensor.getMemref();
575   return {};
576 }
577 
578 namespace {
579 
580 /// Replace tensor.cast + to_memref by to_memref + memref.cast.
581 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
582   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
583 
584   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
585                                 PatternRewriter &rewriter) const final {
586     auto tensorCastOperand =
587         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
588     if (!tensorCastOperand)
589       return failure();
590     auto srcTensorType =
591         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
592     if (!srcTensorType)
593       return failure();
594     auto memrefType = MemRefType::get(srcTensorType.getShape(),
595                                       srcTensorType.getElementType());
596     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
597                                                tensorCastOperand.getOperand());
598     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
599                                                 memref);
600     return success();
601   }
602 };
603 
604 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
605 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
606 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
607   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
608 
609   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
610                                 PatternRewriter &rewriter) const final {
611     // Only handle cases where a cast is needed. The other case is handled by
612     // the folder.
613     return foldToMemrefToTensorPair(rewriter, toMemref,
614                                     /*allowSameType=*/false);
615   }
616 };
617 
618 /// Fold a load on a to_memref operation into an tensor.extract on the
619 /// corresponding tensor.
620 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
621   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
622 
623   LogicalResult matchAndRewrite(memref::LoadOp load,
624                                 PatternRewriter &rewriter) const override {
625     auto toMemref = load.memref().getDefiningOp<ToMemrefOp>();
626     if (!toMemref)
627       return failure();
628 
629     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
630                                                    load.indices());
631     return success();
632   }
633 };
634 
635 /// Fold dim of a to_memref into the dim of the tensor.
636 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
637   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
638 
639   LogicalResult matchAndRewrite(memref::DimOp dimOp,
640                                 PatternRewriter &rewriter) const override {
641     auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>();
642     if (!castOp)
643       return failure();
644     Value newSource = castOp.getOperand();
645     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
646     return success();
647   }
648 };
649 
650 } // namespace
651 
652 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
653                                              MLIRContext *context) {
654   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
655       context);
656 }
657 
658 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
659                                     const BufferizationOptions &options) {
660   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
661   (void)foldToMemrefToTensorPair(rewriter, *this);
662   // Note: The return value of `bufferize` indicates whether there was an error
663   // or not. (And not whether the pattern matched or not.)
664   return success();
665 }
666 
667 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
668   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
669       .getOperation();
670 }
671 
672 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
673   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // TableGen'd op method definitions
678 //===----------------------------------------------------------------------===//
679 
680 #define GET_OP_CLASSES
681 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
682