157470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
257470abcSAlexander Belyaev //
357470abcSAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
457470abcSAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
557470abcSAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
657470abcSAlexander Belyaev //
757470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
857470abcSAlexander Belyaev 
9eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1157470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
129a3d60e0SAart Bik #include "mlir/Dialect/Func/IR/FuncOps.h"
13eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
1457470abcSAlexander Belyaev #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
159a3d60e0SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
16eda6f907SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h"
17ec55f0bdSMatthias Springer #include "mlir/IR/Matchers.h"
1857470abcSAlexander Belyaev 
1957470abcSAlexander Belyaev using namespace mlir;
2057470abcSAlexander Belyaev using namespace mlir::bufferization;
2157470abcSAlexander Belyaev 
2257470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
23fa7c8cb4SMatthias Springer // Helper functions
24fa7c8cb4SMatthias Springer //===----------------------------------------------------------------------===//
25fa7c8cb4SMatthias Springer 
26fa7c8cb4SMatthias Springer FailureOr<Value>
castOrReallocMemRefValue(OpBuilder & b,Value value,MemRefType destType)27fa7c8cb4SMatthias Springer mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
28fa7c8cb4SMatthias Springer                                               MemRefType destType) {
29fa7c8cb4SMatthias Springer   auto srcType = value.getType().cast<MemRefType>();
30fa7c8cb4SMatthias Springer 
31fa7c8cb4SMatthias Springer   // Element type, rank and memory space must match.
32fa7c8cb4SMatthias Springer   if (srcType.getElementType() != destType.getElementType())
33fa7c8cb4SMatthias Springer     return failure();
34fa7c8cb4SMatthias Springer   if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
35fa7c8cb4SMatthias Springer     return failure();
36fa7c8cb4SMatthias Springer   if (srcType.getRank() != destType.getRank())
37fa7c8cb4SMatthias Springer     return failure();
38fa7c8cb4SMatthias Springer 
39fa7c8cb4SMatthias Springer   // In case the affine maps are different, we may need to use a copy if we go
40fa7c8cb4SMatthias Springer   // from dynamic to static offset or stride (the canonicalization cannot know
41fa7c8cb4SMatthias Springer   // at this point that it is really cast compatible).
42fa7c8cb4SMatthias Springer   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43fa7c8cb4SMatthias Springer     int64_t sourceOffset, targetOffset;
44fa7c8cb4SMatthias Springer     SmallVector<int64_t, 4> sourceStrides, targetStrides;
45fa7c8cb4SMatthias Springer     if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46fa7c8cb4SMatthias Springer         failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47fa7c8cb4SMatthias Springer       return false;
48fa7c8cb4SMatthias Springer     auto dynamicToStatic = [](int64_t a, int64_t b) {
49fa7c8cb4SMatthias Springer       return a == MemRefType::getDynamicStrideOrOffset() &&
50fa7c8cb4SMatthias Springer              b != MemRefType::getDynamicStrideOrOffset();
51fa7c8cb4SMatthias Springer     };
52fa7c8cb4SMatthias Springer     if (dynamicToStatic(sourceOffset, targetOffset))
53fa7c8cb4SMatthias Springer       return false;
54fa7c8cb4SMatthias Springer     for (auto it : zip(sourceStrides, targetStrides))
55fa7c8cb4SMatthias Springer       if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
56fa7c8cb4SMatthias Springer         return false;
57fa7c8cb4SMatthias Springer     return true;
58fa7c8cb4SMatthias Springer   };
59fa7c8cb4SMatthias Springer 
60fa7c8cb4SMatthias Springer   // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
61fa7c8cb4SMatthias Springer   // ensure that we only generate casts that always succeed at runtime, we check
62fa7c8cb4SMatthias Springer   // a fix extra conditions in `isGuaranteedCastCompatible`.
63fa7c8cb4SMatthias Springer   if (memref::CastOp::areCastCompatible(srcType, destType) &&
64fa7c8cb4SMatthias Springer       isGuaranteedCastCompatible(srcType, destType)) {
65fa7c8cb4SMatthias Springer     Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
66fa7c8cb4SMatthias Springer     return casted;
67fa7c8cb4SMatthias Springer   }
68fa7c8cb4SMatthias Springer 
69fa7c8cb4SMatthias Springer   auto loc = value.getLoc();
70fa7c8cb4SMatthias Springer   SmallVector<Value, 4> dynamicOperands;
71fa7c8cb4SMatthias Springer   for (int i = 0; i < destType.getRank(); ++i) {
72fa7c8cb4SMatthias Springer     if (destType.getShape()[i] != ShapedType::kDynamicSize)
73fa7c8cb4SMatthias Springer       continue;
74fa7c8cb4SMatthias Springer     auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
75fa7c8cb4SMatthias Springer     Value size = b.create<memref::DimOp>(loc, value, index);
76fa7c8cb4SMatthias Springer     dynamicOperands.push_back(size);
77fa7c8cb4SMatthias Springer   }
78fa7c8cb4SMatthias Springer   // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
79fa7c8cb4SMatthias Springer   // BufferizableOpInterface impl of ToMemrefOp.
80fa7c8cb4SMatthias Springer   Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
81fa7c8cb4SMatthias Springer   b.create<memref::CopyOp>(loc, value, copy);
82fa7c8cb4SMatthias Springer   return copy;
83fa7c8cb4SMatthias Springer }
84fa7c8cb4SMatthias Springer 
85d820acddSMatthias Springer /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
86d820acddSMatthias Springer /// to_memref op are different, a memref.cast is needed.
87cb471241SMatthias Springer LogicalResult
foldToMemrefToTensorPair(RewriterBase & rewriter,ToMemrefOp toMemref)88cb471241SMatthias Springer mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
89cb471241SMatthias Springer                                               ToMemrefOp toMemref) {
9099260e95SMatthias Springer   auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
91d820acddSMatthias Springer   if (!memrefToTensor)
92d820acddSMatthias Springer     return failure();
93d820acddSMatthias Springer 
9499260e95SMatthias Springer   Type srcType = memrefToTensor.getMemref().getType();
95d820acddSMatthias Springer   Type destType = toMemref.getType();
96d820acddSMatthias Springer 
97d820acddSMatthias Springer   // Directly rewrite if the type did not change.
98d820acddSMatthias Springer   if (srcType == destType) {
9999260e95SMatthias Springer     rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
100d820acddSMatthias Springer     return success();
101d820acddSMatthias Springer   }
102d820acddSMatthias Springer 
103d820acddSMatthias Springer   auto rankedSrcType = srcType.dyn_cast<MemRefType>();
104d820acddSMatthias Springer   auto rankedDestType = destType.dyn_cast<MemRefType>();
105d820acddSMatthias Springer   auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
106d820acddSMatthias Springer 
107d820acddSMatthias Springer   // Ranked memref -> Ranked memref cast.
108d820acddSMatthias Springer   if (rankedSrcType && rankedDestType) {
109d820acddSMatthias Springer     FailureOr<Value> replacement = castOrReallocMemRefValue(
11099260e95SMatthias Springer         rewriter, memrefToTensor.getMemref(), rankedDestType);
111d820acddSMatthias Springer     if (failed(replacement))
112d820acddSMatthias Springer       return failure();
113d820acddSMatthias Springer 
114d820acddSMatthias Springer     rewriter.replaceOp(toMemref, *replacement);
115d820acddSMatthias Springer     return success();
116d820acddSMatthias Springer   }
117d820acddSMatthias Springer 
118d820acddSMatthias Springer   // Unranked memref -> Ranked memref cast: May require a copy.
119d820acddSMatthias Springer   // TODO: Not implemented at the moment.
120d820acddSMatthias Springer   if (unrankedSrcType && rankedDestType)
121d820acddSMatthias Springer     return failure();
122d820acddSMatthias Springer 
123d820acddSMatthias Springer   // Unranked memref -> unranked memref cast
124d820acddSMatthias Springer   // Ranked memref -> unranked memref cast: No copy needed.
125d820acddSMatthias Springer   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126d820acddSMatthias Springer          "expected that types are cast compatible");
127d820acddSMatthias Springer   rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
12899260e95SMatthias Springer                                               memrefToTensor.getMemref());
129d820acddSMatthias Springer   return success();
130d820acddSMatthias Springer }
131d820acddSMatthias Springer 
populateDynamicDimSizes(OpBuilder & b,Location loc,Value shapedValue,SmallVector<Value> & dynamicDims)132b3ebe3beSMatthias Springer void mlir::bufferization::populateDynamicDimSizes(
133b3ebe3beSMatthias Springer     OpBuilder &b, Location loc, Value shapedValue,
134b3ebe3beSMatthias Springer     SmallVector<Value> &dynamicDims) {
135b3ebe3beSMatthias Springer   auto shapedType = shapedValue.getType().cast<ShapedType>();
136b3ebe3beSMatthias Springer   for (int64_t i = 0; i < shapedType.getRank(); ++i) {
137b3ebe3beSMatthias Springer     if (shapedType.isDynamicDim(i)) {
138b3ebe3beSMatthias Springer       if (shapedType.isa<MemRefType>()) {
139b3ebe3beSMatthias Springer         dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
140b3ebe3beSMatthias Springer       } else {
141b3ebe3beSMatthias Springer         assert(shapedType.isa<RankedTensorType>() && "expected tensor");
142b3ebe3beSMatthias Springer         dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
143b3ebe3beSMatthias Springer       }
144b3ebe3beSMatthias Springer     }
145b3ebe3beSMatthias Springer   }
146b3ebe3beSMatthias Springer }
147b3ebe3beSMatthias Springer 
148fa7c8cb4SMatthias Springer //===----------------------------------------------------------------------===//
149ffdbecccSMatthias Springer // AllocTensorOp
150ffdbecccSMatthias Springer //===----------------------------------------------------------------------===//
151ffdbecccSMatthias Springer 
bufferize(RewriterBase & rewriter,const BufferizationOptions & options)152ffdbecccSMatthias Springer LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
153b55d55ecSMatthias Springer                                        const BufferizationOptions &options) {
154b3ebe3beSMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
1553474d10eSMatthias Springer   Operation *op = this->getOperation();
156b3ebe3beSMatthias Springer   Location loc = getLoc();
157ffdbecccSMatthias Springer 
158b3ebe3beSMatthias Springer   // Nothing to do for dead AllocTensorOps.
159b3ebe3beSMatthias Springer   if (getOperation()->getUses().empty()) {
160b3ebe3beSMatthias Springer     rewriter.eraseOp(getOperation());
161b3ebe3beSMatthias Springer     return success();
162b3ebe3beSMatthias Springer   }
163b3ebe3beSMatthias Springer 
164c06f01ffSMatthias Springer   // Get "copy" buffer.
165b3ebe3beSMatthias Springer   Value copyBuffer;
1665d50f51cSMatthias Springer   if (getCopy()) {
1675d50f51cSMatthias Springer     FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
1685d50f51cSMatthias Springer     if (failed(maybeCopyBuffer))
1695d50f51cSMatthias Springer       return failure();
1705d50f51cSMatthias Springer     copyBuffer = *maybeCopyBuffer;
1715d50f51cSMatthias Springer   }
172c06f01ffSMatthias Springer 
173c06f01ffSMatthias Springer   // Compute memory space of this allocation.
174c06f01ffSMatthias Springer   unsigned memorySpace;
175491d2701SKazu Hirata   if (getMemorySpace().has_value()) {
176c06f01ffSMatthias Springer     memorySpace = *getMemorySpace();
177c0b0b6a0SMatthias Springer   } else if (getCopy()) {
178c0b0b6a0SMatthias Springer     memorySpace =
179c0b0b6a0SMatthias Springer         copyBuffer.getType().cast<BaseMemRefType>().getMemorySpaceAsInt();
180491d2701SKazu Hirata   } else if (options.defaultMemorySpace.has_value()) {
181c06f01ffSMatthias Springer     memorySpace = *options.defaultMemorySpace;
182c06f01ffSMatthias Springer   } else {
183c06f01ffSMatthias Springer     return op->emitError("could not infer memory space");
184c06f01ffSMatthias Springer   }
185c06f01ffSMatthias Springer 
186c06f01ffSMatthias Springer   // Create memory allocation.
187b3ebe3beSMatthias Springer   auto allocType =
188c06f01ffSMatthias Springer       MemRefType::get(getType().getShape(), getType().getElementType(),
189c06f01ffSMatthias Springer                       AffineMap(), memorySpace);
19099260e95SMatthias Springer   SmallVector<Value> dynamicDims = getDynamicSizes();
19199260e95SMatthias Springer   if (getCopy()) {
192b3ebe3beSMatthias Springer     assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
193b3ebe3beSMatthias Springer     populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
194b3ebe3beSMatthias Springer   }
19556d68e8dSMatthias Springer   FailureOr<Value> alloc =
196b55d55ecSMatthias Springer       options.createAlloc(rewriter, loc, allocType, dynamicDims);
197ffdbecccSMatthias Springer   if (failed(alloc))
198ffdbecccSMatthias Springer     return failure();
199b3ebe3beSMatthias Springer 
200b3ebe3beSMatthias Springer   // Create memory copy (if any).
20199260e95SMatthias Springer   if (getCopy()) {
202b55d55ecSMatthias Springer     if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
20356d68e8dSMatthias Springer       return failure();
20456d68e8dSMatthias Springer   }
205b3ebe3beSMatthias Springer 
206b3ebe3beSMatthias Springer   // Should the buffer be deallocated?
207*664ffa46SMatthias Springer   bool dealloc =
208*664ffa46SMatthias Springer       shouldDeallocateOpResult(getResult().cast<OpResult>(), options);
209b3ebe3beSMatthias Springer 
210b3ebe3beSMatthias Springer   // Replace op.
211ffdbecccSMatthias Springer   replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
212b3ebe3beSMatthias Springer 
213b3ebe3beSMatthias Springer   // Create buffer deallocation (if requested).
214b3ebe3beSMatthias Springer   if (!dealloc)
215b3ebe3beSMatthias Springer     return success();
216b3ebe3beSMatthias Springer 
217b3ebe3beSMatthias Springer   rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
218b55d55ecSMatthias Springer   if (failed(options.createDealloc(rewriter, loc, *alloc)))
219b3ebe3beSMatthias Springer     return failure();
220ffdbecccSMatthias Springer   return success();
221ffdbecccSMatthias Springer }
222ffdbecccSMatthias Springer 
isMemoryWrite(OpResult opResult,const AnalysisState & state)22356d68e8dSMatthias Springer bool AllocTensorOp::isMemoryWrite(OpResult opResult,
22456d68e8dSMatthias Springer                                   const AnalysisState &state) {
22556d68e8dSMatthias Springer   // AllocTensorOps do not write unless they have a `copy` value.
22699260e95SMatthias Springer   return static_cast<bool>(getCopy());
22756d68e8dSMatthias Springer }
22856d68e8dSMatthias Springer 
bufferizesToMemoryRead(OpOperand & opOperand,const AnalysisState & state)22956d68e8dSMatthias Springer bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
23056d68e8dSMatthias Springer                                            const AnalysisState &state) {
23156d68e8dSMatthias Springer   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
23256d68e8dSMatthias Springer          "expected copy operand");
23356d68e8dSMatthias Springer   return true;
23456d68e8dSMatthias Springer }
23556d68e8dSMatthias Springer 
bufferizesToMemoryWrite(OpOperand & opOperand,const AnalysisState & state)23656d68e8dSMatthias Springer bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
23756d68e8dSMatthias Springer                                             const AnalysisState &state) {
23856d68e8dSMatthias Springer   assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
23956d68e8dSMatthias Springer          "expected copy operand");
24056d68e8dSMatthias Springer   return false;
24156d68e8dSMatthias Springer }
24256d68e8dSMatthias Springer 
24356d68e8dSMatthias Springer SmallVector<OpResult>
getAliasingOpResult(OpOperand & opOperand,const AnalysisState & state)24456d68e8dSMatthias Springer AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
24556d68e8dSMatthias Springer                                    const AnalysisState &state) {
24656d68e8dSMatthias Springer   // This is a new allocation. It does not alias with any other buffer.
24756d68e8dSMatthias Springer   return {};
24856d68e8dSMatthias Springer }
24956d68e8dSMatthias Springer 
verify()250ffdbecccSMatthias Springer LogicalResult AllocTensorOp::verify() {
25199260e95SMatthias Springer   if (getCopy() && !getDynamicSizes().empty())
25256d68e8dSMatthias Springer     return emitError("dynamic sizes not needed when copying a tensor");
25399260e95SMatthias Springer   if (!getCopy() && getType().getNumDynamicDims() !=
25499260e95SMatthias Springer                         static_cast<int64_t>(getDynamicSizes().size()))
255ec55f0bdSMatthias Springer     return emitError("expected ")
256ec55f0bdSMatthias Springer            << getType().getNumDynamicDims() << " dynamic sizes";
25799260e95SMatthias Springer   if (getCopy() && getCopy().getType() != getType())
25856d68e8dSMatthias Springer     return emitError("expected that `copy` and return type match");
2599a3d60e0SAart Bik 
2609a3d60e0SAart Bik   // For sparse tensor allocation, we require that none of its
2619a3d60e0SAart Bik   // uses escapes the function boundary directly.
2629a3d60e0SAart Bik   if (sparse_tensor::getSparseTensorEncoding(getType())) {
2639a3d60e0SAart Bik     for (auto &use : getOperation()->getUses())
2649a3d60e0SAart Bik       if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
2659a3d60e0SAart Bik               use.getOwner()))
2669a3d60e0SAart Bik         return emitError("sparse tensor allocation should not escape function");
2679a3d60e0SAart Bik   }
2689a3d60e0SAart Bik 
269ffdbecccSMatthias Springer   return success();
270ffdbecccSMatthias Springer }
271ffdbecccSMatthias Springer 
build(OpBuilder & builder,OperationState & result,RankedTensorType type,ValueRange dynamicSizes)27256d68e8dSMatthias Springer void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
27356d68e8dSMatthias Springer                           RankedTensorType type, ValueRange dynamicSizes) {
274c06f01ffSMatthias Springer   build(builder, result, type, dynamicSizes, /*copy=*/Value(),
2750d0a94a7SMatthias Springer         /*memory_space=*/IntegerAttr());
276c06f01ffSMatthias Springer }
277c06f01ffSMatthias Springer 
build(OpBuilder & builder,OperationState & result,RankedTensorType type,ValueRange dynamicSizes,Value copy)278c06f01ffSMatthias Springer void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
279c06f01ffSMatthias Springer                           RankedTensorType type, ValueRange dynamicSizes,
280c06f01ffSMatthias Springer                           Value copy) {
2810d0a94a7SMatthias Springer   build(builder, result, type, dynamicSizes, copy,
2820d0a94a7SMatthias Springer         /*memory_space=*/IntegerAttr());
28356d68e8dSMatthias Springer }
28456d68e8dSMatthias Springer 
285ffdbecccSMatthias Springer namespace {
286ffdbecccSMatthias Springer /// Change the type of the result of a `bufferization.alloc_tensor` by making
287ffdbecccSMatthias Springer /// the result type statically sized along dimension that in the original
288ffdbecccSMatthias Springer /// operation where defined as dynamic, but the size was defined using a
289ffdbecccSMatthias Springer /// `constant` op. For example:
290ffdbecccSMatthias Springer ///
291ffdbecccSMatthias Springer ///  %c5 = arith.constant 5: index
292ec55f0bdSMatthias Springer ///  %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
293ffdbecccSMatthias Springer ///
294ffdbecccSMatthias Springer ///  to
295ffdbecccSMatthias Springer ///
296ec55f0bdSMatthias Springer ///  %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
297ffdbecccSMatthias Springer struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
298ffdbecccSMatthias Springer   using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
299ffdbecccSMatthias Springer 
matchAndRewrite__anon56ded4390311::ReplaceStaticShapeDims300ffdbecccSMatthias Springer   LogicalResult matchAndRewrite(AllocTensorOp op,
301ffdbecccSMatthias Springer                                 PatternRewriter &rewriter) const override {
30299260e95SMatthias Springer     if (op.getCopy())
30356d68e8dSMatthias Springer       return failure();
304ec55f0bdSMatthias Springer     SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
305ec55f0bdSMatthias Springer     SmallVector<Value> newDynamicSizes;
306ec55f0bdSMatthias Springer     unsigned int dynValCounter = 0;
307ec55f0bdSMatthias Springer     for (int64_t i = 0; i < op.getType().getRank(); ++i) {
308ec55f0bdSMatthias Springer       if (!op.isDynamicDim(i))
309ffdbecccSMatthias Springer         continue;
31099260e95SMatthias Springer       Value value = op.getDynamicSizes()[dynValCounter++];
311ec55f0bdSMatthias Springer       APInt intVal;
312ec55f0bdSMatthias Springer       if (matchPattern(value, m_ConstantInt(&intVal))) {
313ec55f0bdSMatthias Springer         newShape[i] = intVal.getSExtValue();
314ec55f0bdSMatthias Springer       } else {
315ec55f0bdSMatthias Springer         newDynamicSizes.push_back(value);
316ffdbecccSMatthias Springer       }
317ffdbecccSMatthias Springer     }
318ec55f0bdSMatthias Springer     RankedTensorType newType = RankedTensorType::get(
319ec55f0bdSMatthias Springer         newShape, op.getType().getElementType(), op.getType().getEncoding());
320ffdbecccSMatthias Springer     if (newType == op.getType())
321ffdbecccSMatthias Springer       return failure();
32256d68e8dSMatthias Springer     auto newOp = rewriter.create<AllocTensorOp>(
3233474d10eSMatthias Springer         op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
324ffdbecccSMatthias Springer     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
325ffdbecccSMatthias Springer     return success();
326ffdbecccSMatthias Springer   }
327ffdbecccSMatthias Springer };
328ffdbecccSMatthias Springer 
329ffdbecccSMatthias Springer struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
330ffdbecccSMatthias Springer   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
331ffdbecccSMatthias Springer 
matchAndRewrite__anon56ded4390311::FoldDimOfAllocTensorOp332ffdbecccSMatthias Springer   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
333ffdbecccSMatthias Springer                                 PatternRewriter &rewriter) const override {
334ffdbecccSMatthias Springer     Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
33504235d07SJacques Pienaar     auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
336ffdbecccSMatthias Springer     if (!allocTensorOp || !maybeConstantIndex)
337ffdbecccSMatthias Springer       return failure();
338ec55f0bdSMatthias Springer     if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
339ffdbecccSMatthias Springer       return failure();
34056d68e8dSMatthias Springer     rewriter.replaceOp(
34156d68e8dSMatthias Springer         dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
342ffdbecccSMatthias Springer     return success();
343ffdbecccSMatthias Springer   }
344ffdbecccSMatthias Springer };
345ffdbecccSMatthias Springer } // namespace
346ffdbecccSMatthias Springer 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * ctx)347ffdbecccSMatthias Springer void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
348ffdbecccSMatthias Springer                                                 MLIRContext *ctx) {
349ffdbecccSMatthias Springer   results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
350ffdbecccSMatthias Springer }
351ffdbecccSMatthias Springer 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)352ffdbecccSMatthias Springer LogicalResult AllocTensorOp::reifyResultShapes(
353ffdbecccSMatthias Springer     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
354ffdbecccSMatthias Springer   auto shapes = llvm::to_vector<4>(llvm::map_range(
355ffdbecccSMatthias Springer       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
356ec55f0bdSMatthias Springer         if (isDynamicDim(dim))
35756d68e8dSMatthias Springer           return getDynamicSize(builder, dim);
358ffdbecccSMatthias Springer         return builder.create<arith::ConstantIndexOp>(getLoc(),
359ffdbecccSMatthias Springer                                                       getStaticSize(dim));
360ffdbecccSMatthias Springer       }));
361ffdbecccSMatthias Springer   reifiedReturnShapes.emplace_back(std::move(shapes));
362ffdbecccSMatthias Springer   return success();
363ffdbecccSMatthias Springer }
364ffdbecccSMatthias Springer 
parse(OpAsmParser & parser,OperationState & result)36556d68e8dSMatthias Springer ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
36656d68e8dSMatthias Springer   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
36756d68e8dSMatthias Springer   if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
36856d68e8dSMatthias Springer       parser.parseRParen())
36956d68e8dSMatthias Springer     return failure();
37056d68e8dSMatthias Springer   ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
37156d68e8dSMatthias Springer   OpAsmParser::UnresolvedOperand copyOperand;
37256d68e8dSMatthias Springer   if (copyKeyword.succeeded())
37356d68e8dSMatthias Springer     if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
37456d68e8dSMatthias Springer         parser.parseRParen())
37556d68e8dSMatthias Springer       return failure();
37656d68e8dSMatthias Springer   if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
37756d68e8dSMatthias Springer     return failure();
37856d68e8dSMatthias Springer 
37956d68e8dSMatthias Springer   TensorType type;
38056d68e8dSMatthias Springer   if (parser.parseCustomTypeWithFallback(type))
38156d68e8dSMatthias Springer     return failure();
38256d68e8dSMatthias Springer   result.addTypes(type);
38356d68e8dSMatthias Springer 
38456d68e8dSMatthias Springer   Type indexType = parser.getBuilder().getIndexType();
38556d68e8dSMatthias Springer   if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
38656d68e8dSMatthias Springer     return failure();
38756d68e8dSMatthias Springer   if (copyKeyword.succeeded())
38856d68e8dSMatthias Springer     if (parser.resolveOperand(copyOperand, type, result.operands))
38956d68e8dSMatthias Springer       return failure();
39056d68e8dSMatthias Springer   result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
39156d68e8dSMatthias Springer                       parser.getBuilder().getI32VectorAttr(
39256d68e8dSMatthias Springer                           {static_cast<int32_t>(dynamicSizesOperands.size()),
39356d68e8dSMatthias Springer                            static_cast<int32_t>(copyKeyword.succeeded())}));
39456d68e8dSMatthias Springer   return success();
39556d68e8dSMatthias Springer }
39656d68e8dSMatthias Springer 
print(OpAsmPrinter & p)39756d68e8dSMatthias Springer void AllocTensorOp::print(OpAsmPrinter &p) {
39899260e95SMatthias Springer   p << "(" << getDynamicSizes() << ")";
39999260e95SMatthias Springer   if (getCopy())
40099260e95SMatthias Springer     p << " copy(" << getCopy() << ")";
40156d68e8dSMatthias Springer   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
40256d68e8dSMatthias Springer                               AllocTensorOp::getOperandSegmentSizeAttr()});
40356d68e8dSMatthias Springer   p << " : ";
40499260e95SMatthias Springer   auto type = getResult().getType();
40556d68e8dSMatthias Springer   if (auto validType = type.dyn_cast<::mlir::TensorType>())
40656d68e8dSMatthias Springer     p.printStrippedAttrOrType(validType);
40756d68e8dSMatthias Springer   else
40856d68e8dSMatthias Springer     p << type;
40956d68e8dSMatthias Springer }
41056d68e8dSMatthias Springer 
getDynamicSize(OpBuilder & b,unsigned idx)41156d68e8dSMatthias Springer Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
41256d68e8dSMatthias Springer   assert(isDynamicDim(idx) && "expected dynamic dim");
41399260e95SMatthias Springer   if (getCopy())
41499260e95SMatthias Springer     return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
41556d68e8dSMatthias Springer   return getOperand(getIndexOfDynamicSize(idx));
41656d68e8dSMatthias Springer }
41756d68e8dSMatthias Springer 
418ffdbecccSMatthias Springer //===----------------------------------------------------------------------===//
41957470abcSAlexander Belyaev // CloneOp
42057470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
42157470abcSAlexander Belyaev 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)42257470abcSAlexander Belyaev void CloneOp::getEffects(
42357470abcSAlexander Belyaev     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
42457470abcSAlexander Belyaev         &effects) {
42599260e95SMatthias Springer   effects.emplace_back(MemoryEffects::Read::get(), getInput(),
42657470abcSAlexander Belyaev                        SideEffects::DefaultResource::get());
42799260e95SMatthias Springer   effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
42857470abcSAlexander Belyaev                        SideEffects::DefaultResource::get());
42999260e95SMatthias Springer   effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
43057470abcSAlexander Belyaev                        SideEffects::DefaultResource::get());
43157470abcSAlexander Belyaev }
43257470abcSAlexander Belyaev 
fold(ArrayRef<Attribute> operands)43357470abcSAlexander Belyaev OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
43457470abcSAlexander Belyaev   return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
43557470abcSAlexander Belyaev }
43657470abcSAlexander Belyaev 
43757470abcSAlexander Belyaev namespace {
43857470abcSAlexander Belyaev 
43957470abcSAlexander Belyaev /// Merge the clone and its source (by converting the clone to a cast) when
44057470abcSAlexander Belyaev /// possible.
44157470abcSAlexander Belyaev struct SimplifyClones : public OpRewritePattern<CloneOp> {
44257470abcSAlexander Belyaev   using OpRewritePattern<CloneOp>::OpRewritePattern;
44357470abcSAlexander Belyaev 
matchAndRewrite__anon56ded4390511::SimplifyClones44457470abcSAlexander Belyaev   LogicalResult matchAndRewrite(CloneOp cloneOp,
44557470abcSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
44657470abcSAlexander Belyaev     if (cloneOp.use_empty()) {
44757470abcSAlexander Belyaev       rewriter.eraseOp(cloneOp);
44857470abcSAlexander Belyaev       return success();
44957470abcSAlexander Belyaev     }
45057470abcSAlexander Belyaev 
45199260e95SMatthias Springer     Value source = cloneOp.getInput();
45257470abcSAlexander Belyaev 
45357470abcSAlexander Belyaev     // This only finds dealloc operations for the immediate value. It should
45457470abcSAlexander Belyaev     // also consider aliases. That would also make the safety check below
45557470abcSAlexander Belyaev     // redundant.
45657470abcSAlexander Belyaev     llvm::Optional<Operation *> maybeCloneDeallocOp =
45799260e95SMatthias Springer         memref::findDealloc(cloneOp.getOutput());
45857470abcSAlexander Belyaev     // Skip if either of them has > 1 deallocate operations.
459491d2701SKazu Hirata     if (!maybeCloneDeallocOp.has_value())
46057470abcSAlexander Belyaev       return failure();
461af9f7d31SUday Bondhugula     llvm::Optional<Operation *> maybeSourceDeallocOp =
462af9f7d31SUday Bondhugula         memref::findDealloc(source);
463491d2701SKazu Hirata     if (!maybeSourceDeallocOp.has_value())
46457470abcSAlexander Belyaev       return failure();
46557470abcSAlexander Belyaev     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
46657470abcSAlexander Belyaev     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
46757470abcSAlexander Belyaev 
46857470abcSAlexander Belyaev     // If both are deallocated in the same block, their in-block lifetimes
46957470abcSAlexander Belyaev     // might not fully overlap, so we cannot decide which one to drop.
47057470abcSAlexander Belyaev     if (cloneDeallocOp && sourceDeallocOp &&
47157470abcSAlexander Belyaev         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
47257470abcSAlexander Belyaev       return failure();
47357470abcSAlexander Belyaev 
47457470abcSAlexander Belyaev     Block *currentBlock = cloneOp->getBlock();
47557470abcSAlexander Belyaev     Operation *redundantDealloc = nullptr;
47657470abcSAlexander Belyaev     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
47757470abcSAlexander Belyaev       redundantDealloc = cloneDeallocOp;
47857470abcSAlexander Belyaev     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
47957470abcSAlexander Belyaev       redundantDealloc = sourceDeallocOp;
48057470abcSAlexander Belyaev     }
48157470abcSAlexander Belyaev 
48257470abcSAlexander Belyaev     if (!redundantDealloc)
48357470abcSAlexander Belyaev       return failure();
48457470abcSAlexander Belyaev 
48557470abcSAlexander Belyaev     // Safety check that there are no other deallocations inbetween
48657470abcSAlexander Belyaev     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
48757470abcSAlexander Belyaev     // of source before the uses of the clone. With alias information, we could
48857470abcSAlexander Belyaev     // restrict this to only fail of the dealloc's operand is an alias
48957470abcSAlexander Belyaev     // of the source.
49057470abcSAlexander Belyaev     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
49157470abcSAlexander Belyaev          pos = pos->getNextNode()) {
49257470abcSAlexander Belyaev       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
49357470abcSAlexander Belyaev       if (!effectInterface)
49457470abcSAlexander Belyaev         continue;
49557470abcSAlexander Belyaev       if (effectInterface.hasEffect<MemoryEffects::Free>())
49657470abcSAlexander Belyaev         return failure();
49757470abcSAlexander Belyaev     }
49857470abcSAlexander Belyaev 
49957470abcSAlexander Belyaev     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
50057470abcSAlexander Belyaev                                                 source);
50157470abcSAlexander Belyaev     rewriter.eraseOp(redundantDealloc);
50257470abcSAlexander Belyaev     return success();
50357470abcSAlexander Belyaev   }
50457470abcSAlexander Belyaev };
50557470abcSAlexander Belyaev 
506be0a7e9fSMehdi Amini } // namespace
50757470abcSAlexander Belyaev 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)5089f85c198SRiver Riddle void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
50957470abcSAlexander Belyaev                                           MLIRContext *context) {
510b4e0507cSTres Popp   results.add<SimplifyClones>(context);
51157470abcSAlexander Belyaev }
51257470abcSAlexander Belyaev 
51357470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
51427a431f5SMatthias Springer // DeallocTensorOp
51527a431f5SMatthias Springer //===----------------------------------------------------------------------===//
51627a431f5SMatthias Springer 
bufferize(RewriterBase & rewriter,const BufferizationOptions & options)51727a431f5SMatthias Springer LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
51827a431f5SMatthias Springer                                          const BufferizationOptions &options) {
51927a431f5SMatthias Springer   FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
52027a431f5SMatthias Springer   if (failed(buffer))
52127a431f5SMatthias Springer     return failure();
52227a431f5SMatthias Springer   if (failed(options.createDealloc(rewriter, getLoc(), *buffer)))
52327a431f5SMatthias Springer     return failure();
52427a431f5SMatthias Springer   rewriter.eraseOp(getOperation());
52527a431f5SMatthias Springer   return success();
52627a431f5SMatthias Springer }
52727a431f5SMatthias Springer 
52827a431f5SMatthias Springer //===----------------------------------------------------------------------===//
52957470abcSAlexander Belyaev // ToTensorOp
53057470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
53157470abcSAlexander Belyaev 
fold(ArrayRef<Attribute>)53257470abcSAlexander Belyaev OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
53399260e95SMatthias Springer   if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
53457470abcSAlexander Belyaev     // Approximate alias analysis by conservatively folding only when no there
53557470abcSAlexander Belyaev     // is no interleaved operation.
53657470abcSAlexander Belyaev     if (toMemref->getBlock() == this->getOperation()->getBlock() &&
53757470abcSAlexander Belyaev         toMemref->getNextNode() == this->getOperation())
53899260e95SMatthias Springer       return toMemref.getTensor();
53957470abcSAlexander Belyaev   return {};
54057470abcSAlexander Belyaev }
54157470abcSAlexander Belyaev 
54257470abcSAlexander Belyaev namespace {
54357470abcSAlexander Belyaev struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
54457470abcSAlexander Belyaev   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
54557470abcSAlexander Belyaev 
matchAndRewrite__anon56ded4390611::DimOfToTensorFolder54657470abcSAlexander Belyaev   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
54757470abcSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
54804235d07SJacques Pienaar     auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
54957470abcSAlexander Belyaev     if (!memrefToTensorOp)
55057470abcSAlexander Belyaev       return failure();
55157470abcSAlexander Belyaev 
55299260e95SMatthias Springer     rewriter.replaceOpWithNewOp<memref::DimOp>(
55304235d07SJacques Pienaar         dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
55457470abcSAlexander Belyaev     return success();
55557470abcSAlexander Belyaev   }
55657470abcSAlexander Belyaev };
55757470abcSAlexander Belyaev } // namespace
55857470abcSAlexander Belyaev 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)55957470abcSAlexander Belyaev void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
56057470abcSAlexander Belyaev                                              MLIRContext *context) {
561fc9b37ddSMatthias Springer   results.add<DimOfToTensorFolder>(context);
56257470abcSAlexander Belyaev }
56357470abcSAlexander Belyaev 
56457470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
56557470abcSAlexander Belyaev // ToMemrefOp
56657470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
56757470abcSAlexander Belyaev 
fold(ArrayRef<Attribute>)56857470abcSAlexander Belyaev OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
56999260e95SMatthias Springer   if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
57099260e95SMatthias Springer     if (memrefToTensor.getMemref().getType() == getType())
57199260e95SMatthias Springer       return memrefToTensor.getMemref();
57257470abcSAlexander Belyaev   return {};
57357470abcSAlexander Belyaev }
57457470abcSAlexander Belyaev 
57557470abcSAlexander Belyaev namespace {
57657470abcSAlexander Belyaev 
57757470abcSAlexander Belyaev /// Replace tensor.cast + to_memref by to_memref + memref.cast.
57857470abcSAlexander Belyaev struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
57957470abcSAlexander Belyaev   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
58057470abcSAlexander Belyaev 
matchAndRewrite__anon56ded4390711::ToMemrefOfCast58157470abcSAlexander Belyaev   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
58257470abcSAlexander Belyaev                                 PatternRewriter &rewriter) const final {
58357470abcSAlexander Belyaev     auto tensorCastOperand =
58457470abcSAlexander Belyaev         toMemref.getOperand().getDefiningOp<tensor::CastOp>();
58557470abcSAlexander Belyaev     if (!tensorCastOperand)
58657470abcSAlexander Belyaev       return failure();
58757470abcSAlexander Belyaev     auto srcTensorType =
58857470abcSAlexander Belyaev         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
58957470abcSAlexander Belyaev     if (!srcTensorType)
59057470abcSAlexander Belyaev       return failure();
59157470abcSAlexander Belyaev     auto memrefType = MemRefType::get(srcTensorType.getShape(),
59257470abcSAlexander Belyaev                                       srcTensorType.getElementType());
59357470abcSAlexander Belyaev     Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
59457470abcSAlexander Belyaev                                                tensorCastOperand.getOperand());
59557470abcSAlexander Belyaev     rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
59657470abcSAlexander Belyaev                                                 memref);
59757470abcSAlexander Belyaev     return success();
59857470abcSAlexander Belyaev   }
59957470abcSAlexander Belyaev };
60057470abcSAlexander Belyaev 
601cb471241SMatthias Springer /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
602cb471241SMatthias Springer /// cast if necessary.
603cb471241SMatthias Springer struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
604b00ee46bSMatthias Springer   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
605b00ee46bSMatthias Springer 
matchAndRewrite__anon56ded4390711::ToMemrefToTensorFolding606b00ee46bSMatthias Springer   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
607b00ee46bSMatthias Springer                                 PatternRewriter &rewriter) const final {
608cb471241SMatthias Springer     return foldToMemrefToTensorPair(rewriter, toMemref);
609b00ee46bSMatthias Springer   }
61057470abcSAlexander Belyaev };
61157470abcSAlexander Belyaev 
61257470abcSAlexander Belyaev /// Fold a load on a to_memref operation into an tensor.extract on the
61357470abcSAlexander Belyaev /// corresponding tensor.
61457470abcSAlexander Belyaev struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
61557470abcSAlexander Belyaev   using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
61657470abcSAlexander Belyaev 
matchAndRewrite__anon56ded4390711::LoadOfToMemref61757470abcSAlexander Belyaev   LogicalResult matchAndRewrite(memref::LoadOp load,
61857470abcSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
619136d746eSJacques Pienaar     auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
62057470abcSAlexander Belyaev     if (!toMemref)
62157470abcSAlexander Belyaev       return failure();
62257470abcSAlexander Belyaev 
62399260e95SMatthias Springer     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
624136d746eSJacques Pienaar                                                    load.getIndices());
62557470abcSAlexander Belyaev     return success();
62657470abcSAlexander Belyaev   }
62757470abcSAlexander Belyaev };
62857470abcSAlexander Belyaev 
62957470abcSAlexander Belyaev /// Fold dim of a to_memref into the dim of the tensor.
63057470abcSAlexander Belyaev struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
63157470abcSAlexander Belyaev   using OpRewritePattern<memref::DimOp>::OpRewritePattern;
63257470abcSAlexander Belyaev 
matchAndRewrite__anon56ded4390711::DimOfCastOp63357470abcSAlexander Belyaev   LogicalResult matchAndRewrite(memref::DimOp dimOp,
63457470abcSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
635136d746eSJacques Pienaar     auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
63657470abcSAlexander Belyaev     if (!castOp)
63757470abcSAlexander Belyaev       return failure();
63857470abcSAlexander Belyaev     Value newSource = castOp.getOperand();
639136d746eSJacques Pienaar     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
640136d746eSJacques Pienaar                                                dimOp.getIndex());
64157470abcSAlexander Belyaev     return success();
64257470abcSAlexander Belyaev   }
64357470abcSAlexander Belyaev };
64457470abcSAlexander Belyaev 
64557470abcSAlexander Belyaev } // namespace
64657470abcSAlexander Belyaev 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)64757470abcSAlexander Belyaev void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
64857470abcSAlexander Belyaev                                              MLIRContext *context) {
649cb471241SMatthias Springer   results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
650cb471241SMatthias Springer               ToMemrefToTensorFolding>(context);
65157470abcSAlexander Belyaev }
65257470abcSAlexander Belyaev 
bufferize(RewriterBase & rewriter,const BufferizationOptions & options)653b00ee46bSMatthias Springer LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
654b55d55ecSMatthias Springer                                     const BufferizationOptions &options) {
655b00ee46bSMatthias Springer   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
6560b293bf0SMatthias Springer   (void)foldToMemrefToTensorPair(rewriter, *this);
6570b293bf0SMatthias Springer   // Note: The return value of `bufferize` indicates whether there was an error
6580b293bf0SMatthias Springer   // or not. (And not whether the pattern matched or not.)
6590b293bf0SMatthias Springer   return success();
660b00ee46bSMatthias Springer }
661b00ee46bSMatthias Springer 
buildDealloc(OpBuilder & builder,Value alloc)66257470abcSAlexander Belyaev Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
66357470abcSAlexander Belyaev   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
66457470abcSAlexander Belyaev       .getOperation();
66557470abcSAlexander Belyaev }
66657470abcSAlexander Belyaev 
buildClone(OpBuilder & builder,Value alloc)66757470abcSAlexander Belyaev Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
66857470abcSAlexander Belyaev   return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
66957470abcSAlexander Belyaev }
67057470abcSAlexander Belyaev 
67157470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
67257470abcSAlexander Belyaev // TableGen'd op method definitions
67357470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
67457470abcSAlexander Belyaev 
67557470abcSAlexander Belyaev #define GET_OP_CLASSES
67657470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
677