1 //===- AllocTensorElimination.cpp - alloc_tensor op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/Dominance.h"
18 #include "mlir/Pass/Pass.h"
19 
20 using namespace mlir;
21 using namespace mlir::bufferization;
22 
23 /// Return true if all `neededValues` are in scope at the given
24 /// `insertionPoint`.
25 static bool
26 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
27                                    Operation *insertionPoint,
28                                    const SmallVector<Value> &neededValues) {
29   for (Value val : neededValues) {
30     if (auto bbArg = val.dyn_cast<BlockArgument>()) {
31       Block *owner = bbArg.getOwner();
32       if (!owner->findAncestorOpInBlock(*insertionPoint))
33         return false;
34     } else {
35       auto opResult = val.cast<OpResult>();
36       if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
37         return false;
38     }
39   }
40   return true;
41 }
42 
43 /// Return true if the given `insertionPoint` dominates all uses of
44 /// `allocTensorOp`.
45 static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
46                                         Operation *insertionPoint,
47                                         Operation *allocTensorOp) {
48   for (Operation *user : allocTensorOp->getUsers())
49     if (!domInfo.dominates(insertionPoint, user))
50       return false;
51   return true;
52 }
53 
54 /// Find a valid insertion point for a replacement of `allocTensorOp`, assuming
55 /// that the replacement may use any value from `neededValues`.
56 static Operation *
57 findValidInsertionPoint(Operation *allocTensorOp,
58                         const SmallVector<Value> &neededValues) {
59   DominanceInfo domInfo;
60 
61   // Gather all possible insertion points: the location of `allocTensorOp` and
62   // right after the definition of each value in `neededValues`.
63   SmallVector<Operation *> insertionPointCandidates;
64   insertionPointCandidates.push_back(allocTensorOp);
65   for (Value val : neededValues) {
66     // Note: The anchor op is using all of `neededValues`, so:
67     // * in case of a block argument: There must be at least one op in the block
68     //                                (the anchor op or one of its parents).
69     // * in case of an OpResult: There must be at least one op right after the
70     //                           defining op (the anchor op or one of its
71     //                           parents).
72     if (auto bbArg = val.dyn_cast<BlockArgument>()) {
73       insertionPointCandidates.push_back(
74           &bbArg.getOwner()->getOperations().front());
75     } else {
76       insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
77     }
78   }
79 
80   // Select first matching insertion point.
81   for (Operation *insertionPoint : insertionPointCandidates) {
82     // Check if all needed values are in scope.
83     if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
84                                             neededValues))
85       continue;
86     // Check if the insertion point is before all uses.
87     if (!insertionPointDominatesUses(domInfo, insertionPoint, allocTensorOp))
88       continue;
89     return insertionPoint;
90   }
91 
92   // No suitable insertion point was found.
93   return nullptr;
94 }
95 
96 /// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp is replaced
97 /// with the result of `rewriteFunc` if it is anchored on a matching
98 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
99 /// chain, starting from the OpOperand and always following the aliasing
100 /// OpOperand, that eventually ends at a single AllocTensorOp.
101 LogicalResult mlir::bufferization::eliminateAllocTensors(
102     RewriterBase &rewriter, Operation *op, AnalysisState &state,
103     AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
104   OpBuilder::InsertionGuard g(rewriter);
105 
106   WalkResult status = op->walk([&](Operation *op) {
107     for (OpOperand &operand : op->getOpOperands()) {
108       // Skip operands that do not bufferize inplace.
109       if (!state.isInPlace(operand))
110         continue;
111       // All values that are needed to create the replacement op.
112       SmallVector<Value> neededValues;
113       // Is this a matching OpOperand?
114       if (!anchorMatchFunc(operand, neededValues))
115         continue;
116       SetVector<Value> maybeAllocTensor =
117           state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
118             // Continue traversal until this function returns true.
119             OpResult opResult = val.dyn_cast<OpResult>();
120             if (!opResult)
121               return true;
122             SmallVector<OpOperand *> opOperands =
123                 state.getAliasingOpOperand(opResult);
124             if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
125                   return state.isInPlace(*operand);
126                 }))
127               return true;
128             // Only equivalent tensors are supported at the moment.
129             // TODO: Support cases such as extract_slice(alloc_tensor)
130             return !llvm::all_of(opOperands, [&](OpOperand *operand) {
131               return state.areEquivalentBufferizedValues(operand->get(),
132                                                          opResult);
133             });
134           });
135 
136       // Replace only if the reverse use-def chain ends at exactly one
137       // AllocTensorOp.
138       if (maybeAllocTensor.size() != 1 ||
139           !maybeAllocTensor.front().getDefiningOp<AllocTensorOp>())
140         return WalkResult::skip();
141       Value allocTensor = maybeAllocTensor.front();
142 
143       // Find a suitable insertion point.
144       Operation *insertionPoint =
145           findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues);
146       if (!insertionPoint)
147         continue;
148 
149       // Create a replacement for the AllocTensorOp.
150       rewriter.setInsertionPoint(insertionPoint);
151       Value replacement = rewriteFunc(rewriter, allocTensor.getLoc(), operand);
152       if (!replacement)
153         continue;
154 
155       // Replace the AllocTensorOp.
156       rewriter.replaceOp(allocTensor.getDefiningOp(), replacement);
157     }
158 
159     // Advance to the next operation.
160     return WalkResult::advance();
161   });
162 
163   return failure(status.wasInterrupted());
164 }
165 
166 /// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp can be
167 /// eliminated if it is eventually inserted into another tensor (and some other
168 /// conditions are met).
169 ///
170 /// E.g.:
171 /// %0 = linalg.alloc_tensor
172 /// %1 = linalg.fill(%cst, %0) {inplace = [true]}
173 /// %2 = tensor.insert_slice %1 into %t[10][20][1]
174 ///
175 /// AllocTensorOp elimination will try to fill %t inplace instead of filling a
176 /// new allocation %0 and inserting it into %t. This is done by replacing the
177 /// AllocTensorOp with:
178 ///
179 /// %0 = tensor.extract_slice %t[10][20][1]
180 ///
181 /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
182 /// those bufferize inplace in the absence of other conflicts.
183 ///
184 /// Starting from an InsertSliceOp, an AllocTensorOp at the end of the insert
185 /// source's reverse use-def chain is eliminated if:
186 /// * On the reverse use-def chain path from the InsertSliceOp to the
187 ///   AllocTensorOp, all ops were decided to bufferize inplace and the buffer
188 ///   relation is "equivalent" (TODO: can be relaxed if needed).
189 /// * The reverse use-def chain has exactly one end, which is the AllocTensorOp.
190 LogicalResult
191 mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
192     RewriterBase &rewriter, Operation *op, AnalysisState &state) {
193   return eliminateAllocTensors(
194       rewriter, op, state,
195       /*anchorMatchFunc=*/
196       [&](OpOperand &operand, SmallVector<Value> &neededValues) {
197         auto insertSliceOp =
198             dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
199         if (!insertSliceOp)
200           return false;
201         if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
202           return false;
203 
204         // Collect all values that are needed to construct the replacement op.
205         neededValues.append(insertSliceOp.getOffsets().begin(),
206                             insertSliceOp.getOffsets().end());
207         neededValues.append(insertSliceOp.getSizes().begin(),
208                             insertSliceOp.getSizes().end());
209         neededValues.append(insertSliceOp.getStrides().begin(),
210                             insertSliceOp.getStrides().end());
211         neededValues.push_back(insertSliceOp.getDest());
212 
213         return true;
214       },
215       /*rewriteFunc=*/
216       [](OpBuilder &b, Location loc, OpOperand &operand) {
217         auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
218         auto extractOp = b.create<tensor::ExtractSliceOp>(
219             loc, insertOp.getSourceType(), insertOp.getDest(),
220             insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
221             insertOp.getMixedStrides());
222         return extractOp.getResult();
223       });
224 }
225 
226 namespace {
227 struct AllocTensorElimination
228     : public AllocTensorEliminationBase<AllocTensorElimination> {
229   AllocTensorElimination() = default;
230 
231   void runOnOperation() override;
232 
233   void getDependentDialects(DialectRegistry &registry) const override {
234     registry
235         .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
236   }
237 };
238 } // namespace
239 
240 void AllocTensorElimination::runOnOperation() {
241   Operation *op = getOperation();
242   OneShotBufferizationOptions options;
243   OneShotAnalysisState state(op, options);
244   if (failed(analyzeOp(op, state))) {
245     signalPassFailure();
246     return;
247   }
248 
249   IRRewriter rewriter(op->getContext());
250   if (failed(bufferization::insertSliceAnchoredAllocTensorEliminationStep(
251           rewriter, op, state)))
252     signalPassFailure();
253 }
254 
255 std::unique_ptr<Pass> mlir::bufferization::createAllocTensorEliminationPass() {
256   return std::make_unique<AllocTensorElimination>();
257 }
258