1258dae5dSNicolas Vasilache //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
25c16564bSNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65c16564bSNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
85c16564bSNicolas Vasilache //
969d9e990SChris Lattner // This file implements Analysis functions specific to slicing in Function.
105c16564bSNicolas Vasilache //
115c16564bSNicolas Vasilache //===----------------------------------------------------------------------===//
125c16564bSNicolas Vasilache
135c16564bSNicolas Vasilache #include "mlir/Analysis/SliceAnalysis.h"
1465fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
159ffdc930SRiver Riddle #include "mlir/IR/Operation.h"
160c3923e1SMehdi Amini #include "mlir/Support/LLVM.h"
175c16564bSNicolas Vasilache #include "llvm/ADT/SetVector.h"
18*755dc07dSRiver Riddle #include "llvm/ADT/SmallPtrSet.h"
195c16564bSNicolas Vasilache
205c16564bSNicolas Vasilache ///
2169d9e990SChris Lattner /// Implements Analysis functions specific to slicing in Function.
225c16564bSNicolas Vasilache ///
235c16564bSNicolas Vasilache
245c16564bSNicolas Vasilache using namespace mlir;
255c16564bSNicolas Vasilache
getForwardSliceImpl(Operation * op,SetVector<Operation * > * forwardSlice,TransitiveFilter filter)269c085406SRiver Riddle static void getForwardSliceImpl(Operation *op,
279c085406SRiver Riddle SetVector<Operation *> *forwardSlice,
28c3b0c6a0SNicolas Vasilache TransitiveFilter filter) {
29d01ea0edSNicolas Vasilache if (!op)
305c16564bSNicolas Vasilache return;
315c16564bSNicolas Vasilache
325c16564bSNicolas Vasilache // Evaluate whether we should keep this use.
335c16564bSNicolas Vasilache // This is useful in particular to implement scoping; i.e. return the
34258dae5dSNicolas Vasilache // transitive forwardSlice in the current scope.
35d01ea0edSNicolas Vasilache if (filter && !filter(op))
365c16564bSNicolas Vasilache return;
375c16564bSNicolas Vasilache
38d01ea0edSNicolas Vasilache for (Region ®ion : op->getRegions())
39d01ea0edSNicolas Vasilache for (Block &block : region)
40d01ea0edSNicolas Vasilache for (Operation &blockOp : block)
41d01ea0edSNicolas Vasilache if (forwardSlice->count(&blockOp) == 0)
42d01ea0edSNicolas Vasilache getForwardSliceImpl(&blockOp, forwardSlice, filter);
432f23270aSThomas Raoux for (Value result : op->getResults()) {
442f23270aSThomas Raoux for (Operation *userOp : result.getUsers())
452f23270aSThomas Raoux if (forwardSlice->count(userOp) == 0)
462f23270aSThomas Raoux getForwardSliceImpl(userOp, forwardSlice, filter);
475c16564bSNicolas Vasilache }
485c16564bSNicolas Vasilache
499c085406SRiver Riddle forwardSlice->insert(op);
505c16564bSNicolas Vasilache }
51c3b0c6a0SNicolas Vasilache
getForwardSlice(Operation * op,SetVector<Operation * > * forwardSlice,TransitiveFilter filter)529c085406SRiver Riddle void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
53c3b0c6a0SNicolas Vasilache TransitiveFilter filter) {
549c085406SRiver Riddle getForwardSliceImpl(op, forwardSlice, filter);
559c085406SRiver Riddle // Don't insert the top level operation, we just queried on it and don't
56c3b0c6a0SNicolas Vasilache // want it in the results.
579c085406SRiver Riddle forwardSlice->remove(op);
58c3b0c6a0SNicolas Vasilache
59c3b0c6a0SNicolas Vasilache // Reverse to get back the actual topological order.
60c3b0c6a0SNicolas Vasilache // std::reverse does not work out of the box on SetVector and I want an
61c3b0c6a0SNicolas Vasilache // in-place swap based thing (the real std::reverse, not the LLVM adapter).
629c085406SRiver Riddle std::vector<Operation *> v(forwardSlice->takeVector());
63c3b0c6a0SNicolas Vasilache forwardSlice->insert(v.rbegin(), v.rend());
645c16564bSNicolas Vasilache }
655c16564bSNicolas Vasilache
getForwardSlice(Value root,SetVector<Operation * > * forwardSlice,TransitiveFilter filter)66d01ea0edSNicolas Vasilache void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
67d01ea0edSNicolas Vasilache TransitiveFilter filter) {
68d01ea0edSNicolas Vasilache for (Operation *user : root.getUsers())
69d01ea0edSNicolas Vasilache getForwardSliceImpl(user, forwardSlice, filter);
70d01ea0edSNicolas Vasilache
71d01ea0edSNicolas Vasilache // Reverse to get back the actual topological order.
72d01ea0edSNicolas Vasilache // std::reverse does not work out of the box on SetVector and I want an
73d01ea0edSNicolas Vasilache // in-place swap based thing (the real std::reverse, not the LLVM adapter).
74d01ea0edSNicolas Vasilache std::vector<Operation *> v(forwardSlice->takeVector());
75d01ea0edSNicolas Vasilache forwardSlice->insert(v.rbegin(), v.rend());
76d01ea0edSNicolas Vasilache }
77d01ea0edSNicolas Vasilache
getBackwardSliceImpl(Operation * op,SetVector<Operation * > * backwardSlice,TransitiveFilter filter)789c085406SRiver Riddle static void getBackwardSliceImpl(Operation *op,
799c085406SRiver Riddle SetVector<Operation *> *backwardSlice,
80c3b0c6a0SNicolas Vasilache TransitiveFilter filter) {
81d01ea0edSNicolas Vasilache if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
825c16564bSNicolas Vasilache return;
835c16564bSNicolas Vasilache
845c16564bSNicolas Vasilache // Evaluate whether we should keep this def.
855c16564bSNicolas Vasilache // This is useful in particular to implement scoping; i.e. return the
86d01ea0edSNicolas Vasilache // transitive backwardSlice in the current scope.
87d01ea0edSNicolas Vasilache if (filter && !filter(op))
885c16564bSNicolas Vasilache return;
895c16564bSNicolas Vasilache
90e4853be2SMehdi Amini for (const auto &en : llvm::enumerate(op->getOperands())) {
9135807bc4SRiver Riddle auto operand = en.value();
92d01ea0edSNicolas Vasilache if (auto *definingOp = operand.getDefiningOp()) {
93d01ea0edSNicolas Vasilache if (backwardSlice->count(definingOp) == 0)
94d01ea0edSNicolas Vasilache getBackwardSliceImpl(definingOp, backwardSlice, filter);
95d01ea0edSNicolas Vasilache } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
96d01ea0edSNicolas Vasilache Block *block = blockArg.getOwner();
97d01ea0edSNicolas Vasilache Operation *parentOp = block->getParentOp();
98d01ea0edSNicolas Vasilache // TODO: determine whether we want to recurse backward into the other
99d01ea0edSNicolas Vasilache // blocks of parentOp, which are not technically backward unless they flow
100d01ea0edSNicolas Vasilache // into us. For now, just bail.
101d01ea0edSNicolas Vasilache assert(parentOp->getNumRegions() == 1 &&
102d01ea0edSNicolas Vasilache parentOp->getRegion(0).getBlocks().size() == 1);
103d01ea0edSNicolas Vasilache if (backwardSlice->count(parentOp) == 0)
104d01ea0edSNicolas Vasilache getBackwardSliceImpl(parentOp, backwardSlice, filter);
105d01ea0edSNicolas Vasilache } else {
106d01ea0edSNicolas Vasilache llvm_unreachable("No definingOp and not a block argument.");
1075c16564bSNicolas Vasilache }
1085c16564bSNicolas Vasilache }
1095c16564bSNicolas Vasilache
1109c085406SRiver Riddle backwardSlice->insert(op);
1115c16564bSNicolas Vasilache }
112c3b0c6a0SNicolas Vasilache
getBackwardSlice(Operation * op,SetVector<Operation * > * backwardSlice,TransitiveFilter filter)1139c085406SRiver Riddle void mlir::getBackwardSlice(Operation *op,
1149c085406SRiver Riddle SetVector<Operation *> *backwardSlice,
115c3b0c6a0SNicolas Vasilache TransitiveFilter filter) {
1169c085406SRiver Riddle getBackwardSliceImpl(op, backwardSlice, filter);
117c3b0c6a0SNicolas Vasilache
1189c085406SRiver Riddle // Don't insert the top level operation, we just queried on it and don't
119c3b0c6a0SNicolas Vasilache // want it in the results.
1209c085406SRiver Riddle backwardSlice->remove(op);
1215c16564bSNicolas Vasilache }
1225c16564bSNicolas Vasilache
getBackwardSlice(Value root,SetVector<Operation * > * backwardSlice,TransitiveFilter filter)123d01ea0edSNicolas Vasilache void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
124d01ea0edSNicolas Vasilache TransitiveFilter filter) {
125d01ea0edSNicolas Vasilache if (Operation *definingOp = root.getDefiningOp()) {
126d01ea0edSNicolas Vasilache getBackwardSlice(definingOp, backwardSlice, filter);
127d01ea0edSNicolas Vasilache return;
128d01ea0edSNicolas Vasilache }
129d01ea0edSNicolas Vasilache Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
130d01ea0edSNicolas Vasilache getBackwardSlice(bbAargOwner, backwardSlice, filter);
131d01ea0edSNicolas Vasilache }
132d01ea0edSNicolas Vasilache
getSlice(Operation * op,TransitiveFilter backwardFilter,TransitiveFilter forwardFilter)1339c085406SRiver Riddle SetVector<Operation *> mlir::getSlice(Operation *op,
1345c16564bSNicolas Vasilache TransitiveFilter backwardFilter,
1355c16564bSNicolas Vasilache TransitiveFilter forwardFilter) {
1369c085406SRiver Riddle SetVector<Operation *> slice;
1379c085406SRiver Riddle slice.insert(op);
1385c16564bSNicolas Vasilache
139258dae5dSNicolas Vasilache unsigned currentIndex = 0;
1409c085406SRiver Riddle SetVector<Operation *> backwardSlice;
1419c085406SRiver Riddle SetVector<Operation *> forwardSlice;
142258dae5dSNicolas Vasilache while (currentIndex != slice.size()) {
1436953cf65SNicolas Vasilache auto *currentOp = (slice)[currentIndex];
1446953cf65SNicolas Vasilache // Compute and insert the backwardSlice starting from currentOp.
145258dae5dSNicolas Vasilache backwardSlice.clear();
1466953cf65SNicolas Vasilache getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
147258dae5dSNicolas Vasilache slice.insert(backwardSlice.begin(), backwardSlice.end());
1485c16564bSNicolas Vasilache
1496953cf65SNicolas Vasilache // Compute and insert the forwardSlice starting from currentOp.
150258dae5dSNicolas Vasilache forwardSlice.clear();
1516953cf65SNicolas Vasilache getForwardSlice(currentOp, &forwardSlice, forwardFilter);
152258dae5dSNicolas Vasilache slice.insert(forwardSlice.begin(), forwardSlice.end());
1535c16564bSNicolas Vasilache ++currentIndex;
1545c16564bSNicolas Vasilache }
155258dae5dSNicolas Vasilache return topologicalSort(slice);
1565c16564bSNicolas Vasilache }
1575c16564bSNicolas Vasilache
1585c16564bSNicolas Vasilache namespace {
1595c16564bSNicolas Vasilache /// DFS post-order implementation that maintains a global count to work across
1605c16564bSNicolas Vasilache /// multiple invocations, to help implement topological sort on multi-root DAGs.
1619c085406SRiver Riddle /// We traverse all operations but only record the ones that appear in
162456ad6a8SChris Lattner /// `toSort` for the final result.
1635c16564bSNicolas Vasilache struct DFSState {
DFSState__anonfb076aa00111::DFSState1649c085406SRiver Riddle DFSState(const SetVector<Operation *> &set)
1655c16564bSNicolas Vasilache : toSort(set), topologicalCounts(), seen() {}
1669c085406SRiver Riddle const SetVector<Operation *> &toSort;
1679c085406SRiver Riddle SmallVector<Operation *, 16> topologicalCounts;
1689c085406SRiver Riddle DenseSet<Operation *> seen;
1695c16564bSNicolas Vasilache };
1705c16564bSNicolas Vasilache } // namespace
1715c16564bSNicolas Vasilache
dfsPostorder(Operation * root,DFSState * state)17202b6fb21SMehdi Amini static void dfsPostorder(Operation *root, DFSState *state) {
1735aa6038aSthomasraoux SmallVector<Operation *> queue(1, root);
1745aa6038aSthomasraoux std::vector<Operation *> ops;
1755aa6038aSthomasraoux while (!queue.empty()) {
1765aa6038aSthomasraoux Operation *current = queue.pop_back_val();
1775aa6038aSthomasraoux ops.push_back(current);
1782f23270aSThomas Raoux for (Value result : current->getResults()) {
1792f23270aSThomas Raoux for (Operation *op : result.getUsers())
1805aa6038aSthomasraoux queue.push_back(op);
1815c16564bSNicolas Vasilache }
1825aa6038aSthomasraoux for (Region ®ion : current->getRegions()) {
1835aa6038aSthomasraoux for (Operation &op : region.getOps())
1845aa6038aSthomasraoux queue.push_back(&op);
1855c16564bSNicolas Vasilache }
1865c16564bSNicolas Vasilache }
1875aa6038aSthomasraoux
1885aa6038aSthomasraoux for (Operation *op : llvm::reverse(ops)) {
1895aa6038aSthomasraoux if (state->seen.insert(op).second && state->toSort.count(op) > 0)
1905aa6038aSthomasraoux state->topologicalCounts.push_back(op);
1915aa6038aSthomasraoux }
1925c16564bSNicolas Vasilache }
1935c16564bSNicolas Vasilache
1949c085406SRiver Riddle SetVector<Operation *>
topologicalSort(const SetVector<Operation * > & toSort)1959c085406SRiver Riddle mlir::topologicalSort(const SetVector<Operation *> &toSort) {
1965c16564bSNicolas Vasilache if (toSort.empty()) {
1975c16564bSNicolas Vasilache return toSort;
1985c16564bSNicolas Vasilache }
1995c16564bSNicolas Vasilache
2005c16564bSNicolas Vasilache // Run from each root with global count and `seen` set.
2015c16564bSNicolas Vasilache DFSState state(toSort);
2025c16564bSNicolas Vasilache for (auto *s : toSort) {
2035c16564bSNicolas Vasilache assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
20402b6fb21SMehdi Amini dfsPostorder(s, &state);
2055c16564bSNicolas Vasilache }
2065c16564bSNicolas Vasilache
2075c16564bSNicolas Vasilache // Reorder and return.
2089c085406SRiver Riddle SetVector<Operation *> res;
2095c16564bSNicolas Vasilache for (auto it = state.topologicalCounts.rbegin(),
2105c16564bSNicolas Vasilache eit = state.topologicalCounts.rend();
2115c16564bSNicolas Vasilache it != eit; ++it) {
2125c16564bSNicolas Vasilache res.insert(*it);
2135c16564bSNicolas Vasilache }
2145c16564bSNicolas Vasilache return res;
2155c16564bSNicolas Vasilache }
216*755dc07dSRiver Riddle
217*755dc07dSRiver Riddle /// Returns true if `value` (transitively) depends on iteration-carried values
218*755dc07dSRiver Riddle /// of the given `ancestorOp`.
dependsOnCarriedVals(Value value,ArrayRef<BlockArgument> iterCarriedArgs,Operation * ancestorOp)219*755dc07dSRiver Riddle static bool dependsOnCarriedVals(Value value,
220*755dc07dSRiver Riddle ArrayRef<BlockArgument> iterCarriedArgs,
221*755dc07dSRiver Riddle Operation *ancestorOp) {
222*755dc07dSRiver Riddle // Compute the backward slice of the value.
223*755dc07dSRiver Riddle SetVector<Operation *> slice;
224*755dc07dSRiver Riddle getBackwardSlice(value, &slice,
225*755dc07dSRiver Riddle [&](Operation *op) { return !ancestorOp->isAncestor(op); });
226*755dc07dSRiver Riddle
227*755dc07dSRiver Riddle // Check that none of the operands of the operations in the backward slice are
228*755dc07dSRiver Riddle // loop iteration arguments, and neither is the value itself.
229*755dc07dSRiver Riddle SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
230*755dc07dSRiver Riddle iterCarriedArgs.end());
231*755dc07dSRiver Riddle if (iterCarriedValSet.contains(value))
232*755dc07dSRiver Riddle return true;
233*755dc07dSRiver Riddle
234*755dc07dSRiver Riddle for (Operation *op : slice)
235*755dc07dSRiver Riddle for (Value operand : op->getOperands())
236*755dc07dSRiver Riddle if (iterCarriedValSet.contains(operand))
237*755dc07dSRiver Riddle return true;
238*755dc07dSRiver Riddle
239*755dc07dSRiver Riddle return false;
240*755dc07dSRiver Riddle }
241*755dc07dSRiver Riddle
242*755dc07dSRiver Riddle /// Utility to match a generic reduction given a list of iteration-carried
243*755dc07dSRiver Riddle /// arguments, `iterCarriedArgs` and the position of the potential reduction
244*755dc07dSRiver Riddle /// argument within the list, `redPos`. If a reduction is matched, returns the
245*755dc07dSRiver Riddle /// reduced value and the topologically-sorted list of combiner operations
246*755dc07dSRiver Riddle /// involved in the reduction. Otherwise, returns a null value.
247*755dc07dSRiver Riddle ///
248*755dc07dSRiver Riddle /// The matching algorithm relies on the following invariants, which are subject
249*755dc07dSRiver Riddle /// to change:
250*755dc07dSRiver Riddle /// 1. The first combiner operation must be a binary operation with the
251*755dc07dSRiver Riddle /// iteration-carried value and the reduced value as operands.
252*755dc07dSRiver Riddle /// 2. The iteration-carried value and combiner operations must be side
253*755dc07dSRiver Riddle /// effect-free, have single result and a single use.
254*755dc07dSRiver Riddle /// 3. Combiner operations must be immediately nested in the region op
255*755dc07dSRiver Riddle /// performing the reduction.
256*755dc07dSRiver Riddle /// 4. Reduction def-use chain must end in a terminator op that yields the
257*755dc07dSRiver Riddle /// next iteration/output values in the same order as the iteration-carried
258*755dc07dSRiver Riddle /// values in `iterCarriedArgs`.
259*755dc07dSRiver Riddle /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
260*755dc07dSRiver Riddle /// of the region op performing the reduction.
261*755dc07dSRiver Riddle ///
262*755dc07dSRiver Riddle /// This utility is generic enough to detect reductions involving multiple
263*755dc07dSRiver Riddle /// combiner operations (disabled for now) across multiple dialects, including
264*755dc07dSRiver Riddle /// Linalg, Affine and SCF. For the sake of genericity, it does not return
265*755dc07dSRiver Riddle /// specific enum values for the combiner operations since its goal is also
266*755dc07dSRiver Riddle /// matching reductions without pre-defined semantics in core MLIR. It's up to
267*755dc07dSRiver Riddle /// each client to make sense out of the list of combiner operations. It's also
268*755dc07dSRiver Riddle /// up to each client to check for additional invariants on the expected
269*755dc07dSRiver Riddle /// reductions not covered by this generic matching.
matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,unsigned redPos,SmallVectorImpl<Operation * > & combinerOps)270*755dc07dSRiver Riddle Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
271*755dc07dSRiver Riddle unsigned redPos,
272*755dc07dSRiver Riddle SmallVectorImpl<Operation *> &combinerOps) {
273*755dc07dSRiver Riddle assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
274*755dc07dSRiver Riddle
275*755dc07dSRiver Riddle BlockArgument redCarriedVal = iterCarriedArgs[redPos];
276*755dc07dSRiver Riddle if (!redCarriedVal.hasOneUse())
277*755dc07dSRiver Riddle return nullptr;
278*755dc07dSRiver Riddle
279*755dc07dSRiver Riddle // For now, the first combiner op must be a binary op.
280*755dc07dSRiver Riddle Operation *combinerOp = *redCarriedVal.getUsers().begin();
281*755dc07dSRiver Riddle if (combinerOp->getNumOperands() != 2)
282*755dc07dSRiver Riddle return nullptr;
283*755dc07dSRiver Riddle Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
284*755dc07dSRiver Riddle ? combinerOp->getOperand(1)
285*755dc07dSRiver Riddle : combinerOp->getOperand(0);
286*755dc07dSRiver Riddle
287*755dc07dSRiver Riddle Operation *redRegionOp =
288*755dc07dSRiver Riddle iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
289*755dc07dSRiver Riddle if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
290*755dc07dSRiver Riddle return nullptr;
291*755dc07dSRiver Riddle
292*755dc07dSRiver Riddle // Traverse the def-use chain starting from the first combiner op until a
293*755dc07dSRiver Riddle // terminator is found. Gather all the combiner ops along the way in
294*755dc07dSRiver Riddle // topological order.
295*755dc07dSRiver Riddle while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
296*755dc07dSRiver Riddle if (!MemoryEffectOpInterface::hasNoEffect(combinerOp) ||
297*755dc07dSRiver Riddle combinerOp->getNumResults() != 1 || !combinerOp->hasOneUse() ||
298*755dc07dSRiver Riddle combinerOp->getParentOp() != redRegionOp)
299*755dc07dSRiver Riddle return nullptr;
300*755dc07dSRiver Riddle
301*755dc07dSRiver Riddle combinerOps.push_back(combinerOp);
302*755dc07dSRiver Riddle combinerOp = *combinerOp->getUsers().begin();
303*755dc07dSRiver Riddle }
304*755dc07dSRiver Riddle
305*755dc07dSRiver Riddle // Limit matching to single combiner op until we can properly test reductions
306*755dc07dSRiver Riddle // involving multiple combiners.
307*755dc07dSRiver Riddle if (combinerOps.size() != 1)
308*755dc07dSRiver Riddle return nullptr;
309*755dc07dSRiver Riddle
310*755dc07dSRiver Riddle // Check that the yielded value is in the same position as in
311*755dc07dSRiver Riddle // `iterCarriedArgs`.
312*755dc07dSRiver Riddle Operation *terminatorOp = combinerOp;
313*755dc07dSRiver Riddle if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
314*755dc07dSRiver Riddle return nullptr;
315*755dc07dSRiver Riddle
316*755dc07dSRiver Riddle return reducedVal;
317*755dc07dSRiver Riddle }
318