1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
10
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
17
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::shape;
21
22 namespace mlir {
23 namespace shape {
24 namespace {
25
26 /// Bufferization of shape.assuming.
27 struct AssumingOpInterface
28 : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
29 shape::AssumingOp> {
30 SmallVector<OpOperand *>
getAliasingOpOperandmlir::shape::__anon2c7d7a2f0111::AssumingOpInterface31 getAliasingOpOperand(Operation *op, OpResult opResult,
32 const AnalysisState &state) const {
33 // AssumingOps do not have tensor OpOperands. The yielded value can be any
34 // SSA value that is in scope. To allow for use-def chain traversal through
35 // AssumingOps in the analysis, the corresponding yield value is considered
36 // to be aliasing with the result.
37 auto assumingOp = cast<shape::AssumingOp>(op);
38 size_t resultNum = std::distance(op->getOpResults().begin(),
39 llvm::find(op->getOpResults(), opResult));
40 // TODO: Support multiple blocks.
41 assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
42 "expected exactly 1 block");
43 auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
44 assumingOp.getDoRegion().front().getTerminator());
45 assert(yieldOp && "expected shape.assuming_yield terminator");
46 return {&yieldOp->getOpOperand(resultNum)};
47 }
48
49 // TODO: For better bufferization results, this could return `true` only if
50 // there is a memory write in the region.
isMemoryWritemlir::shape::__anon2c7d7a2f0111::AssumingOpInterface51 bool isMemoryWrite(Operation *op, OpResult opResult,
52 const AnalysisState &state) const {
53 // Similar to scf.if, results of this op are always considered memory writes
54 // in the analysis. This is a useful pattern for all ops that have tensor
55 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
56 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
57 // ops without OpOperands.
58 return true;
59 }
60
bufferizemlir::shape::__anon2c7d7a2f0111::AssumingOpInterface61 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
62 const BufferizationOptions &options) const {
63 auto assumingOp = cast<shape::AssumingOp>(op);
64 assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
65 "only 1 block supported");
66 auto yieldOp = cast<shape::AssumingYieldOp>(
67 assumingOp.getDoRegion().front().getTerminator());
68
69 // Create new op and move over region.
70 TypeRange newResultTypes(yieldOp.operands());
71 auto newOp = rewriter.create<shape::AssumingOp>(
72 op->getLoc(), newResultTypes, assumingOp.getWitness());
73 newOp.getDoRegion().takeBody(assumingOp.getRegion());
74
75 // Update all uses of the old op.
76 rewriter.setInsertionPointAfter(newOp);
77 SmallVector<Value> newResults;
78 for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
79 if (it.value().isa<TensorType>()) {
80 newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
81 assumingOp.getLoc(), newOp->getResult(it.index())));
82 } else {
83 newResults.push_back(newOp->getResult(it.index()));
84 }
85 }
86
87 // Replace old op.
88 rewriter.replaceOp(assumingOp, newResults);
89
90 return success();
91 }
92
bufferRelationmlir::shape::__anon2c7d7a2f0111::AssumingOpInterface93 BufferRelation bufferRelation(Operation *op, OpResult opResult,
94 const AnalysisState &state) const {
95 return BufferRelation::Equivalent;
96 }
97 };
98
99 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
100 /// ops, so this is for analysis only.
101 struct AssumingYieldOpInterface
102 : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
103 shape::AssumingYieldOp> {
bufferizesToMemoryReadmlir::shape::__anon2c7d7a2f0111::AssumingYieldOpInterface104 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
105 const AnalysisState &state) const {
106 return true;
107 }
108
bufferizesToMemoryWritemlir::shape::__anon2c7d7a2f0111::AssumingYieldOpInterface109 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
110 const AnalysisState &state) const {
111 return false;
112 }
113
getAliasingOpResultmlir::shape::__anon2c7d7a2f0111::AssumingYieldOpInterface114 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
115 const AnalysisState &state) const {
116 assert(isa<shape::AssumingOp>(op->getParentOp()) &&
117 "expected that parent is an AssumingOp");
118 return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
119 }
120
mustBufferizeInPlacemlir::shape::__anon2c7d7a2f0111::AssumingYieldOpInterface121 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
122 const AnalysisState &state) const {
123 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
124 // may be generated inside the block. We should not return/yield allocations
125 // when possible.
126 return true;
127 }
128
bufferizemlir::shape::__anon2c7d7a2f0111::AssumingYieldOpInterface129 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
130 const BufferizationOptions &options) const {
131 auto yieldOp = cast<shape::AssumingYieldOp>(op);
132 SmallVector<Value> newResults;
133 for (Value value : yieldOp.operands()) {
134 if (value.getType().isa<TensorType>()) {
135 FailureOr<Value> buffer = getBuffer(rewriter, value, options);
136 if (failed(buffer))
137 return failure();
138 newResults.push_back(*buffer);
139 } else {
140 newResults.push_back(value);
141 }
142 }
143 replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
144 newResults);
145 return success();
146 }
147 };
148
149 } // namespace
150 } // namespace shape
151 } // namespace mlir
152
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)153 void mlir::shape::registerBufferizableOpInterfaceExternalModels(
154 DialectRegistry ®istry) {
155 registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
156 shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
157 shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
158 });
159 }
160