1ffdbecccSMatthias Springer //===- AllocTensorElimination.cpp - alloc_tensor op elimination -----------===//
2ffdbecccSMatthias Springer //
3ffdbecccSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ffdbecccSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5ffdbecccSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ffdbecccSMatthias Springer //
7ffdbecccSMatthias Springer //===----------------------------------------------------------------------===//
8ffdbecccSMatthias Springer
9ffdbecccSMatthias Springer #include "PassDetail.h"
10ffdbecccSMatthias Springer
11ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h"
14ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15ffdbecccSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
16ffdbecccSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
17ffdbecccSMatthias Springer #include "mlir/IR/Dominance.h"
18ffdbecccSMatthias Springer #include "mlir/Pass/Pass.h"
19ffdbecccSMatthias Springer
20ffdbecccSMatthias Springer using namespace mlir;
21ffdbecccSMatthias Springer using namespace mlir::bufferization;
22ffdbecccSMatthias Springer
23ffdbecccSMatthias Springer /// Return true if all `neededValues` are in scope at the given
24ffdbecccSMatthias Springer /// `insertionPoint`.
25ffdbecccSMatthias Springer static bool
neededValuesDominateInsertionPoint(const DominanceInfo & domInfo,Operation * insertionPoint,const SmallVector<Value> & neededValues)26ffdbecccSMatthias Springer neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
27ffdbecccSMatthias Springer Operation *insertionPoint,
28ffdbecccSMatthias Springer const SmallVector<Value> &neededValues) {
29ffdbecccSMatthias Springer for (Value val : neededValues) {
30ffdbecccSMatthias Springer if (auto bbArg = val.dyn_cast<BlockArgument>()) {
31ffdbecccSMatthias Springer Block *owner = bbArg.getOwner();
32ffdbecccSMatthias Springer if (!owner->findAncestorOpInBlock(*insertionPoint))
33ffdbecccSMatthias Springer return false;
34ffdbecccSMatthias Springer } else {
35ffdbecccSMatthias Springer auto opResult = val.cast<OpResult>();
36ffdbecccSMatthias Springer if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
37ffdbecccSMatthias Springer return false;
38ffdbecccSMatthias Springer }
39ffdbecccSMatthias Springer }
40ffdbecccSMatthias Springer return true;
41ffdbecccSMatthias Springer }
42ffdbecccSMatthias Springer
43ffdbecccSMatthias Springer /// Return true if the given `insertionPoint` dominates all uses of
44ffdbecccSMatthias Springer /// `allocTensorOp`.
insertionPointDominatesUses(const DominanceInfo & domInfo,Operation * insertionPoint,Operation * allocTensorOp)45ffdbecccSMatthias Springer static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
46ffdbecccSMatthias Springer Operation *insertionPoint,
47ffdbecccSMatthias Springer Operation *allocTensorOp) {
48ffdbecccSMatthias Springer for (Operation *user : allocTensorOp->getUsers())
49ffdbecccSMatthias Springer if (!domInfo.dominates(insertionPoint, user))
50ffdbecccSMatthias Springer return false;
51ffdbecccSMatthias Springer return true;
52ffdbecccSMatthias Springer }
53ffdbecccSMatthias Springer
54ffdbecccSMatthias Springer /// Find a valid insertion point for a replacement of `allocTensorOp`, assuming
55ffdbecccSMatthias Springer /// that the replacement may use any value from `neededValues`.
56ffdbecccSMatthias Springer static Operation *
findValidInsertionPoint(Operation * allocTensorOp,const SmallVector<Value> & neededValues)57ffdbecccSMatthias Springer findValidInsertionPoint(Operation *allocTensorOp,
58ffdbecccSMatthias Springer const SmallVector<Value> &neededValues) {
59ffdbecccSMatthias Springer DominanceInfo domInfo;
60ffdbecccSMatthias Springer
61ffdbecccSMatthias Springer // Gather all possible insertion points: the location of `allocTensorOp` and
62ffdbecccSMatthias Springer // right after the definition of each value in `neededValues`.
63ffdbecccSMatthias Springer SmallVector<Operation *> insertionPointCandidates;
64ffdbecccSMatthias Springer insertionPointCandidates.push_back(allocTensorOp);
65ffdbecccSMatthias Springer for (Value val : neededValues) {
66ffdbecccSMatthias Springer // Note: The anchor op is using all of `neededValues`, so:
67ffdbecccSMatthias Springer // * in case of a block argument: There must be at least one op in the block
68ffdbecccSMatthias Springer // (the anchor op or one of its parents).
69ffdbecccSMatthias Springer // * in case of an OpResult: There must be at least one op right after the
70ffdbecccSMatthias Springer // defining op (the anchor op or one of its
71ffdbecccSMatthias Springer // parents).
72ffdbecccSMatthias Springer if (auto bbArg = val.dyn_cast<BlockArgument>()) {
73ffdbecccSMatthias Springer insertionPointCandidates.push_back(
74ffdbecccSMatthias Springer &bbArg.getOwner()->getOperations().front());
75ffdbecccSMatthias Springer } else {
76ffdbecccSMatthias Springer insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
77ffdbecccSMatthias Springer }
78ffdbecccSMatthias Springer }
79ffdbecccSMatthias Springer
80ffdbecccSMatthias Springer // Select first matching insertion point.
81ffdbecccSMatthias Springer for (Operation *insertionPoint : insertionPointCandidates) {
82ffdbecccSMatthias Springer // Check if all needed values are in scope.
83ffdbecccSMatthias Springer if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
84ffdbecccSMatthias Springer neededValues))
85ffdbecccSMatthias Springer continue;
86ffdbecccSMatthias Springer // Check if the insertion point is before all uses.
87ffdbecccSMatthias Springer if (!insertionPointDominatesUses(domInfo, insertionPoint, allocTensorOp))
88ffdbecccSMatthias Springer continue;
89ffdbecccSMatthias Springer return insertionPoint;
90ffdbecccSMatthias Springer }
91ffdbecccSMatthias Springer
92ffdbecccSMatthias Springer // No suitable insertion point was found.
93ffdbecccSMatthias Springer return nullptr;
94ffdbecccSMatthias Springer }
95ffdbecccSMatthias Springer
96ffdbecccSMatthias Springer /// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp is replaced
97ffdbecccSMatthias Springer /// with the result of `rewriteFunc` if it is anchored on a matching
98ffdbecccSMatthias Springer /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
99ffdbecccSMatthias Springer /// chain, starting from the OpOperand and always following the aliasing
100ffdbecccSMatthias Springer /// OpOperand, that eventually ends at a single AllocTensorOp.
eliminateAllocTensors(RewriterBase & rewriter,Operation * op,AnalysisState & state,AnchorMatchFn anchorMatchFunc,RewriteFn rewriteFunc)101ffdbecccSMatthias Springer LogicalResult mlir::bufferization::eliminateAllocTensors(
102ffdbecccSMatthias Springer RewriterBase &rewriter, Operation *op, AnalysisState &state,
103ffdbecccSMatthias Springer AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
104ffdbecccSMatthias Springer OpBuilder::InsertionGuard g(rewriter);
105ffdbecccSMatthias Springer
106ffdbecccSMatthias Springer WalkResult status = op->walk([&](Operation *op) {
107ffdbecccSMatthias Springer for (OpOperand &operand : op->getOpOperands()) {
108ffdbecccSMatthias Springer // Skip operands that do not bufferize inplace.
109ffdbecccSMatthias Springer if (!state.isInPlace(operand))
110ffdbecccSMatthias Springer continue;
111ffdbecccSMatthias Springer // All values that are needed to create the replacement op.
112ffdbecccSMatthias Springer SmallVector<Value> neededValues;
113ffdbecccSMatthias Springer // Is this a matching OpOperand?
114ffdbecccSMatthias Springer if (!anchorMatchFunc(operand, neededValues))
115ffdbecccSMatthias Springer continue;
116ffdbecccSMatthias Springer SetVector<Value> maybeAllocTensor =
117ffdbecccSMatthias Springer state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
118ffdbecccSMatthias Springer // Continue traversal until this function returns true.
119ffdbecccSMatthias Springer OpResult opResult = val.dyn_cast<OpResult>();
120ffdbecccSMatthias Springer if (!opResult)
121ffdbecccSMatthias Springer return true;
122ffdbecccSMatthias Springer SmallVector<OpOperand *> opOperands =
123ffdbecccSMatthias Springer state.getAliasingOpOperand(opResult);
124ffdbecccSMatthias Springer if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
125ffdbecccSMatthias Springer return state.isInPlace(*operand);
126ffdbecccSMatthias Springer }))
127ffdbecccSMatthias Springer return true;
128ffdbecccSMatthias Springer // Only equivalent tensors are supported at the moment.
129ffdbecccSMatthias Springer // TODO: Support cases such as extract_slice(alloc_tensor)
130ffdbecccSMatthias Springer return !llvm::all_of(opOperands, [&](OpOperand *operand) {
131ffdbecccSMatthias Springer return state.areEquivalentBufferizedValues(operand->get(),
132ffdbecccSMatthias Springer opResult);
133ffdbecccSMatthias Springer });
134ffdbecccSMatthias Springer });
135ffdbecccSMatthias Springer
136ffdbecccSMatthias Springer // Replace only if the reverse use-def chain ends at exactly one
137ffdbecccSMatthias Springer // AllocTensorOp.
138ffdbecccSMatthias Springer if (maybeAllocTensor.size() != 1 ||
139ffdbecccSMatthias Springer !maybeAllocTensor.front().getDefiningOp<AllocTensorOp>())
140ffdbecccSMatthias Springer return WalkResult::skip();
141ffdbecccSMatthias Springer Value allocTensor = maybeAllocTensor.front();
142ffdbecccSMatthias Springer
143ffdbecccSMatthias Springer // Find a suitable insertion point.
144ffdbecccSMatthias Springer Operation *insertionPoint =
145ffdbecccSMatthias Springer findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues);
146ffdbecccSMatthias Springer if (!insertionPoint)
147ffdbecccSMatthias Springer continue;
148ffdbecccSMatthias Springer
149ffdbecccSMatthias Springer // Create a replacement for the AllocTensorOp.
150ffdbecccSMatthias Springer rewriter.setInsertionPoint(insertionPoint);
151ffdbecccSMatthias Springer Value replacement = rewriteFunc(rewriter, allocTensor.getLoc(), operand);
152ffdbecccSMatthias Springer if (!replacement)
153ffdbecccSMatthias Springer continue;
154ffdbecccSMatthias Springer
155ffdbecccSMatthias Springer // Replace the AllocTensorOp.
156ffdbecccSMatthias Springer rewriter.replaceOp(allocTensor.getDefiningOp(), replacement);
157ffdbecccSMatthias Springer }
158ffdbecccSMatthias Springer
159ffdbecccSMatthias Springer // Advance to the next operation.
160ffdbecccSMatthias Springer return WalkResult::advance();
161ffdbecccSMatthias Springer });
162ffdbecccSMatthias Springer
163ffdbecccSMatthias Springer return failure(status.wasInterrupted());
164ffdbecccSMatthias Springer }
165ffdbecccSMatthias Springer
166ffdbecccSMatthias Springer /// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp can be
167ffdbecccSMatthias Springer /// eliminated if it is eventually inserted into another tensor (and some other
168ffdbecccSMatthias Springer /// conditions are met).
169ffdbecccSMatthias Springer ///
170ffdbecccSMatthias Springer /// E.g.:
171ffdbecccSMatthias Springer /// %0 = linalg.alloc_tensor
172ffdbecccSMatthias Springer /// %1 = linalg.fill(%cst, %0) {inplace = [true]}
173ffdbecccSMatthias Springer /// %2 = tensor.insert_slice %1 into %t[10][20][1]
174ffdbecccSMatthias Springer ///
175ffdbecccSMatthias Springer /// AllocTensorOp elimination will try to fill %t inplace instead of filling a
176ffdbecccSMatthias Springer /// new allocation %0 and inserting it into %t. This is done by replacing the
177ffdbecccSMatthias Springer /// AllocTensorOp with:
178ffdbecccSMatthias Springer ///
179ffdbecccSMatthias Springer /// %0 = tensor.extract_slice %t[10][20][1]
180ffdbecccSMatthias Springer ///
181ffdbecccSMatthias Springer /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
182ffdbecccSMatthias Springer /// those bufferize inplace in the absence of other conflicts.
183ffdbecccSMatthias Springer ///
184ffdbecccSMatthias Springer /// Starting from an InsertSliceOp, an AllocTensorOp at the end of the insert
185ffdbecccSMatthias Springer /// source's reverse use-def chain is eliminated if:
186ffdbecccSMatthias Springer /// * On the reverse use-def chain path from the InsertSliceOp to the
187ffdbecccSMatthias Springer /// AllocTensorOp, all ops were decided to bufferize inplace and the buffer
188ffdbecccSMatthias Springer /// relation is "equivalent" (TODO: can be relaxed if needed).
189ffdbecccSMatthias Springer /// * The reverse use-def chain has exactly one end, which is the AllocTensorOp.
190ffdbecccSMatthias Springer LogicalResult
insertSliceAnchoredAllocTensorEliminationStep(RewriterBase & rewriter,Operation * op,AnalysisState & state)191ffdbecccSMatthias Springer mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
192ffdbecccSMatthias Springer RewriterBase &rewriter, Operation *op, AnalysisState &state) {
193ffdbecccSMatthias Springer return eliminateAllocTensors(
194ffdbecccSMatthias Springer rewriter, op, state,
195ffdbecccSMatthias Springer /*anchorMatchFunc=*/
196ffdbecccSMatthias Springer [&](OpOperand &operand, SmallVector<Value> &neededValues) {
197ffdbecccSMatthias Springer auto insertSliceOp =
198ffdbecccSMatthias Springer dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
199ffdbecccSMatthias Springer if (!insertSliceOp)
200ffdbecccSMatthias Springer return false;
201ffdbecccSMatthias Springer if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
202ffdbecccSMatthias Springer return false;
203ffdbecccSMatthias Springer
204ffdbecccSMatthias Springer // Collect all values that are needed to construct the replacement op.
20504235d07SJacques Pienaar neededValues.append(insertSliceOp.getOffsets().begin(),
20604235d07SJacques Pienaar insertSliceOp.getOffsets().end());
20704235d07SJacques Pienaar neededValues.append(insertSliceOp.getSizes().begin(),
20804235d07SJacques Pienaar insertSliceOp.getSizes().end());
20904235d07SJacques Pienaar neededValues.append(insertSliceOp.getStrides().begin(),
21004235d07SJacques Pienaar insertSliceOp.getStrides().end());
21104235d07SJacques Pienaar neededValues.push_back(insertSliceOp.getDest());
212ffdbecccSMatthias Springer
213ffdbecccSMatthias Springer return true;
214ffdbecccSMatthias Springer },
215ffdbecccSMatthias Springer /*rewriteFunc=*/
216ffdbecccSMatthias Springer [](OpBuilder &b, Location loc, OpOperand &operand) {
217ffdbecccSMatthias Springer auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
218ffdbecccSMatthias Springer auto extractOp = b.create<tensor::ExtractSliceOp>(
219*6c3c5f80SMatthias Springer loc, insertOp.getSourceType(), insertOp.getDest(),
220*6c3c5f80SMatthias Springer insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
221*6c3c5f80SMatthias Springer insertOp.getMixedStrides());
22204235d07SJacques Pienaar return extractOp.getResult();
223ffdbecccSMatthias Springer });
224ffdbecccSMatthias Springer }
225ffdbecccSMatthias Springer
226ffdbecccSMatthias Springer namespace {
227ffdbecccSMatthias Springer struct AllocTensorElimination
228ffdbecccSMatthias Springer : public AllocTensorEliminationBase<AllocTensorElimination> {
229ffdbecccSMatthias Springer AllocTensorElimination() = default;
230ffdbecccSMatthias Springer
231ffdbecccSMatthias Springer void runOnOperation() override;
232ffdbecccSMatthias Springer
getDependentDialects__anon8ff24c420711::AllocTensorElimination233ffdbecccSMatthias Springer void getDependentDialects(DialectRegistry ®istry) const override {
234ffdbecccSMatthias Springer registry
235ffdbecccSMatthias Springer .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
236ffdbecccSMatthias Springer }
237ffdbecccSMatthias Springer };
238ffdbecccSMatthias Springer } // namespace
239ffdbecccSMatthias Springer
runOnOperation()240ffdbecccSMatthias Springer void AllocTensorElimination::runOnOperation() {
241ffdbecccSMatthias Springer Operation *op = getOperation();
242ffdbecccSMatthias Springer OneShotBufferizationOptions options;
243ffdbecccSMatthias Springer OneShotAnalysisState state(op, options);
244ffdbecccSMatthias Springer if (failed(analyzeOp(op, state))) {
245ffdbecccSMatthias Springer signalPassFailure();
246ffdbecccSMatthias Springer return;
247ffdbecccSMatthias Springer }
248ffdbecccSMatthias Springer
249ffdbecccSMatthias Springer IRRewriter rewriter(op->getContext());
250ffdbecccSMatthias Springer if (failed(bufferization::insertSliceAnchoredAllocTensorEliminationStep(
251ffdbecccSMatthias Springer rewriter, op, state)))
252ffdbecccSMatthias Springer signalPassFailure();
253ffdbecccSMatthias Springer }
254ffdbecccSMatthias Springer
createAllocTensorEliminationPass()255ffdbecccSMatthias Springer std::unique_ptr<Pass> mlir::bufferization::createAllocTensorEliminationPass() {
256ffdbecccSMatthias Springer return std::make_unique<AllocTensorElimination>();
257ffdbecccSMatthias Springer }
258