157470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
257470abcSAlexander Belyaev //
357470abcSAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
457470abcSAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
557470abcSAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
657470abcSAlexander Belyaev //
757470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
857470abcSAlexander Belyaev
9eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1157470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
129a3d60e0SAart Bik #include "mlir/Dialect/Func/IR/FuncOps.h"
13eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
1457470abcSAlexander Belyaev #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
159a3d60e0SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
16eda6f907SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h"
17ec55f0bdSMatthias Springer #include "mlir/IR/Matchers.h"
1857470abcSAlexander Belyaev
1957470abcSAlexander Belyaev using namespace mlir;
2057470abcSAlexander Belyaev using namespace mlir::bufferization;
2157470abcSAlexander Belyaev
2257470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
23fa7c8cb4SMatthias Springer // Helper functions
24fa7c8cb4SMatthias Springer //===----------------------------------------------------------------------===//
25fa7c8cb4SMatthias Springer
26fa7c8cb4SMatthias Springer FailureOr<Value>
castOrReallocMemRefValue(OpBuilder & b,Value value,MemRefType destType)27fa7c8cb4SMatthias Springer mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
28fa7c8cb4SMatthias Springer MemRefType destType) {
29fa7c8cb4SMatthias Springer auto srcType = value.getType().cast<MemRefType>();
30fa7c8cb4SMatthias Springer
31fa7c8cb4SMatthias Springer // Element type, rank and memory space must match.
32fa7c8cb4SMatthias Springer if (srcType.getElementType() != destType.getElementType())
33fa7c8cb4SMatthias Springer return failure();
34fa7c8cb4SMatthias Springer if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
35fa7c8cb4SMatthias Springer return failure();
36fa7c8cb4SMatthias Springer if (srcType.getRank() != destType.getRank())
37fa7c8cb4SMatthias Springer return failure();
38fa7c8cb4SMatthias Springer
39fa7c8cb4SMatthias Springer // In case the affine maps are different, we may need to use a copy if we go
40fa7c8cb4SMatthias Springer // from dynamic to static offset or stride (the canonicalization cannot know
41fa7c8cb4SMatthias Springer // at this point that it is really cast compatible).
42fa7c8cb4SMatthias Springer auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43fa7c8cb4SMatthias Springer int64_t sourceOffset, targetOffset;
44fa7c8cb4SMatthias Springer SmallVector<int64_t, 4> sourceStrides, targetStrides;
45fa7c8cb4SMatthias Springer if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46fa7c8cb4SMatthias Springer failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47fa7c8cb4SMatthias Springer return false;
48fa7c8cb4SMatthias Springer auto dynamicToStatic = [](int64_t a, int64_t b) {
49fa7c8cb4SMatthias Springer return a == MemRefType::getDynamicStrideOrOffset() &&
50fa7c8cb4SMatthias Springer b != MemRefType::getDynamicStrideOrOffset();
51fa7c8cb4SMatthias Springer };
52fa7c8cb4SMatthias Springer if (dynamicToStatic(sourceOffset, targetOffset))
53fa7c8cb4SMatthias Springer return false;
54fa7c8cb4SMatthias Springer for (auto it : zip(sourceStrides, targetStrides))
55fa7c8cb4SMatthias Springer if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
56fa7c8cb4SMatthias Springer return false;
57fa7c8cb4SMatthias Springer return true;
58fa7c8cb4SMatthias Springer };
59fa7c8cb4SMatthias Springer
60fa7c8cb4SMatthias Springer // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
61fa7c8cb4SMatthias Springer // ensure that we only generate casts that always succeed at runtime, we check
62fa7c8cb4SMatthias Springer // a fix extra conditions in `isGuaranteedCastCompatible`.
63fa7c8cb4SMatthias Springer if (memref::CastOp::areCastCompatible(srcType, destType) &&
64fa7c8cb4SMatthias Springer isGuaranteedCastCompatible(srcType, destType)) {
65fa7c8cb4SMatthias Springer Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
66fa7c8cb4SMatthias Springer return casted;
67fa7c8cb4SMatthias Springer }
68fa7c8cb4SMatthias Springer
69fa7c8cb4SMatthias Springer auto loc = value.getLoc();
70fa7c8cb4SMatthias Springer SmallVector<Value, 4> dynamicOperands;
71fa7c8cb4SMatthias Springer for (int i = 0; i < destType.getRank(); ++i) {
72fa7c8cb4SMatthias Springer if (destType.getShape()[i] != ShapedType::kDynamicSize)
73fa7c8cb4SMatthias Springer continue;
74fa7c8cb4SMatthias Springer auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
75fa7c8cb4SMatthias Springer Value size = b.create<memref::DimOp>(loc, value, index);
76fa7c8cb4SMatthias Springer dynamicOperands.push_back(size);
77fa7c8cb4SMatthias Springer }
78fa7c8cb4SMatthias Springer // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
79fa7c8cb4SMatthias Springer // BufferizableOpInterface impl of ToMemrefOp.
80fa7c8cb4SMatthias Springer Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
81fa7c8cb4SMatthias Springer b.create<memref::CopyOp>(loc, value, copy);
82fa7c8cb4SMatthias Springer return copy;
83fa7c8cb4SMatthias Springer }
84fa7c8cb4SMatthias Springer
85d820acddSMatthias Springer /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
86d820acddSMatthias Springer /// to_memref op are different, a memref.cast is needed.
87cb471241SMatthias Springer LogicalResult
foldToMemrefToTensorPair(RewriterBase & rewriter,ToMemrefOp toMemref)88cb471241SMatthias Springer mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
89cb471241SMatthias Springer ToMemrefOp toMemref) {
9099260e95SMatthias Springer auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
91d820acddSMatthias Springer if (!memrefToTensor)
92d820acddSMatthias Springer return failure();
93d820acddSMatthias Springer
9499260e95SMatthias Springer Type srcType = memrefToTensor.getMemref().getType();
95d820acddSMatthias Springer Type destType = toMemref.getType();
96d820acddSMatthias Springer
97d820acddSMatthias Springer // Directly rewrite if the type did not change.
98d820acddSMatthias Springer if (srcType == destType) {
9999260e95SMatthias Springer rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
100d820acddSMatthias Springer return success();
101d820acddSMatthias Springer }
102d820acddSMatthias Springer
103d820acddSMatthias Springer auto rankedSrcType = srcType.dyn_cast<MemRefType>();
104d820acddSMatthias Springer auto rankedDestType = destType.dyn_cast<MemRefType>();
105d820acddSMatthias Springer auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
106d820acddSMatthias Springer
107d820acddSMatthias Springer // Ranked memref -> Ranked memref cast.
108d820acddSMatthias Springer if (rankedSrcType && rankedDestType) {
109d820acddSMatthias Springer FailureOr<Value> replacement = castOrReallocMemRefValue(
11099260e95SMatthias Springer rewriter, memrefToTensor.getMemref(), rankedDestType);
111d820acddSMatthias Springer if (failed(replacement))
112d820acddSMatthias Springer return failure();
113d820acddSMatthias Springer
114d820acddSMatthias Springer rewriter.replaceOp(toMemref, *replacement);
115d820acddSMatthias Springer return success();
116d820acddSMatthias Springer }
117d820acddSMatthias Springer
118d820acddSMatthias Springer // Unranked memref -> Ranked memref cast: May require a copy.
119d820acddSMatthias Springer // TODO: Not implemented at the moment.
120d820acddSMatthias Springer if (unrankedSrcType && rankedDestType)
121d820acddSMatthias Springer return failure();
122d820acddSMatthias Springer
123d820acddSMatthias Springer // Unranked memref -> unranked memref cast
124d820acddSMatthias Springer // Ranked memref -> unranked memref cast: No copy needed.
125d820acddSMatthias Springer assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126d820acddSMatthias Springer "expected that types are cast compatible");
127d820acddSMatthias Springer rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
12899260e95SMatthias Springer memrefToTensor.getMemref());
129d820acddSMatthias Springer return success();
130d820acddSMatthias Springer }
131d820acddSMatthias Springer
populateDynamicDimSizes(OpBuilder & b,Location loc,Value shapedValue,SmallVector<Value> & dynamicDims)132b3ebe3beSMatthias Springer void mlir::bufferization::populateDynamicDimSizes(
133b3ebe3beSMatthias Springer OpBuilder &b, Location loc, Value shapedValue,
134b3ebe3beSMatthias Springer SmallVector<Value> &dynamicDims) {
135b3ebe3beSMatthias Springer auto shapedType = shapedValue.getType().cast<ShapedType>();
136b3ebe3beSMatthias Springer for (int64_t i = 0; i < shapedType.getRank(); ++i) {
137b3ebe3beSMatthias Springer if (shapedType.isDynamicDim(i)) {
138b3ebe3beSMatthias Springer if (shapedType.isa<MemRefType>()) {
139b3ebe3beSMatthias Springer dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
140b3ebe3beSMatthias Springer } else {
141b3ebe3beSMatthias Springer assert(shapedType.isa<RankedTensorType>() && "expected tensor");
142b3ebe3beSMatthias Springer dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
143b3ebe3beSMatthias Springer }
144b3ebe3beSMatthias Springer }
145b3ebe3beSMatthias Springer }
146b3ebe3beSMatthias Springer }
147b3ebe3beSMatthias Springer
148fa7c8cb4SMatthias Springer //===----------------------------------------------------------------------===//
149ffdbecccSMatthias Springer // AllocTensorOp
150ffdbecccSMatthias Springer //===----------------------------------------------------------------------===//
151ffdbecccSMatthias Springer
bufferize(RewriterBase & rewriter,const BufferizationOptions & options)152ffdbecccSMatthias Springer LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
153b55d55ecSMatthias Springer const BufferizationOptions &options) {
154b3ebe3beSMatthias Springer OpBuilder::InsertionGuard g(rewriter);
1553474d10eSMatthias Springer Operation *op = this->getOperation();
156b3ebe3beSMatthias Springer Location loc = getLoc();
157ffdbecccSMatthias Springer
158b3ebe3beSMatthias Springer // Nothing to do for dead AllocTensorOps.
159b3ebe3beSMatthias Springer if (getOperation()->getUses().empty()) {
160b3ebe3beSMatthias Springer rewriter.eraseOp(getOperation());
161b3ebe3beSMatthias Springer return success();
162b3ebe3beSMatthias Springer }
163b3ebe3beSMatthias Springer
164c06f01ffSMatthias Springer // Get "copy" buffer.
165b3ebe3beSMatthias Springer Value copyBuffer;
1665d50f51cSMatthias Springer if (getCopy()) {
1675d50f51cSMatthias Springer FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
1685d50f51cSMatthias Springer if (failed(maybeCopyBuffer))
1695d50f51cSMatthias Springer return failure();
1705d50f51cSMatthias Springer copyBuffer = *maybeCopyBuffer;
1715d50f51cSMatthias Springer }
172c06f01ffSMatthias Springer
173c06f01ffSMatthias Springer // Compute memory space of this allocation.
174c06f01ffSMatthias Springer unsigned memorySpace;
175491d2701SKazu Hirata if (getMemorySpace().has_value()) {
176c06f01ffSMatthias Springer memorySpace = *getMemorySpace();
177c0b0b6a0SMatthias Springer } else if (getCopy()) {
178c0b0b6a0SMatthias Springer memorySpace =
179c0b0b6a0SMatthias Springer copyBuffer.getType().cast<BaseMemRefType>().getMemorySpaceAsInt();
180491d2701SKazu Hirata } else if (options.defaultMemorySpace.has_value()) {
181c06f01ffSMatthias Springer memorySpace = *options.defaultMemorySpace;
182c06f01ffSMatthias Springer } else {
183c06f01ffSMatthias Springer return op->emitError("could not infer memory space");
184c06f01ffSMatthias Springer }
185c06f01ffSMatthias Springer
186c06f01ffSMatthias Springer // Create memory allocation.
187b3ebe3beSMatthias Springer auto allocType =
188c06f01ffSMatthias Springer MemRefType::get(getType().getShape(), getType().getElementType(),
189c06f01ffSMatthias Springer AffineMap(), memorySpace);
19099260e95SMatthias Springer SmallVector<Value> dynamicDims = getDynamicSizes();
19199260e95SMatthias Springer if (getCopy()) {
192b3ebe3beSMatthias Springer assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
193b3ebe3beSMatthias Springer populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
194b3ebe3beSMatthias Springer }
19556d68e8dSMatthias Springer FailureOr<Value> alloc =
196b55d55ecSMatthias Springer options.createAlloc(rewriter, loc, allocType, dynamicDims);
197ffdbecccSMatthias Springer if (failed(alloc))
198ffdbecccSMatthias Springer return failure();
199b3ebe3beSMatthias Springer
200b3ebe3beSMatthias Springer // Create memory copy (if any).
20199260e95SMatthias Springer if (getCopy()) {
202b55d55ecSMatthias Springer if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
20356d68e8dSMatthias Springer return failure();
20456d68e8dSMatthias Springer }
205b3ebe3beSMatthias Springer
206b3ebe3beSMatthias Springer // Should the buffer be deallocated?
207*664ffa46SMatthias Springer bool dealloc =
208*664ffa46SMatthias Springer shouldDeallocateOpResult(getResult().cast<OpResult>(), options);
209b3ebe3beSMatthias Springer
210b3ebe3beSMatthias Springer // Replace op.
211ffdbecccSMatthias Springer replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
212b3ebe3beSMatthias Springer
213b3ebe3beSMatthias Springer // Create buffer deallocation (if requested).
214b3ebe3beSMatthias Springer if (!dealloc)
215b3ebe3beSMatthias Springer return success();
216b3ebe3beSMatthias Springer
217b3ebe3beSMatthias Springer rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
218b55d55ecSMatthias Springer if (failed(options.createDealloc(rewriter, loc, *alloc)))
219b3ebe3beSMatthias Springer return failure();
220ffdbecccSMatthias Springer return success();
221ffdbecccSMatthias Springer }
222ffdbecccSMatthias Springer
isMemoryWrite(OpResult opResult,const AnalysisState & state)22356d68e8dSMatthias Springer bool AllocTensorOp::isMemoryWrite(OpResult opResult,
22456d68e8dSMatthias Springer const AnalysisState &state) {
22556d68e8dSMatthias Springer // AllocTensorOps do not write unless they have a `copy` value.
22699260e95SMatthias Springer return static_cast<bool>(getCopy());
22756d68e8dSMatthias Springer }
22856d68e8dSMatthias Springer
bufferizesToMemoryRead(OpOperand & opOperand,const AnalysisState & state)22956d68e8dSMatthias Springer bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
23056d68e8dSMatthias Springer const AnalysisState &state) {
23156d68e8dSMatthias Springer assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
23256d68e8dSMatthias Springer "expected copy operand");
23356d68e8dSMatthias Springer return true;
23456d68e8dSMatthias Springer }
23556d68e8dSMatthias Springer
bufferizesToMemoryWrite(OpOperand & opOperand,const AnalysisState & state)23656d68e8dSMatthias Springer bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
23756d68e8dSMatthias Springer const AnalysisState &state) {
23856d68e8dSMatthias Springer assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
23956d68e8dSMatthias Springer "expected copy operand");
24056d68e8dSMatthias Springer return false;
24156d68e8dSMatthias Springer }
24256d68e8dSMatthias Springer
24356d68e8dSMatthias Springer SmallVector<OpResult>
getAliasingOpResult(OpOperand & opOperand,const AnalysisState & state)24456d68e8dSMatthias Springer AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
24556d68e8dSMatthias Springer const AnalysisState &state) {
24656d68e8dSMatthias Springer // This is a new allocation. It does not alias with any other buffer.
24756d68e8dSMatthias Springer return {};
24856d68e8dSMatthias Springer }
24956d68e8dSMatthias Springer
verify()250ffdbecccSMatthias Springer LogicalResult AllocTensorOp::verify() {
25199260e95SMatthias Springer if (getCopy() && !getDynamicSizes().empty())
25256d68e8dSMatthias Springer return emitError("dynamic sizes not needed when copying a tensor");
25399260e95SMatthias Springer if (!getCopy() && getType().getNumDynamicDims() !=
25499260e95SMatthias Springer static_cast<int64_t>(getDynamicSizes().size()))
255ec55f0bdSMatthias Springer return emitError("expected ")
256ec55f0bdSMatthias Springer << getType().getNumDynamicDims() << " dynamic sizes";
25799260e95SMatthias Springer if (getCopy() && getCopy().getType() != getType())
25856d68e8dSMatthias Springer return emitError("expected that `copy` and return type match");
2599a3d60e0SAart Bik
2609a3d60e0SAart Bik // For sparse tensor allocation, we require that none of its
2619a3d60e0SAart Bik // uses escapes the function boundary directly.
2629a3d60e0SAart Bik if (sparse_tensor::getSparseTensorEncoding(getType())) {
2639a3d60e0SAart Bik for (auto &use : getOperation()->getUses())
2649a3d60e0SAart Bik if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
2659a3d60e0SAart Bik use.getOwner()))
2669a3d60e0SAart Bik return emitError("sparse tensor allocation should not escape function");
2679a3d60e0SAart Bik }
2689a3d60e0SAart Bik
269ffdbecccSMatthias Springer return success();
270ffdbecccSMatthias Springer }
271ffdbecccSMatthias Springer
build(OpBuilder & builder,OperationState & result,RankedTensorType type,ValueRange dynamicSizes)27256d68e8dSMatthias Springer void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
27356d68e8dSMatthias Springer RankedTensorType type, ValueRange dynamicSizes) {
274c06f01ffSMatthias Springer build(builder, result, type, dynamicSizes, /*copy=*/Value(),
2750d0a94a7SMatthias Springer /*memory_space=*/IntegerAttr());
276c06f01ffSMatthias Springer }
277c06f01ffSMatthias Springer
build(OpBuilder & builder,OperationState & result,RankedTensorType type,ValueRange dynamicSizes,Value copy)278c06f01ffSMatthias Springer void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
279c06f01ffSMatthias Springer RankedTensorType type, ValueRange dynamicSizes,
280c06f01ffSMatthias Springer Value copy) {
2810d0a94a7SMatthias Springer build(builder, result, type, dynamicSizes, copy,
2820d0a94a7SMatthias Springer /*memory_space=*/IntegerAttr());
28356d68e8dSMatthias Springer }
28456d68e8dSMatthias Springer
285ffdbecccSMatthias Springer namespace {
286ffdbecccSMatthias Springer /// Change the type of the result of a `bufferization.alloc_tensor` by making
287ffdbecccSMatthias Springer /// the result type statically sized along dimension that in the original
288ffdbecccSMatthias Springer /// operation where defined as dynamic, but the size was defined using a
289ffdbecccSMatthias Springer /// `constant` op. For example:
290ffdbecccSMatthias Springer ///
291ffdbecccSMatthias Springer /// %c5 = arith.constant 5: index
292ec55f0bdSMatthias Springer /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
293ffdbecccSMatthias Springer ///
294ffdbecccSMatthias Springer /// to
295ffdbecccSMatthias Springer ///
296ec55f0bdSMatthias Springer /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
297ffdbecccSMatthias Springer struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
298ffdbecccSMatthias Springer using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
299ffdbecccSMatthias Springer
matchAndRewrite__anon56ded4390311::ReplaceStaticShapeDims300ffdbecccSMatthias Springer LogicalResult matchAndRewrite(AllocTensorOp op,
301ffdbecccSMatthias Springer PatternRewriter &rewriter) const override {
30299260e95SMatthias Springer if (op.getCopy())
30356d68e8dSMatthias Springer return failure();
304ec55f0bdSMatthias Springer SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
305ec55f0bdSMatthias Springer SmallVector<Value> newDynamicSizes;
306ec55f0bdSMatthias Springer unsigned int dynValCounter = 0;
307ec55f0bdSMatthias Springer for (int64_t i = 0; i < op.getType().getRank(); ++i) {
308ec55f0bdSMatthias Springer if (!op.isDynamicDim(i))
309ffdbecccSMatthias Springer continue;
31099260e95SMatthias Springer Value value = op.getDynamicSizes()[dynValCounter++];
311ec55f0bdSMatthias Springer APInt intVal;
312ec55f0bdSMatthias Springer if (matchPattern(value, m_ConstantInt(&intVal))) {
313ec55f0bdSMatthias Springer newShape[i] = intVal.getSExtValue();
314ec55f0bdSMatthias Springer } else {
315ec55f0bdSMatthias Springer newDynamicSizes.push_back(value);
316ffdbecccSMatthias Springer }
317ffdbecccSMatthias Springer }
318ec55f0bdSMatthias Springer RankedTensorType newType = RankedTensorType::get(
319ec55f0bdSMatthias Springer newShape, op.getType().getElementType(), op.getType().getEncoding());
320ffdbecccSMatthias Springer if (newType == op.getType())
321ffdbecccSMatthias Springer return failure();
32256d68e8dSMatthias Springer auto newOp = rewriter.create<AllocTensorOp>(
3233474d10eSMatthias Springer op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
324ffdbecccSMatthias Springer rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
325ffdbecccSMatthias Springer return success();
326ffdbecccSMatthias Springer }
327ffdbecccSMatthias Springer };
328ffdbecccSMatthias Springer
329ffdbecccSMatthias Springer struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
330ffdbecccSMatthias Springer using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
331ffdbecccSMatthias Springer
matchAndRewrite__anon56ded4390311::FoldDimOfAllocTensorOp332ffdbecccSMatthias Springer LogicalResult matchAndRewrite(tensor::DimOp dimOp,
333ffdbecccSMatthias Springer PatternRewriter &rewriter) const override {
334ffdbecccSMatthias Springer Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
33504235d07SJacques Pienaar auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
336ffdbecccSMatthias Springer if (!allocTensorOp || !maybeConstantIndex)
337ffdbecccSMatthias Springer return failure();
338ec55f0bdSMatthias Springer if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
339ffdbecccSMatthias Springer return failure();
34056d68e8dSMatthias Springer rewriter.replaceOp(
34156d68e8dSMatthias Springer dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
342ffdbecccSMatthias Springer return success();
343ffdbecccSMatthias Springer }
344ffdbecccSMatthias Springer };
345ffdbecccSMatthias Springer } // namespace
346ffdbecccSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * ctx)347ffdbecccSMatthias Springer void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
348ffdbecccSMatthias Springer MLIRContext *ctx) {
349ffdbecccSMatthias Springer results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
350ffdbecccSMatthias Springer }
351ffdbecccSMatthias Springer
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)352ffdbecccSMatthias Springer LogicalResult AllocTensorOp::reifyResultShapes(
353ffdbecccSMatthias Springer OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
354ffdbecccSMatthias Springer auto shapes = llvm::to_vector<4>(llvm::map_range(
355ffdbecccSMatthias Springer llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
356ec55f0bdSMatthias Springer if (isDynamicDim(dim))
35756d68e8dSMatthias Springer return getDynamicSize(builder, dim);
358ffdbecccSMatthias Springer return builder.create<arith::ConstantIndexOp>(getLoc(),
359ffdbecccSMatthias Springer getStaticSize(dim));
360ffdbecccSMatthias Springer }));
361ffdbecccSMatthias Springer reifiedReturnShapes.emplace_back(std::move(shapes));
362ffdbecccSMatthias Springer return success();
363ffdbecccSMatthias Springer }
364ffdbecccSMatthias Springer
parse(OpAsmParser & parser,OperationState & result)36556d68e8dSMatthias Springer ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
36656d68e8dSMatthias Springer SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
36756d68e8dSMatthias Springer if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
36856d68e8dSMatthias Springer parser.parseRParen())
36956d68e8dSMatthias Springer return failure();
37056d68e8dSMatthias Springer ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
37156d68e8dSMatthias Springer OpAsmParser::UnresolvedOperand copyOperand;
37256d68e8dSMatthias Springer if (copyKeyword.succeeded())
37356d68e8dSMatthias Springer if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
37456d68e8dSMatthias Springer parser.parseRParen())
37556d68e8dSMatthias Springer return failure();
37656d68e8dSMatthias Springer if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
37756d68e8dSMatthias Springer return failure();
37856d68e8dSMatthias Springer
37956d68e8dSMatthias Springer TensorType type;
38056d68e8dSMatthias Springer if (parser.parseCustomTypeWithFallback(type))
38156d68e8dSMatthias Springer return failure();
38256d68e8dSMatthias Springer result.addTypes(type);
38356d68e8dSMatthias Springer
38456d68e8dSMatthias Springer Type indexType = parser.getBuilder().getIndexType();
38556d68e8dSMatthias Springer if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
38656d68e8dSMatthias Springer return failure();
38756d68e8dSMatthias Springer if (copyKeyword.succeeded())
38856d68e8dSMatthias Springer if (parser.resolveOperand(copyOperand, type, result.operands))
38956d68e8dSMatthias Springer return failure();
39056d68e8dSMatthias Springer result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
39156d68e8dSMatthias Springer parser.getBuilder().getI32VectorAttr(
39256d68e8dSMatthias Springer {static_cast<int32_t>(dynamicSizesOperands.size()),
39356d68e8dSMatthias Springer static_cast<int32_t>(copyKeyword.succeeded())}));
39456d68e8dSMatthias Springer return success();
39556d68e8dSMatthias Springer }
39656d68e8dSMatthias Springer
print(OpAsmPrinter & p)39756d68e8dSMatthias Springer void AllocTensorOp::print(OpAsmPrinter &p) {
39899260e95SMatthias Springer p << "(" << getDynamicSizes() << ")";
39999260e95SMatthias Springer if (getCopy())
40099260e95SMatthias Springer p << " copy(" << getCopy() << ")";
40156d68e8dSMatthias Springer p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
40256d68e8dSMatthias Springer AllocTensorOp::getOperandSegmentSizeAttr()});
40356d68e8dSMatthias Springer p << " : ";
40499260e95SMatthias Springer auto type = getResult().getType();
40556d68e8dSMatthias Springer if (auto validType = type.dyn_cast<::mlir::TensorType>())
40656d68e8dSMatthias Springer p.printStrippedAttrOrType(validType);
40756d68e8dSMatthias Springer else
40856d68e8dSMatthias Springer p << type;
40956d68e8dSMatthias Springer }
41056d68e8dSMatthias Springer
getDynamicSize(OpBuilder & b,unsigned idx)41156d68e8dSMatthias Springer Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
41256d68e8dSMatthias Springer assert(isDynamicDim(idx) && "expected dynamic dim");
41399260e95SMatthias Springer if (getCopy())
41499260e95SMatthias Springer return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
41556d68e8dSMatthias Springer return getOperand(getIndexOfDynamicSize(idx));
41656d68e8dSMatthias Springer }
41756d68e8dSMatthias Springer
418ffdbecccSMatthias Springer //===----------------------------------------------------------------------===//
41957470abcSAlexander Belyaev // CloneOp
42057470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
42157470abcSAlexander Belyaev
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)42257470abcSAlexander Belyaev void CloneOp::getEffects(
42357470abcSAlexander Belyaev SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
42457470abcSAlexander Belyaev &effects) {
42599260e95SMatthias Springer effects.emplace_back(MemoryEffects::Read::get(), getInput(),
42657470abcSAlexander Belyaev SideEffects::DefaultResource::get());
42799260e95SMatthias Springer effects.emplace_back(MemoryEffects::Write::get(), getOutput(),
42857470abcSAlexander Belyaev SideEffects::DefaultResource::get());
42999260e95SMatthias Springer effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(),
43057470abcSAlexander Belyaev SideEffects::DefaultResource::get());
43157470abcSAlexander Belyaev }
43257470abcSAlexander Belyaev
fold(ArrayRef<Attribute> operands)43357470abcSAlexander Belyaev OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
43457470abcSAlexander Belyaev return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
43557470abcSAlexander Belyaev }
43657470abcSAlexander Belyaev
43757470abcSAlexander Belyaev namespace {
43857470abcSAlexander Belyaev
43957470abcSAlexander Belyaev /// Merge the clone and its source (by converting the clone to a cast) when
44057470abcSAlexander Belyaev /// possible.
44157470abcSAlexander Belyaev struct SimplifyClones : public OpRewritePattern<CloneOp> {
44257470abcSAlexander Belyaev using OpRewritePattern<CloneOp>::OpRewritePattern;
44357470abcSAlexander Belyaev
matchAndRewrite__anon56ded4390511::SimplifyClones44457470abcSAlexander Belyaev LogicalResult matchAndRewrite(CloneOp cloneOp,
44557470abcSAlexander Belyaev PatternRewriter &rewriter) const override {
44657470abcSAlexander Belyaev if (cloneOp.use_empty()) {
44757470abcSAlexander Belyaev rewriter.eraseOp(cloneOp);
44857470abcSAlexander Belyaev return success();
44957470abcSAlexander Belyaev }
45057470abcSAlexander Belyaev
45199260e95SMatthias Springer Value source = cloneOp.getInput();
45257470abcSAlexander Belyaev
45357470abcSAlexander Belyaev // This only finds dealloc operations for the immediate value. It should
45457470abcSAlexander Belyaev // also consider aliases. That would also make the safety check below
45557470abcSAlexander Belyaev // redundant.
45657470abcSAlexander Belyaev llvm::Optional<Operation *> maybeCloneDeallocOp =
45799260e95SMatthias Springer memref::findDealloc(cloneOp.getOutput());
45857470abcSAlexander Belyaev // Skip if either of them has > 1 deallocate operations.
459491d2701SKazu Hirata if (!maybeCloneDeallocOp.has_value())
46057470abcSAlexander Belyaev return failure();
461af9f7d31SUday Bondhugula llvm::Optional<Operation *> maybeSourceDeallocOp =
462af9f7d31SUday Bondhugula memref::findDealloc(source);
463491d2701SKazu Hirata if (!maybeSourceDeallocOp.has_value())
46457470abcSAlexander Belyaev return failure();
46557470abcSAlexander Belyaev Operation *cloneDeallocOp = *maybeCloneDeallocOp;
46657470abcSAlexander Belyaev Operation *sourceDeallocOp = *maybeSourceDeallocOp;
46757470abcSAlexander Belyaev
46857470abcSAlexander Belyaev // If both are deallocated in the same block, their in-block lifetimes
46957470abcSAlexander Belyaev // might not fully overlap, so we cannot decide which one to drop.
47057470abcSAlexander Belyaev if (cloneDeallocOp && sourceDeallocOp &&
47157470abcSAlexander Belyaev cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
47257470abcSAlexander Belyaev return failure();
47357470abcSAlexander Belyaev
47457470abcSAlexander Belyaev Block *currentBlock = cloneOp->getBlock();
47557470abcSAlexander Belyaev Operation *redundantDealloc = nullptr;
47657470abcSAlexander Belyaev if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
47757470abcSAlexander Belyaev redundantDealloc = cloneDeallocOp;
47857470abcSAlexander Belyaev } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
47957470abcSAlexander Belyaev redundantDealloc = sourceDeallocOp;
48057470abcSAlexander Belyaev }
48157470abcSAlexander Belyaev
48257470abcSAlexander Belyaev if (!redundantDealloc)
48357470abcSAlexander Belyaev return failure();
48457470abcSAlexander Belyaev
48557470abcSAlexander Belyaev // Safety check that there are no other deallocations inbetween
48657470abcSAlexander Belyaev // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
48757470abcSAlexander Belyaev // of source before the uses of the clone. With alias information, we could
48857470abcSAlexander Belyaev // restrict this to only fail of the dealloc's operand is an alias
48957470abcSAlexander Belyaev // of the source.
49057470abcSAlexander Belyaev for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
49157470abcSAlexander Belyaev pos = pos->getNextNode()) {
49257470abcSAlexander Belyaev auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
49357470abcSAlexander Belyaev if (!effectInterface)
49457470abcSAlexander Belyaev continue;
49557470abcSAlexander Belyaev if (effectInterface.hasEffect<MemoryEffects::Free>())
49657470abcSAlexander Belyaev return failure();
49757470abcSAlexander Belyaev }
49857470abcSAlexander Belyaev
49957470abcSAlexander Belyaev rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
50057470abcSAlexander Belyaev source);
50157470abcSAlexander Belyaev rewriter.eraseOp(redundantDealloc);
50257470abcSAlexander Belyaev return success();
50357470abcSAlexander Belyaev }
50457470abcSAlexander Belyaev };
50557470abcSAlexander Belyaev
506be0a7e9fSMehdi Amini } // namespace
50757470abcSAlexander Belyaev
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)5089f85c198SRiver Riddle void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
50957470abcSAlexander Belyaev MLIRContext *context) {
510b4e0507cSTres Popp results.add<SimplifyClones>(context);
51157470abcSAlexander Belyaev }
51257470abcSAlexander Belyaev
51357470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
51427a431f5SMatthias Springer // DeallocTensorOp
51527a431f5SMatthias Springer //===----------------------------------------------------------------------===//
51627a431f5SMatthias Springer
bufferize(RewriterBase & rewriter,const BufferizationOptions & options)51727a431f5SMatthias Springer LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
51827a431f5SMatthias Springer const BufferizationOptions &options) {
51927a431f5SMatthias Springer FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
52027a431f5SMatthias Springer if (failed(buffer))
52127a431f5SMatthias Springer return failure();
52227a431f5SMatthias Springer if (failed(options.createDealloc(rewriter, getLoc(), *buffer)))
52327a431f5SMatthias Springer return failure();
52427a431f5SMatthias Springer rewriter.eraseOp(getOperation());
52527a431f5SMatthias Springer return success();
52627a431f5SMatthias Springer }
52727a431f5SMatthias Springer
52827a431f5SMatthias Springer //===----------------------------------------------------------------------===//
52957470abcSAlexander Belyaev // ToTensorOp
53057470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
53157470abcSAlexander Belyaev
fold(ArrayRef<Attribute>)53257470abcSAlexander Belyaev OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
53399260e95SMatthias Springer if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
53457470abcSAlexander Belyaev // Approximate alias analysis by conservatively folding only when no there
53557470abcSAlexander Belyaev // is no interleaved operation.
53657470abcSAlexander Belyaev if (toMemref->getBlock() == this->getOperation()->getBlock() &&
53757470abcSAlexander Belyaev toMemref->getNextNode() == this->getOperation())
53899260e95SMatthias Springer return toMemref.getTensor();
53957470abcSAlexander Belyaev return {};
54057470abcSAlexander Belyaev }
54157470abcSAlexander Belyaev
54257470abcSAlexander Belyaev namespace {
54357470abcSAlexander Belyaev struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
54457470abcSAlexander Belyaev using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
54557470abcSAlexander Belyaev
matchAndRewrite__anon56ded4390611::DimOfToTensorFolder54657470abcSAlexander Belyaev LogicalResult matchAndRewrite(tensor::DimOp dimOp,
54757470abcSAlexander Belyaev PatternRewriter &rewriter) const override {
54804235d07SJacques Pienaar auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
54957470abcSAlexander Belyaev if (!memrefToTensorOp)
55057470abcSAlexander Belyaev return failure();
55157470abcSAlexander Belyaev
55299260e95SMatthias Springer rewriter.replaceOpWithNewOp<memref::DimOp>(
55304235d07SJacques Pienaar dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
55457470abcSAlexander Belyaev return success();
55557470abcSAlexander Belyaev }
55657470abcSAlexander Belyaev };
55757470abcSAlexander Belyaev } // namespace
55857470abcSAlexander Belyaev
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)55957470abcSAlexander Belyaev void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
56057470abcSAlexander Belyaev MLIRContext *context) {
561fc9b37ddSMatthias Springer results.add<DimOfToTensorFolder>(context);
56257470abcSAlexander Belyaev }
56357470abcSAlexander Belyaev
56457470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
56557470abcSAlexander Belyaev // ToMemrefOp
56657470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
56757470abcSAlexander Belyaev
fold(ArrayRef<Attribute>)56857470abcSAlexander Belyaev OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
56999260e95SMatthias Springer if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
57099260e95SMatthias Springer if (memrefToTensor.getMemref().getType() == getType())
57199260e95SMatthias Springer return memrefToTensor.getMemref();
57257470abcSAlexander Belyaev return {};
57357470abcSAlexander Belyaev }
57457470abcSAlexander Belyaev
57557470abcSAlexander Belyaev namespace {
57657470abcSAlexander Belyaev
57757470abcSAlexander Belyaev /// Replace tensor.cast + to_memref by to_memref + memref.cast.
57857470abcSAlexander Belyaev struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
57957470abcSAlexander Belyaev using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
58057470abcSAlexander Belyaev
matchAndRewrite__anon56ded4390711::ToMemrefOfCast58157470abcSAlexander Belyaev LogicalResult matchAndRewrite(ToMemrefOp toMemref,
58257470abcSAlexander Belyaev PatternRewriter &rewriter) const final {
58357470abcSAlexander Belyaev auto tensorCastOperand =
58457470abcSAlexander Belyaev toMemref.getOperand().getDefiningOp<tensor::CastOp>();
58557470abcSAlexander Belyaev if (!tensorCastOperand)
58657470abcSAlexander Belyaev return failure();
58757470abcSAlexander Belyaev auto srcTensorType =
58857470abcSAlexander Belyaev tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
58957470abcSAlexander Belyaev if (!srcTensorType)
59057470abcSAlexander Belyaev return failure();
59157470abcSAlexander Belyaev auto memrefType = MemRefType::get(srcTensorType.getShape(),
59257470abcSAlexander Belyaev srcTensorType.getElementType());
59357470abcSAlexander Belyaev Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
59457470abcSAlexander Belyaev tensorCastOperand.getOperand());
59557470abcSAlexander Belyaev rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
59657470abcSAlexander Belyaev memref);
59757470abcSAlexander Belyaev return success();
59857470abcSAlexander Belyaev }
59957470abcSAlexander Belyaev };
60057470abcSAlexander Belyaev
601cb471241SMatthias Springer /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
602cb471241SMatthias Springer /// cast if necessary.
603cb471241SMatthias Springer struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
604b00ee46bSMatthias Springer using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
605b00ee46bSMatthias Springer
matchAndRewrite__anon56ded4390711::ToMemrefToTensorFolding606b00ee46bSMatthias Springer LogicalResult matchAndRewrite(ToMemrefOp toMemref,
607b00ee46bSMatthias Springer PatternRewriter &rewriter) const final {
608cb471241SMatthias Springer return foldToMemrefToTensorPair(rewriter, toMemref);
609b00ee46bSMatthias Springer }
61057470abcSAlexander Belyaev };
61157470abcSAlexander Belyaev
61257470abcSAlexander Belyaev /// Fold a load on a to_memref operation into an tensor.extract on the
61357470abcSAlexander Belyaev /// corresponding tensor.
61457470abcSAlexander Belyaev struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
61557470abcSAlexander Belyaev using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
61657470abcSAlexander Belyaev
matchAndRewrite__anon56ded4390711::LoadOfToMemref61757470abcSAlexander Belyaev LogicalResult matchAndRewrite(memref::LoadOp load,
61857470abcSAlexander Belyaev PatternRewriter &rewriter) const override {
619136d746eSJacques Pienaar auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
62057470abcSAlexander Belyaev if (!toMemref)
62157470abcSAlexander Belyaev return failure();
62257470abcSAlexander Belyaev
62399260e95SMatthias Springer rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
624136d746eSJacques Pienaar load.getIndices());
62557470abcSAlexander Belyaev return success();
62657470abcSAlexander Belyaev }
62757470abcSAlexander Belyaev };
62857470abcSAlexander Belyaev
62957470abcSAlexander Belyaev /// Fold dim of a to_memref into the dim of the tensor.
63057470abcSAlexander Belyaev struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
63157470abcSAlexander Belyaev using OpRewritePattern<memref::DimOp>::OpRewritePattern;
63257470abcSAlexander Belyaev
matchAndRewrite__anon56ded4390711::DimOfCastOp63357470abcSAlexander Belyaev LogicalResult matchAndRewrite(memref::DimOp dimOp,
63457470abcSAlexander Belyaev PatternRewriter &rewriter) const override {
635136d746eSJacques Pienaar auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
63657470abcSAlexander Belyaev if (!castOp)
63757470abcSAlexander Belyaev return failure();
63857470abcSAlexander Belyaev Value newSource = castOp.getOperand();
639136d746eSJacques Pienaar rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
640136d746eSJacques Pienaar dimOp.getIndex());
64157470abcSAlexander Belyaev return success();
64257470abcSAlexander Belyaev }
64357470abcSAlexander Belyaev };
64457470abcSAlexander Belyaev
64557470abcSAlexander Belyaev } // namespace
64657470abcSAlexander Belyaev
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)64757470abcSAlexander Belyaev void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
64857470abcSAlexander Belyaev MLIRContext *context) {
649cb471241SMatthias Springer results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
650cb471241SMatthias Springer ToMemrefToTensorFolding>(context);
65157470abcSAlexander Belyaev }
65257470abcSAlexander Belyaev
bufferize(RewriterBase & rewriter,const BufferizationOptions & options)653b00ee46bSMatthias Springer LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
654b55d55ecSMatthias Springer const BufferizationOptions &options) {
655b00ee46bSMatthias Springer // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
6560b293bf0SMatthias Springer (void)foldToMemrefToTensorPair(rewriter, *this);
6570b293bf0SMatthias Springer // Note: The return value of `bufferize` indicates whether there was an error
6580b293bf0SMatthias Springer // or not. (And not whether the pattern matched or not.)
6590b293bf0SMatthias Springer return success();
660b00ee46bSMatthias Springer }
661b00ee46bSMatthias Springer
buildDealloc(OpBuilder & builder,Value alloc)66257470abcSAlexander Belyaev Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
66357470abcSAlexander Belyaev return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
66457470abcSAlexander Belyaev .getOperation();
66557470abcSAlexander Belyaev }
66657470abcSAlexander Belyaev
buildClone(OpBuilder & builder,Value alloc)66757470abcSAlexander Belyaev Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
66857470abcSAlexander Belyaev return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
66957470abcSAlexander Belyaev }
67057470abcSAlexander Belyaev
67157470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
67257470abcSAlexander Belyaev // TableGen'd op method definitions
67357470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
67457470abcSAlexander Belyaev
67557470abcSAlexander Belyaev #define GET_OP_CLASSES
67657470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
677