1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
17 #include "llvm/ADT/SetVector.h"
18 
19 ///
20 /// Implements Analysis functions specific to slicing in Function.
21 ///
22 
23 using namespace mlir;
24 
25 static void getForwardSliceImpl(Operation *op,
26                                 SetVector<Operation *> *forwardSlice,
27                                 TransitiveFilter filter) {
28   if (!op)
29     return;
30 
31   // Evaluate whether we should keep this use.
32   // This is useful in particular to implement scoping; i.e. return the
33   // transitive forwardSlice in the current scope.
34   if (filter && !filter(op))
35     return;
36 
37   for (Region &region : op->getRegions())
38     for (Block &block : region)
39       for (Operation &blockOp : block)
40         if (forwardSlice->count(&blockOp) == 0)
41           getForwardSliceImpl(&blockOp, forwardSlice, filter);
42   for (Value result : op->getResults()) {
43     for (Operation *userOp : result.getUsers())
44       if (forwardSlice->count(userOp) == 0)
45         getForwardSliceImpl(userOp, forwardSlice, filter);
46   }
47 
48   forwardSlice->insert(op);
49 }
50 
51 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
52                            TransitiveFilter filter) {
53   getForwardSliceImpl(op, forwardSlice, filter);
54   // Don't insert the top level operation, we just queried on it and don't
55   // want it in the results.
56   forwardSlice->remove(op);
57 
58   // Reverse to get back the actual topological order.
59   // std::reverse does not work out of the box on SetVector and I want an
60   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
61   std::vector<Operation *> v(forwardSlice->takeVector());
62   forwardSlice->insert(v.rbegin(), v.rend());
63 }
64 
65 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
66                            TransitiveFilter filter) {
67   for (Operation *user : root.getUsers())
68     getForwardSliceImpl(user, forwardSlice, filter);
69 
70   // Reverse to get back the actual topological order.
71   // std::reverse does not work out of the box on SetVector and I want an
72   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
73   std::vector<Operation *> v(forwardSlice->takeVector());
74   forwardSlice->insert(v.rbegin(), v.rend());
75 }
76 
77 static void getBackwardSliceImpl(Operation *op,
78                                  SetVector<Operation *> *backwardSlice,
79                                  TransitiveFilter filter) {
80   if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
81     return;
82 
83   // Evaluate whether we should keep this def.
84   // This is useful in particular to implement scoping; i.e. return the
85   // transitive backwardSlice in the current scope.
86   if (filter && !filter(op))
87     return;
88 
89   for (auto en : llvm::enumerate(op->getOperands())) {
90     auto operand = en.value();
91     if (auto *definingOp = operand.getDefiningOp()) {
92       if (backwardSlice->count(definingOp) == 0)
93         getBackwardSliceImpl(definingOp, backwardSlice, filter);
94     } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
95       Block *block = blockArg.getOwner();
96       Operation *parentOp = block->getParentOp();
97       // TODO: determine whether we want to recurse backward into the other
98       // blocks of parentOp, which are not technically backward unless they flow
99       // into us. For now, just bail.
100       assert(parentOp->getNumRegions() == 1 &&
101              parentOp->getRegion(0).getBlocks().size() == 1);
102       if (backwardSlice->count(parentOp) == 0)
103         getBackwardSliceImpl(parentOp, backwardSlice, filter);
104     } else {
105       llvm_unreachable("No definingOp and not a block argument.");
106     }
107   }
108 
109   backwardSlice->insert(op);
110 }
111 
112 void mlir::getBackwardSlice(Operation *op,
113                             SetVector<Operation *> *backwardSlice,
114                             TransitiveFilter filter) {
115   getBackwardSliceImpl(op, backwardSlice, filter);
116 
117   // Don't insert the top level operation, we just queried on it and don't
118   // want it in the results.
119   backwardSlice->remove(op);
120 }
121 
122 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
123                             TransitiveFilter filter) {
124   if (Operation *definingOp = root.getDefiningOp()) {
125     getBackwardSlice(definingOp, backwardSlice, filter);
126     return;
127   }
128   Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
129   getBackwardSlice(bbAargOwner, backwardSlice, filter);
130 }
131 
132 SetVector<Operation *> mlir::getSlice(Operation *op,
133                                       TransitiveFilter backwardFilter,
134                                       TransitiveFilter forwardFilter) {
135   SetVector<Operation *> slice;
136   slice.insert(op);
137 
138   unsigned currentIndex = 0;
139   SetVector<Operation *> backwardSlice;
140   SetVector<Operation *> forwardSlice;
141   while (currentIndex != slice.size()) {
142     auto *currentOp = (slice)[currentIndex];
143     // Compute and insert the backwardSlice starting from currentOp.
144     backwardSlice.clear();
145     getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
146     slice.insert(backwardSlice.begin(), backwardSlice.end());
147 
148     // Compute and insert the forwardSlice starting from currentOp.
149     forwardSlice.clear();
150     getForwardSlice(currentOp, &forwardSlice, forwardFilter);
151     slice.insert(forwardSlice.begin(), forwardSlice.end());
152     ++currentIndex;
153   }
154   return topologicalSort(slice);
155 }
156 
157 namespace {
158 /// DFS post-order implementation that maintains a global count to work across
159 /// multiple invocations, to help implement topological sort on multi-root DAGs.
160 /// We traverse all operations but only record the ones that appear in
161 /// `toSort` for the final result.
162 struct DFSState {
163   DFSState(const SetVector<Operation *> &set)
164       : toSort(set), topologicalCounts(), seen() {}
165   const SetVector<Operation *> &toSort;
166   SmallVector<Operation *, 16> topologicalCounts;
167   DenseSet<Operation *> seen;
168 };
169 } // namespace
170 
171 static void DFSPostorder(Operation *root, DFSState *state) {
172   SmallVector<Operation *> queue(1, root);
173   std::vector<Operation *> ops;
174   while (!queue.empty()) {
175     Operation *current = queue.pop_back_val();
176     ops.push_back(current);
177     for (Value result : current->getResults()) {
178       for (Operation *op : result.getUsers())
179         queue.push_back(op);
180     }
181     for (Region &region : current->getRegions()) {
182       for (Operation &op : region.getOps())
183         queue.push_back(&op);
184     }
185   }
186 
187   for (Operation *op : llvm::reverse(ops)) {
188     if (state->seen.insert(op).second && state->toSort.count(op) > 0)
189       state->topologicalCounts.push_back(op);
190   }
191 }
192 
193 SetVector<Operation *>
194 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
195   if (toSort.empty()) {
196     return toSort;
197   }
198 
199   // Run from each root with global count and `seen` set.
200   DFSState state(toSort);
201   for (auto *s : toSort) {
202     assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
203     DFSPostorder(s, &state);
204   }
205 
206   // Reorder and return.
207   SetVector<Operation *> res;
208   for (auto it = state.topologicalCounts.rbegin(),
209             eit = state.topologicalCounts.rend();
210        it != eit; ++it) {
211     res.insert(*it);
212   }
213   return res;
214 }
215