1ffdbecccSMatthias Springer //===- AllocTensorElimination.cpp - alloc_tensor op elimination -----------===//
2ffdbecccSMatthias Springer //
3ffdbecccSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ffdbecccSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5ffdbecccSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ffdbecccSMatthias Springer //
7ffdbecccSMatthias Springer //===----------------------------------------------------------------------===//
8ffdbecccSMatthias Springer 
9ffdbecccSMatthias Springer #include "PassDetail.h"
10ffdbecccSMatthias Springer 
11ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h"
14ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
16ffdbecccSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
17ffdbecccSMatthias Springer #include "mlir/IR/Dominance.h"
18ffdbecccSMatthias Springer #include "mlir/Pass/Pass.h"
19ffdbecccSMatthias Springer 
20ffdbecccSMatthias Springer using namespace mlir;
21ffdbecccSMatthias Springer using namespace mlir::bufferization;
22ffdbecccSMatthias Springer 
23ffdbecccSMatthias Springer /// Return true if all `neededValues` are in scope at the given
24ffdbecccSMatthias Springer /// `insertionPoint`.
25ffdbecccSMatthias Springer static bool
neededValuesDominateInsertionPoint(const DominanceInfo & domInfo,Operation * insertionPoint,const SmallVector<Value> & neededValues)26ffdbecccSMatthias Springer neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
27ffdbecccSMatthias Springer                                    Operation *insertionPoint,
28ffdbecccSMatthias Springer                                    const SmallVector<Value> &neededValues) {
29ffdbecccSMatthias Springer   for (Value val : neededValues) {
30ffdbecccSMatthias Springer     if (auto bbArg = val.dyn_cast<BlockArgument>()) {
31ffdbecccSMatthias Springer       Block *owner = bbArg.getOwner();
32ffdbecccSMatthias Springer       if (!owner->findAncestorOpInBlock(*insertionPoint))
33ffdbecccSMatthias Springer         return false;
34ffdbecccSMatthias Springer     } else {
35ffdbecccSMatthias Springer       auto opResult = val.cast<OpResult>();
36ffdbecccSMatthias Springer       if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
37ffdbecccSMatthias Springer         return false;
38ffdbecccSMatthias Springer     }
39ffdbecccSMatthias Springer   }
40ffdbecccSMatthias Springer   return true;
41ffdbecccSMatthias Springer }
42ffdbecccSMatthias Springer 
43ffdbecccSMatthias Springer /// Return true if the given `insertionPoint` dominates all uses of
44ffdbecccSMatthias Springer /// `allocTensorOp`.
insertionPointDominatesUses(const DominanceInfo & domInfo,Operation * insertionPoint,Operation * allocTensorOp)45ffdbecccSMatthias Springer static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
46ffdbecccSMatthias Springer                                         Operation *insertionPoint,
47ffdbecccSMatthias Springer                                         Operation *allocTensorOp) {
48ffdbecccSMatthias Springer   for (Operation *user : allocTensorOp->getUsers())
49ffdbecccSMatthias Springer     if (!domInfo.dominates(insertionPoint, user))
50ffdbecccSMatthias Springer       return false;
51ffdbecccSMatthias Springer   return true;
52ffdbecccSMatthias Springer }
53ffdbecccSMatthias Springer 
54ffdbecccSMatthias Springer /// Find a valid insertion point for a replacement of `allocTensorOp`, assuming
55ffdbecccSMatthias Springer /// that the replacement may use any value from `neededValues`.
56ffdbecccSMatthias Springer static Operation *
findValidInsertionPoint(Operation * allocTensorOp,const SmallVector<Value> & neededValues)57ffdbecccSMatthias Springer findValidInsertionPoint(Operation *allocTensorOp,
58ffdbecccSMatthias Springer                         const SmallVector<Value> &neededValues) {
59ffdbecccSMatthias Springer   DominanceInfo domInfo;
60ffdbecccSMatthias Springer 
61ffdbecccSMatthias Springer   // Gather all possible insertion points: the location of `allocTensorOp` and
62ffdbecccSMatthias Springer   // right after the definition of each value in `neededValues`.
63ffdbecccSMatthias Springer   SmallVector<Operation *> insertionPointCandidates;
64ffdbecccSMatthias Springer   insertionPointCandidates.push_back(allocTensorOp);
65ffdbecccSMatthias Springer   for (Value val : neededValues) {
66ffdbecccSMatthias Springer     // Note: The anchor op is using all of `neededValues`, so:
67ffdbecccSMatthias Springer     // * in case of a block argument: There must be at least one op in the block
68ffdbecccSMatthias Springer     //                                (the anchor op or one of its parents).
69ffdbecccSMatthias Springer     // * in case of an OpResult: There must be at least one op right after the
70ffdbecccSMatthias Springer     //                           defining op (the anchor op or one of its
71ffdbecccSMatthias Springer     //                           parents).
72ffdbecccSMatthias Springer     if (auto bbArg = val.dyn_cast<BlockArgument>()) {
73ffdbecccSMatthias Springer       insertionPointCandidates.push_back(
74ffdbecccSMatthias Springer           &bbArg.getOwner()->getOperations().front());
75ffdbecccSMatthias Springer     } else {
76ffdbecccSMatthias Springer       insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
77ffdbecccSMatthias Springer     }
78ffdbecccSMatthias Springer   }
79ffdbecccSMatthias Springer 
80ffdbecccSMatthias Springer   // Select first matching insertion point.
81ffdbecccSMatthias Springer   for (Operation *insertionPoint : insertionPointCandidates) {
82ffdbecccSMatthias Springer     // Check if all needed values are in scope.
83ffdbecccSMatthias Springer     if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
84ffdbecccSMatthias Springer                                             neededValues))
85ffdbecccSMatthias Springer       continue;
86ffdbecccSMatthias Springer     // Check if the insertion point is before all uses.
87ffdbecccSMatthias Springer     if (!insertionPointDominatesUses(domInfo, insertionPoint, allocTensorOp))
88ffdbecccSMatthias Springer       continue;
89ffdbecccSMatthias Springer     return insertionPoint;
90ffdbecccSMatthias Springer   }
91ffdbecccSMatthias Springer 
92ffdbecccSMatthias Springer   // No suitable insertion point was found.
93ffdbecccSMatthias Springer   return nullptr;
94ffdbecccSMatthias Springer }
95ffdbecccSMatthias Springer 
96ffdbecccSMatthias Springer /// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp is replaced
97ffdbecccSMatthias Springer /// with the result of `rewriteFunc` if it is anchored on a matching
98ffdbecccSMatthias Springer /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
99ffdbecccSMatthias Springer /// chain, starting from the OpOperand and always following the aliasing
100ffdbecccSMatthias Springer /// OpOperand, that eventually ends at a single AllocTensorOp.
eliminateAllocTensors(RewriterBase & rewriter,Operation * op,AnalysisState & state,AnchorMatchFn anchorMatchFunc,RewriteFn rewriteFunc)101ffdbecccSMatthias Springer LogicalResult mlir::bufferization::eliminateAllocTensors(
102ffdbecccSMatthias Springer     RewriterBase &rewriter, Operation *op, AnalysisState &state,
103ffdbecccSMatthias Springer     AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
104ffdbecccSMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
105ffdbecccSMatthias Springer 
106ffdbecccSMatthias Springer   WalkResult status = op->walk([&](Operation *op) {
107ffdbecccSMatthias Springer     for (OpOperand &operand : op->getOpOperands()) {
108ffdbecccSMatthias Springer       // Skip operands that do not bufferize inplace.
109ffdbecccSMatthias Springer       if (!state.isInPlace(operand))
110ffdbecccSMatthias Springer         continue;
111ffdbecccSMatthias Springer       // All values that are needed to create the replacement op.
112ffdbecccSMatthias Springer       SmallVector<Value> neededValues;
113ffdbecccSMatthias Springer       // Is this a matching OpOperand?
114ffdbecccSMatthias Springer       if (!anchorMatchFunc(operand, neededValues))
115ffdbecccSMatthias Springer         continue;
116ffdbecccSMatthias Springer       SetVector<Value> maybeAllocTensor =
117ffdbecccSMatthias Springer           state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
118ffdbecccSMatthias Springer             // Continue traversal until this function returns true.
119ffdbecccSMatthias Springer             OpResult opResult = val.dyn_cast<OpResult>();
120ffdbecccSMatthias Springer             if (!opResult)
121ffdbecccSMatthias Springer               return true;
122ffdbecccSMatthias Springer             SmallVector<OpOperand *> opOperands =
123ffdbecccSMatthias Springer                 state.getAliasingOpOperand(opResult);
124ffdbecccSMatthias Springer             if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
125ffdbecccSMatthias Springer                   return state.isInPlace(*operand);
126ffdbecccSMatthias Springer                 }))
127ffdbecccSMatthias Springer               return true;
128ffdbecccSMatthias Springer             // Only equivalent tensors are supported at the moment.
129ffdbecccSMatthias Springer             // TODO: Support cases such as extract_slice(alloc_tensor)
130ffdbecccSMatthias Springer             return !llvm::all_of(opOperands, [&](OpOperand *operand) {
131ffdbecccSMatthias Springer               return state.areEquivalentBufferizedValues(operand->get(),
132ffdbecccSMatthias Springer                                                          opResult);
133ffdbecccSMatthias Springer             });
134ffdbecccSMatthias Springer           });
135ffdbecccSMatthias Springer 
136ffdbecccSMatthias Springer       // Replace only if the reverse use-def chain ends at exactly one
137ffdbecccSMatthias Springer       // AllocTensorOp.
138ffdbecccSMatthias Springer       if (maybeAllocTensor.size() != 1 ||
139ffdbecccSMatthias Springer           !maybeAllocTensor.front().getDefiningOp<AllocTensorOp>())
140ffdbecccSMatthias Springer         return WalkResult::skip();
141ffdbecccSMatthias Springer       Value allocTensor = maybeAllocTensor.front();
142ffdbecccSMatthias Springer 
143ffdbecccSMatthias Springer       // Find a suitable insertion point.
144ffdbecccSMatthias Springer       Operation *insertionPoint =
145ffdbecccSMatthias Springer           findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues);
146ffdbecccSMatthias Springer       if (!insertionPoint)
147ffdbecccSMatthias Springer         continue;
148ffdbecccSMatthias Springer 
149ffdbecccSMatthias Springer       // Create a replacement for the AllocTensorOp.
150ffdbecccSMatthias Springer       rewriter.setInsertionPoint(insertionPoint);
151ffdbecccSMatthias Springer       Value replacement = rewriteFunc(rewriter, allocTensor.getLoc(), operand);
152ffdbecccSMatthias Springer       if (!replacement)
153ffdbecccSMatthias Springer         continue;
154ffdbecccSMatthias Springer 
155ffdbecccSMatthias Springer       // Replace the AllocTensorOp.
156ffdbecccSMatthias Springer       rewriter.replaceOp(allocTensor.getDefiningOp(), replacement);
157ffdbecccSMatthias Springer     }
158ffdbecccSMatthias Springer 
159ffdbecccSMatthias Springer     // Advance to the next operation.
160ffdbecccSMatthias Springer     return WalkResult::advance();
161ffdbecccSMatthias Springer   });
162ffdbecccSMatthias Springer 
163ffdbecccSMatthias Springer   return failure(status.wasInterrupted());
164ffdbecccSMatthias Springer }
165ffdbecccSMatthias Springer 
166ffdbecccSMatthias Springer /// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp can be
167ffdbecccSMatthias Springer /// eliminated if it is eventually inserted into another tensor (and some other
168ffdbecccSMatthias Springer /// conditions are met).
169ffdbecccSMatthias Springer ///
170ffdbecccSMatthias Springer /// E.g.:
171ffdbecccSMatthias Springer /// %0 = linalg.alloc_tensor
172ffdbecccSMatthias Springer /// %1 = linalg.fill(%cst, %0) {inplace = [true]}
173ffdbecccSMatthias Springer /// %2 = tensor.insert_slice %1 into %t[10][20][1]
174ffdbecccSMatthias Springer ///
175ffdbecccSMatthias Springer /// AllocTensorOp elimination will try to fill %t inplace instead of filling a
176ffdbecccSMatthias Springer /// new allocation %0 and inserting it into %t. This is done by replacing the
177ffdbecccSMatthias Springer /// AllocTensorOp with:
178ffdbecccSMatthias Springer ///
179ffdbecccSMatthias Springer /// %0 = tensor.extract_slice %t[10][20][1]
180ffdbecccSMatthias Springer ///
181ffdbecccSMatthias Springer /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
182ffdbecccSMatthias Springer /// those bufferize inplace in the absence of other conflicts.
183ffdbecccSMatthias Springer ///
184ffdbecccSMatthias Springer /// Starting from an InsertSliceOp, an AllocTensorOp at the end of the insert
185ffdbecccSMatthias Springer /// source's reverse use-def chain is eliminated if:
186ffdbecccSMatthias Springer /// * On the reverse use-def chain path from the InsertSliceOp to the
187ffdbecccSMatthias Springer ///   AllocTensorOp, all ops were decided to bufferize inplace and the buffer
188ffdbecccSMatthias Springer ///   relation is "equivalent" (TODO: can be relaxed if needed).
189ffdbecccSMatthias Springer /// * The reverse use-def chain has exactly one end, which is the AllocTensorOp.
190ffdbecccSMatthias Springer LogicalResult
insertSliceAnchoredAllocTensorEliminationStep(RewriterBase & rewriter,Operation * op,AnalysisState & state)191ffdbecccSMatthias Springer mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
192ffdbecccSMatthias Springer     RewriterBase &rewriter, Operation *op, AnalysisState &state) {
193ffdbecccSMatthias Springer   return eliminateAllocTensors(
194ffdbecccSMatthias Springer       rewriter, op, state,
195ffdbecccSMatthias Springer       /*anchorMatchFunc=*/
196ffdbecccSMatthias Springer       [&](OpOperand &operand, SmallVector<Value> &neededValues) {
197ffdbecccSMatthias Springer         auto insertSliceOp =
198ffdbecccSMatthias Springer             dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
199ffdbecccSMatthias Springer         if (!insertSliceOp)
200ffdbecccSMatthias Springer           return false;
201ffdbecccSMatthias Springer         if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
202ffdbecccSMatthias Springer           return false;
203ffdbecccSMatthias Springer 
204ffdbecccSMatthias Springer         // Collect all values that are needed to construct the replacement op.
20504235d07SJacques Pienaar         neededValues.append(insertSliceOp.getOffsets().begin(),
20604235d07SJacques Pienaar                             insertSliceOp.getOffsets().end());
20704235d07SJacques Pienaar         neededValues.append(insertSliceOp.getSizes().begin(),
20804235d07SJacques Pienaar                             insertSliceOp.getSizes().end());
20904235d07SJacques Pienaar         neededValues.append(insertSliceOp.getStrides().begin(),
21004235d07SJacques Pienaar                             insertSliceOp.getStrides().end());
21104235d07SJacques Pienaar         neededValues.push_back(insertSliceOp.getDest());
212ffdbecccSMatthias Springer 
213ffdbecccSMatthias Springer         return true;
214ffdbecccSMatthias Springer       },
215ffdbecccSMatthias Springer       /*rewriteFunc=*/
216ffdbecccSMatthias Springer       [](OpBuilder &b, Location loc, OpOperand &operand) {
217ffdbecccSMatthias Springer         auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
218ffdbecccSMatthias Springer         auto extractOp = b.create<tensor::ExtractSliceOp>(
219*6c3c5f80SMatthias Springer             loc, insertOp.getSourceType(), insertOp.getDest(),
220*6c3c5f80SMatthias Springer             insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
221*6c3c5f80SMatthias Springer             insertOp.getMixedStrides());
22204235d07SJacques Pienaar         return extractOp.getResult();
223ffdbecccSMatthias Springer       });
224ffdbecccSMatthias Springer }
225ffdbecccSMatthias Springer 
226ffdbecccSMatthias Springer namespace {
227ffdbecccSMatthias Springer struct AllocTensorElimination
228ffdbecccSMatthias Springer     : public AllocTensorEliminationBase<AllocTensorElimination> {
229ffdbecccSMatthias Springer   AllocTensorElimination() = default;
230ffdbecccSMatthias Springer 
231ffdbecccSMatthias Springer   void runOnOperation() override;
232ffdbecccSMatthias Springer 
getDependentDialects__anon8ff24c420711::AllocTensorElimination233ffdbecccSMatthias Springer   void getDependentDialects(DialectRegistry &registry) const override {
234ffdbecccSMatthias Springer     registry
235ffdbecccSMatthias Springer         .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
236ffdbecccSMatthias Springer   }
237ffdbecccSMatthias Springer };
238ffdbecccSMatthias Springer } // namespace
239ffdbecccSMatthias Springer 
runOnOperation()240ffdbecccSMatthias Springer void AllocTensorElimination::runOnOperation() {
241ffdbecccSMatthias Springer   Operation *op = getOperation();
242ffdbecccSMatthias Springer   OneShotBufferizationOptions options;
243ffdbecccSMatthias Springer   OneShotAnalysisState state(op, options);
244ffdbecccSMatthias Springer   if (failed(analyzeOp(op, state))) {
245ffdbecccSMatthias Springer     signalPassFailure();
246ffdbecccSMatthias Springer     return;
247ffdbecccSMatthias Springer   }
248ffdbecccSMatthias Springer 
249ffdbecccSMatthias Springer   IRRewriter rewriter(op->getContext());
250ffdbecccSMatthias Springer   if (failed(bufferization::insertSliceAnchoredAllocTensorEliminationStep(
251ffdbecccSMatthias Springer           rewriter, op, state)))
252ffdbecccSMatthias Springer     signalPassFailure();
253ffdbecccSMatthias Springer }
254ffdbecccSMatthias Springer 
createAllocTensorEliminationPass()255ffdbecccSMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createAllocTensorEliminationPass() {
256ffdbecccSMatthias Springer   return std::make_unique<AllocTensorElimination>();
257ffdbecccSMatthias Springer }
258