1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
17 #include "llvm/ADT/SetVector.h"
18 
19 ///
20 /// Implements Analysis functions specific to slicing in Function.
21 ///
22 
23 using namespace mlir;
24 
25 static void getForwardSliceImpl(Operation *op,
26                                 SetVector<Operation *> *forwardSlice,
27                                 TransitiveFilter filter) {
28   if (!op)
29     return;
30 
31   // Evaluate whether we should keep this use.
32   // This is useful in particular to implement scoping; i.e. return the
33   // transitive forwardSlice in the current scope.
34   if (filter && !filter(op))
35     return;
36 
37   for (Region &region : op->getRegions())
38     for (Block &block : region)
39       for (Operation &blockOp : block)
40         if (forwardSlice->count(&blockOp) == 0)
41           getForwardSliceImpl(&blockOp, forwardSlice, filter);
42   for (Value result : op->getResults()) {
43     for (Operation *userOp : result.getUsers())
44       if (forwardSlice->count(userOp) == 0)
45         getForwardSliceImpl(userOp, forwardSlice, filter);
46   }
47 
48   forwardSlice->insert(op);
49 }
50 
51 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
52                            TransitiveFilter filter) {
53   getForwardSliceImpl(op, forwardSlice, filter);
54   // Don't insert the top level operation, we just queried on it and don't
55   // want it in the results.
56   forwardSlice->remove(op);
57 
58   // Reverse to get back the actual topological order.
59   // std::reverse does not work out of the box on SetVector and I want an
60   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
61   std::vector<Operation *> v(forwardSlice->takeVector());
62   forwardSlice->insert(v.rbegin(), v.rend());
63 }
64 
65 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
66                            TransitiveFilter filter) {
67   for (Operation *user : root.getUsers())
68     getForwardSliceImpl(user, forwardSlice, filter);
69 
70   // Reverse to get back the actual topological order.
71   // std::reverse does not work out of the box on SetVector and I want an
72   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
73   std::vector<Operation *> v(forwardSlice->takeVector());
74   forwardSlice->insert(v.rbegin(), v.rend());
75 }
76 
77 static void getBackwardSliceImpl(Operation *op,
78                                  SetVector<Operation *> *backwardSlice,
79                                  TransitiveFilter filter) {
80   if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
81     return;
82 
83   // Evaluate whether we should keep this def.
84   // This is useful in particular to implement scoping; i.e. return the
85   // transitive backwardSlice in the current scope.
86   if (filter && !filter(op))
87     return;
88 
89   for (auto en : llvm::enumerate(op->getOperands())) {
90     auto operand = en.value();
91     if (auto *definingOp = operand.getDefiningOp()) {
92       if (backwardSlice->count(definingOp) == 0)
93         getBackwardSliceImpl(definingOp, backwardSlice, filter);
94     } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
95       Block *block = blockArg.getOwner();
96       Operation *parentOp = block->getParentOp();
97       // TODO: determine whether we want to recurse backward into the other
98       // blocks of parentOp, which are not technically backward unless they flow
99       // into us. For now, just bail.
100       assert(parentOp->getNumRegions() == 1 &&
101              parentOp->getRegion(0).getBlocks().size() == 1);
102       if (backwardSlice->count(parentOp) == 0)
103         getBackwardSliceImpl(parentOp, backwardSlice, filter);
104     } else {
105       llvm_unreachable("No definingOp and not a block argument.");
106     }
107   }
108 
109   backwardSlice->insert(op);
110 }
111 
112 void mlir::getBackwardSlice(Operation *op,
113                             SetVector<Operation *> *backwardSlice,
114                             TransitiveFilter filter) {
115   getBackwardSliceImpl(op, backwardSlice, filter);
116 
117   // Don't insert the top level operation, we just queried on it and don't
118   // want it in the results.
119   backwardSlice->remove(op);
120 }
121 
122 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
123                             TransitiveFilter filter) {
124   if (Operation *definingOp = root.getDefiningOp()) {
125     getBackwardSlice(definingOp, backwardSlice, filter);
126     return;
127   }
128   Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
129   getBackwardSlice(bbAargOwner, backwardSlice, filter);
130 }
131 
132 SetVector<Operation *> mlir::getSlice(Operation *op,
133                                       TransitiveFilter backwardFilter,
134                                       TransitiveFilter forwardFilter) {
135   SetVector<Operation *> slice;
136   slice.insert(op);
137 
138   unsigned currentIndex = 0;
139   SetVector<Operation *> backwardSlice;
140   SetVector<Operation *> forwardSlice;
141   while (currentIndex != slice.size()) {
142     auto *currentOp = (slice)[currentIndex];
143     // Compute and insert the backwardSlice starting from currentOp.
144     backwardSlice.clear();
145     getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
146     slice.insert(backwardSlice.begin(), backwardSlice.end());
147 
148     // Compute and insert the forwardSlice starting from currentOp.
149     forwardSlice.clear();
150     getForwardSlice(currentOp, &forwardSlice, forwardFilter);
151     slice.insert(forwardSlice.begin(), forwardSlice.end());
152     ++currentIndex;
153   }
154   return topologicalSort(slice);
155 }
156 
157 namespace {
158 /// DFS post-order implementation that maintains a global count to work across
159 /// multiple invocations, to help implement topological sort on multi-root DAGs.
160 /// We traverse all operations but only record the ones that appear in
161 /// `toSort` for the final result.
162 struct DFSState {
163   DFSState(const SetVector<Operation *> &set)
164       : toSort(set), topologicalCounts(), seen() {}
165   const SetVector<Operation *> &toSort;
166   SmallVector<Operation *, 16> topologicalCounts;
167   DenseSet<Operation *> seen;
168 };
169 } // namespace
170 
171 static void DFSPostorder(Operation *current, DFSState *state) {
172   for (Value result : current->getResults()) {
173     for (Operation *op : result.getUsers())
174       DFSPostorder(op, state);
175   }
176   bool inserted;
177   using IterTy = decltype(state->seen.begin());
178   IterTy iter;
179   std::tie(iter, inserted) = state->seen.insert(current);
180   if (inserted) {
181     if (state->toSort.count(current) > 0) {
182       state->topologicalCounts.push_back(current);
183     }
184   }
185 }
186 
187 SetVector<Operation *>
188 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
189   if (toSort.empty()) {
190     return toSort;
191   }
192 
193   // Run from each root with global count and `seen` set.
194   DFSState state(toSort);
195   for (auto *s : toSort) {
196     assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
197     DFSPostorder(s, &state);
198   }
199 
200   // Reorder and return.
201   SetVector<Operation *> res;
202   for (auto it = state.topologicalCounts.rbegin(),
203             eit = state.topologicalCounts.rend();
204        it != eit; ++it) {
205     res.insert(*it);
206   }
207   return res;
208 }
209