1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/SetVector.h"
21 
22 ///
23 /// Implements Analysis functions specific to slicing in Function.
24 ///
25 
26 using namespace mlir;
27 
28 using llvm::SetVector;
29 
30 static void getForwardSliceImpl(Operation *op,
31                                 SetVector<Operation *> *forwardSlice,
32                                 TransitiveFilter filter) {
33   if (!op)
34     return;
35 
36   // Evaluate whether we should keep this use.
37   // This is useful in particular to implement scoping; i.e. return the
38   // transitive forwardSlice in the current scope.
39   if (filter && !filter(op))
40     return;
41 
42   for (Region &region : op->getRegions())
43     for (Block &block : region)
44       for (Operation &blockOp : block)
45         if (forwardSlice->count(&blockOp) == 0)
46           getForwardSliceImpl(&blockOp, forwardSlice, filter);
47   for (Value result : op->getResults()) {
48     for (Operation *userOp : result.getUsers())
49       if (forwardSlice->count(userOp) == 0)
50         getForwardSliceImpl(userOp, forwardSlice, filter);
51   }
52 
53   forwardSlice->insert(op);
54 }
55 
56 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
57                            TransitiveFilter filter) {
58   getForwardSliceImpl(op, forwardSlice, filter);
59   // Don't insert the top level operation, we just queried on it and don't
60   // want it in the results.
61   forwardSlice->remove(op);
62 
63   // Reverse to get back the actual topological order.
64   // std::reverse does not work out of the box on SetVector and I want an
65   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
66   std::vector<Operation *> v(forwardSlice->takeVector());
67   forwardSlice->insert(v.rbegin(), v.rend());
68 }
69 
70 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
71                            TransitiveFilter filter) {
72   for (Operation *user : root.getUsers())
73     getForwardSliceImpl(user, forwardSlice, filter);
74 
75   // Reverse to get back the actual topological order.
76   // std::reverse does not work out of the box on SetVector and I want an
77   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
78   std::vector<Operation *> v(forwardSlice->takeVector());
79   forwardSlice->insert(v.rbegin(), v.rend());
80 }
81 
82 static void getBackwardSliceImpl(Operation *op,
83                                  SetVector<Operation *> *backwardSlice,
84                                  TransitiveFilter filter) {
85   if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
86     return;
87 
88   // Evaluate whether we should keep this def.
89   // This is useful in particular to implement scoping; i.e. return the
90   // transitive backwardSlice in the current scope.
91   if (filter && !filter(op))
92     return;
93 
94   for (auto en : llvm::enumerate(op->getOperands())) {
95     auto operand = en.value();
96     if (auto *definingOp = operand.getDefiningOp()) {
97       if (backwardSlice->count(definingOp) == 0)
98         getBackwardSliceImpl(definingOp, backwardSlice, filter);
99     } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
100       Block *block = blockArg.getOwner();
101       Operation *parentOp = block->getParentOp();
102       // TODO: determine whether we want to recurse backward into the other
103       // blocks of parentOp, which are not technically backward unless they flow
104       // into us. For now, just bail.
105       assert(parentOp->getNumRegions() == 1 &&
106              parentOp->getRegion(0).getBlocks().size() == 1);
107       if (backwardSlice->count(parentOp) == 0)
108         getBackwardSliceImpl(parentOp, backwardSlice, filter);
109     } else {
110       llvm_unreachable("No definingOp and not a block argument.");
111     }
112   }
113 
114   backwardSlice->insert(op);
115 }
116 
117 void mlir::getBackwardSlice(Operation *op,
118                             SetVector<Operation *> *backwardSlice,
119                             TransitiveFilter filter) {
120   getBackwardSliceImpl(op, backwardSlice, filter);
121 
122   // Don't insert the top level operation, we just queried on it and don't
123   // want it in the results.
124   backwardSlice->remove(op);
125 }
126 
127 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
128                             TransitiveFilter filter) {
129   if (Operation *definingOp = root.getDefiningOp()) {
130     getBackwardSlice(definingOp, backwardSlice, filter);
131     return;
132   }
133   Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
134   getBackwardSlice(bbAargOwner, backwardSlice, filter);
135 }
136 
137 SetVector<Operation *> mlir::getSlice(Operation *op,
138                                       TransitiveFilter backwardFilter,
139                                       TransitiveFilter forwardFilter) {
140   SetVector<Operation *> slice;
141   slice.insert(op);
142 
143   unsigned currentIndex = 0;
144   SetVector<Operation *> backwardSlice;
145   SetVector<Operation *> forwardSlice;
146   while (currentIndex != slice.size()) {
147     auto *currentOp = (slice)[currentIndex];
148     // Compute and insert the backwardSlice starting from currentOp.
149     backwardSlice.clear();
150     getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
151     slice.insert(backwardSlice.begin(), backwardSlice.end());
152 
153     // Compute and insert the forwardSlice starting from currentOp.
154     forwardSlice.clear();
155     getForwardSlice(currentOp, &forwardSlice, forwardFilter);
156     slice.insert(forwardSlice.begin(), forwardSlice.end());
157     ++currentIndex;
158   }
159   return topologicalSort(slice);
160 }
161 
162 namespace {
163 /// DFS post-order implementation that maintains a global count to work across
164 /// multiple invocations, to help implement topological sort on multi-root DAGs.
165 /// We traverse all operations but only record the ones that appear in
166 /// `toSort` for the final result.
167 struct DFSState {
168   DFSState(const SetVector<Operation *> &set)
169       : toSort(set), topologicalCounts(), seen() {}
170   const SetVector<Operation *> &toSort;
171   SmallVector<Operation *, 16> topologicalCounts;
172   DenseSet<Operation *> seen;
173 };
174 } // namespace
175 
176 static void DFSPostorder(Operation *current, DFSState *state) {
177   for (Value result : current->getResults()) {
178     for (Operation *op : result.getUsers())
179       DFSPostorder(op, state);
180   }
181   bool inserted;
182   using IterTy = decltype(state->seen.begin());
183   IterTy iter;
184   std::tie(iter, inserted) = state->seen.insert(current);
185   if (inserted) {
186     if (state->toSort.count(current) > 0) {
187       state->topologicalCounts.push_back(current);
188     }
189   }
190 }
191 
192 SetVector<Operation *>
193 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
194   if (toSort.empty()) {
195     return toSort;
196   }
197 
198   // Run from each root with global count and `seen` set.
199   DFSState state(toSort);
200   for (auto *s : toSort) {
201     assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
202     DFSPostorder(s, &state);
203   }
204 
205   // Reorder and return.
206   SetVector<Operation *> res;
207   for (auto it = state.topologicalCounts.rbegin(),
208             eit = state.topologicalCounts.rend();
209        it != eit; ++it) {
210     res.insert(*it);
211   }
212   return res;
213 }
214