1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::shape;
21 
22 namespace mlir {
23 namespace shape {
24 namespace {
25 
26 /// Bufferization of shape.assuming.
27 struct AssumingOpInterface
28     : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
29                                                     shape::AssumingOp> {
30   SmallVector<OpOperand *>
31   getAliasingOpOperand(Operation *op, OpResult opResult,
32                        const AnalysisState &state) const {
33     // AssumingOps do not have tensor OpOperands. The yielded value can be any
34     // SSA value that is in scope. To allow for use-def chain traversal through
35     // AssumingOps in the analysis, the corresponding yield value is considered
36     // to be aliasing with the result.
37     auto assumingOp = cast<shape::AssumingOp>(op);
38     size_t resultNum = std::distance(op->getOpResults().begin(),
39                                      llvm::find(op->getOpResults(), opResult));
40     // TODO: Support multiple blocks.
41     assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
42            "expected exactly 1 block");
43     auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
44         assumingOp.getDoRegion().front().getTerminator());
45     assert(yieldOp && "expected shape.assuming_yield terminator");
46     return {&yieldOp->getOpOperand(resultNum)};
47   }
48 
49   // TODO: For better bufferization results, this could return `true` only if
50   // there is a memory write in the region.
51   bool isMemoryWrite(Operation *op, OpResult opResult,
52                      const AnalysisState &state) const {
53     // Similar to scf.if, results of this op are always considered memory writes
54     // in the analysis. This is a useful pattern for all ops that have tensor
55     // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
56     // implemented in terms of `bufferizesToMemoryWrite`, which does not work on
57     // ops without OpOperands.
58     return true;
59   }
60 
61   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
62                           BufferizationState &state) const {
63     auto assumingOp = cast<shape::AssumingOp>(op);
64 
65     // Compute new result types.
66     SmallVector<Type> newResultTypes;
67     for (Type type : assumingOp->getResultTypes()) {
68       if (auto tensorType = type.dyn_cast<TensorType>()) {
69         // TODO: Infer the result type instead of computing it.
70         newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
71       } else {
72         newResultTypes.push_back(type);
73       }
74     }
75 
76     // Create new op and move over region.
77     auto newOp = rewriter.create<shape::AssumingOp>(
78         op->getLoc(), newResultTypes, assumingOp.getWitness());
79     newOp.getDoRegion().takeBody(assumingOp.getRegion());
80 
81     // Update terminator.
82     assert(newOp.getDoRegion().getBlocks().size() == 1 &&
83            "only 1 block supported");
84     Block *newBlock = &newOp.getDoRegion().front();
85     auto yieldOp = cast<shape::AssumingYieldOp>(newBlock->getTerminator());
86     rewriter.setInsertionPoint(yieldOp);
87     SmallVector<Value> newYieldValues;
88     for (const auto &it : llvm::enumerate(yieldOp.operands())) {
89       Value val = it.value();
90       if (val.getType().isa<TensorType>()) {
91         newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
92             yieldOp.getLoc(), newResultTypes[it.index()], val));
93       } else {
94         newYieldValues.push_back(val);
95       }
96     }
97     rewriter.replaceOpWithNewOp<shape::AssumingYieldOp>(yieldOp,
98                                                         newYieldValues);
99 
100     // Update all uses of the old op.
101     rewriter.setInsertionPointAfter(newOp);
102     SmallVector<Value> newResults;
103     for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
104       if (it.value().isa<TensorType>()) {
105         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
106             assumingOp.getLoc(), newOp->getResult(it.index())));
107       } else {
108         newResults.push_back(newOp->getResult(it.index()));
109       }
110     }
111 
112     // Replace old op.
113     rewriter.replaceOp(assumingOp, newResults);
114 
115     return success();
116   }
117 
118   BufferRelation bufferRelation(Operation *op, OpResult opResult,
119                                 const AnalysisState &state) const {
120     return BufferRelation::Equivalent;
121   }
122 };
123 
124 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
125 /// ops, so this is for analysis only.
126 struct AssumingYieldOpInterface
127     : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
128                                                     shape::AssumingYieldOp> {
129   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
130                               const AnalysisState &state) const {
131     return true;
132   }
133 
134   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
135                                const AnalysisState &state) const {
136     return false;
137   }
138 
139   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
140                                             const AnalysisState &state) const {
141     assert(isa<shape::AssumingOp>(op->getParentOp()) &&
142            "expected that parent is an AssumingOp");
143     return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
144   }
145 
146   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
147                             const AnalysisState &state) const {
148     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
149     // may be generated inside the block. We should not return/yield allocations
150     // when possible.
151     return true;
152   }
153 
154   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
155                           BufferizationState &state) const {
156     // Op is bufferized as part of AssumingOp.
157     return failure();
158   }
159 };
160 
161 } // namespace
162 } // namespace shape
163 } // namespace mlir
164 
165 void mlir::shape::registerBufferizableOpInterfaceExternalModels(
166     DialectRegistry &registry) {
167   registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
168     shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
169     shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
170   });
171 }
172