1 //===- AllocTensorElimination.cpp - alloc_tensor op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/Dominance.h"
18 #include "mlir/Pass/Pass.h"
19
20 using namespace mlir;
21 using namespace mlir::bufferization;
22
23 /// Return true if all `neededValues` are in scope at the given
24 /// `insertionPoint`.
25 static bool
neededValuesDominateInsertionPoint(const DominanceInfo & domInfo,Operation * insertionPoint,const SmallVector<Value> & neededValues)26 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
27 Operation *insertionPoint,
28 const SmallVector<Value> &neededValues) {
29 for (Value val : neededValues) {
30 if (auto bbArg = val.dyn_cast<BlockArgument>()) {
31 Block *owner = bbArg.getOwner();
32 if (!owner->findAncestorOpInBlock(*insertionPoint))
33 return false;
34 } else {
35 auto opResult = val.cast<OpResult>();
36 if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
37 return false;
38 }
39 }
40 return true;
41 }
42
43 /// Return true if the given `insertionPoint` dominates all uses of
44 /// `allocTensorOp`.
insertionPointDominatesUses(const DominanceInfo & domInfo,Operation * insertionPoint,Operation * allocTensorOp)45 static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
46 Operation *insertionPoint,
47 Operation *allocTensorOp) {
48 for (Operation *user : allocTensorOp->getUsers())
49 if (!domInfo.dominates(insertionPoint, user))
50 return false;
51 return true;
52 }
53
54 /// Find a valid insertion point for a replacement of `allocTensorOp`, assuming
55 /// that the replacement may use any value from `neededValues`.
56 static Operation *
findValidInsertionPoint(Operation * allocTensorOp,const SmallVector<Value> & neededValues)57 findValidInsertionPoint(Operation *allocTensorOp,
58 const SmallVector<Value> &neededValues) {
59 DominanceInfo domInfo;
60
61 // Gather all possible insertion points: the location of `allocTensorOp` and
62 // right after the definition of each value in `neededValues`.
63 SmallVector<Operation *> insertionPointCandidates;
64 insertionPointCandidates.push_back(allocTensorOp);
65 for (Value val : neededValues) {
66 // Note: The anchor op is using all of `neededValues`, so:
67 // * in case of a block argument: There must be at least one op in the block
68 // (the anchor op or one of its parents).
69 // * in case of an OpResult: There must be at least one op right after the
70 // defining op (the anchor op or one of its
71 // parents).
72 if (auto bbArg = val.dyn_cast<BlockArgument>()) {
73 insertionPointCandidates.push_back(
74 &bbArg.getOwner()->getOperations().front());
75 } else {
76 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
77 }
78 }
79
80 // Select first matching insertion point.
81 for (Operation *insertionPoint : insertionPointCandidates) {
82 // Check if all needed values are in scope.
83 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
84 neededValues))
85 continue;
86 // Check if the insertion point is before all uses.
87 if (!insertionPointDominatesUses(domInfo, insertionPoint, allocTensorOp))
88 continue;
89 return insertionPoint;
90 }
91
92 // No suitable insertion point was found.
93 return nullptr;
94 }
95
96 /// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp is replaced
97 /// with the result of `rewriteFunc` if it is anchored on a matching
98 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
99 /// chain, starting from the OpOperand and always following the aliasing
100 /// OpOperand, that eventually ends at a single AllocTensorOp.
eliminateAllocTensors(RewriterBase & rewriter,Operation * op,AnalysisState & state,AnchorMatchFn anchorMatchFunc,RewriteFn rewriteFunc)101 LogicalResult mlir::bufferization::eliminateAllocTensors(
102 RewriterBase &rewriter, Operation *op, AnalysisState &state,
103 AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
104 OpBuilder::InsertionGuard g(rewriter);
105
106 WalkResult status = op->walk([&](Operation *op) {
107 for (OpOperand &operand : op->getOpOperands()) {
108 // Skip operands that do not bufferize inplace.
109 if (!state.isInPlace(operand))
110 continue;
111 // All values that are needed to create the replacement op.
112 SmallVector<Value> neededValues;
113 // Is this a matching OpOperand?
114 if (!anchorMatchFunc(operand, neededValues))
115 continue;
116 SetVector<Value> maybeAllocTensor =
117 state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
118 // Continue traversal until this function returns true.
119 OpResult opResult = val.dyn_cast<OpResult>();
120 if (!opResult)
121 return true;
122 SmallVector<OpOperand *> opOperands =
123 state.getAliasingOpOperand(opResult);
124 if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
125 return state.isInPlace(*operand);
126 }))
127 return true;
128 // Only equivalent tensors are supported at the moment.
129 // TODO: Support cases such as extract_slice(alloc_tensor)
130 return !llvm::all_of(opOperands, [&](OpOperand *operand) {
131 return state.areEquivalentBufferizedValues(operand->get(),
132 opResult);
133 });
134 });
135
136 // Replace only if the reverse use-def chain ends at exactly one
137 // AllocTensorOp.
138 if (maybeAllocTensor.size() != 1 ||
139 !maybeAllocTensor.front().getDefiningOp<AllocTensorOp>())
140 return WalkResult::skip();
141 Value allocTensor = maybeAllocTensor.front();
142
143 // Find a suitable insertion point.
144 Operation *insertionPoint =
145 findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues);
146 if (!insertionPoint)
147 continue;
148
149 // Create a replacement for the AllocTensorOp.
150 rewriter.setInsertionPoint(insertionPoint);
151 Value replacement = rewriteFunc(rewriter, allocTensor.getLoc(), operand);
152 if (!replacement)
153 continue;
154
155 // Replace the AllocTensorOp.
156 rewriter.replaceOp(allocTensor.getDefiningOp(), replacement);
157 }
158
159 // Advance to the next operation.
160 return WalkResult::advance();
161 });
162
163 return failure(status.wasInterrupted());
164 }
165
166 /// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp can be
167 /// eliminated if it is eventually inserted into another tensor (and some other
168 /// conditions are met).
169 ///
170 /// E.g.:
171 /// %0 = linalg.alloc_tensor
172 /// %1 = linalg.fill(%cst, %0) {inplace = [true]}
173 /// %2 = tensor.insert_slice %1 into %t[10][20][1]
174 ///
175 /// AllocTensorOp elimination will try to fill %t inplace instead of filling a
176 /// new allocation %0 and inserting it into %t. This is done by replacing the
177 /// AllocTensorOp with:
178 ///
179 /// %0 = tensor.extract_slice %t[10][20][1]
180 ///
181 /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
182 /// those bufferize inplace in the absence of other conflicts.
183 ///
184 /// Starting from an InsertSliceOp, an AllocTensorOp at the end of the insert
185 /// source's reverse use-def chain is eliminated if:
186 /// * On the reverse use-def chain path from the InsertSliceOp to the
187 /// AllocTensorOp, all ops were decided to bufferize inplace and the buffer
188 /// relation is "equivalent" (TODO: can be relaxed if needed).
189 /// * The reverse use-def chain has exactly one end, which is the AllocTensorOp.
190 LogicalResult
insertSliceAnchoredAllocTensorEliminationStep(RewriterBase & rewriter,Operation * op,AnalysisState & state)191 mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
192 RewriterBase &rewriter, Operation *op, AnalysisState &state) {
193 return eliminateAllocTensors(
194 rewriter, op, state,
195 /*anchorMatchFunc=*/
196 [&](OpOperand &operand, SmallVector<Value> &neededValues) {
197 auto insertSliceOp =
198 dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
199 if (!insertSliceOp)
200 return false;
201 if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
202 return false;
203
204 // Collect all values that are needed to construct the replacement op.
205 neededValues.append(insertSliceOp.getOffsets().begin(),
206 insertSliceOp.getOffsets().end());
207 neededValues.append(insertSliceOp.getSizes().begin(),
208 insertSliceOp.getSizes().end());
209 neededValues.append(insertSliceOp.getStrides().begin(),
210 insertSliceOp.getStrides().end());
211 neededValues.push_back(insertSliceOp.getDest());
212
213 return true;
214 },
215 /*rewriteFunc=*/
216 [](OpBuilder &b, Location loc, OpOperand &operand) {
217 auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
218 auto extractOp = b.create<tensor::ExtractSliceOp>(
219 loc, insertOp.getSourceType(), insertOp.getDest(),
220 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
221 insertOp.getMixedStrides());
222 return extractOp.getResult();
223 });
224 }
225
226 namespace {
227 struct AllocTensorElimination
228 : public AllocTensorEliminationBase<AllocTensorElimination> {
229 AllocTensorElimination() = default;
230
231 void runOnOperation() override;
232
getDependentDialects__anon8ff24c420711::AllocTensorElimination233 void getDependentDialects(DialectRegistry ®istry) const override {
234 registry
235 .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
236 }
237 };
238 } // namespace
239
runOnOperation()240 void AllocTensorElimination::runOnOperation() {
241 Operation *op = getOperation();
242 OneShotBufferizationOptions options;
243 OneShotAnalysisState state(op, options);
244 if (failed(analyzeOp(op, state))) {
245 signalPassFailure();
246 return;
247 }
248
249 IRRewriter rewriter(op->getContext());
250 if (failed(bufferization::insertSliceAnchoredAllocTensorEliminationStep(
251 rewriter, op, state)))
252 signalPassFailure();
253 }
254
createAllocTensorEliminationPass()255 std::unique_ptr<Pass> mlir::bufferization::createAllocTensorEliminationPass() {
256 return std::make_unique<AllocTensorElimination>();
257 }
258