1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/SetVector.h"
21 
22 ///
23 /// Implements Analysis functions specific to slicing in Function.
24 ///
25 
26 using namespace mlir;
27 
28 static void getForwardSliceImpl(Operation *op,
29                                 SetVector<Operation *> *forwardSlice,
30                                 TransitiveFilter filter) {
31   if (!op)
32     return;
33 
34   // Evaluate whether we should keep this use.
35   // This is useful in particular to implement scoping; i.e. return the
36   // transitive forwardSlice in the current scope.
37   if (filter && !filter(op))
38     return;
39 
40   for (Region &region : op->getRegions())
41     for (Block &block : region)
42       for (Operation &blockOp : block)
43         if (forwardSlice->count(&blockOp) == 0)
44           getForwardSliceImpl(&blockOp, forwardSlice, filter);
45   for (Value result : op->getResults()) {
46     for (Operation *userOp : result.getUsers())
47       if (forwardSlice->count(userOp) == 0)
48         getForwardSliceImpl(userOp, forwardSlice, filter);
49   }
50 
51   forwardSlice->insert(op);
52 }
53 
54 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
55                            TransitiveFilter filter) {
56   getForwardSliceImpl(op, forwardSlice, filter);
57   // Don't insert the top level operation, we just queried on it and don't
58   // want it in the results.
59   forwardSlice->remove(op);
60 
61   // Reverse to get back the actual topological order.
62   // std::reverse does not work out of the box on SetVector and I want an
63   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
64   std::vector<Operation *> v(forwardSlice->takeVector());
65   forwardSlice->insert(v.rbegin(), v.rend());
66 }
67 
68 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
69                            TransitiveFilter filter) {
70   for (Operation *user : root.getUsers())
71     getForwardSliceImpl(user, forwardSlice, filter);
72 
73   // Reverse to get back the actual topological order.
74   // std::reverse does not work out of the box on SetVector and I want an
75   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
76   std::vector<Operation *> v(forwardSlice->takeVector());
77   forwardSlice->insert(v.rbegin(), v.rend());
78 }
79 
80 static void getBackwardSliceImpl(Operation *op,
81                                  SetVector<Operation *> *backwardSlice,
82                                  TransitiveFilter filter) {
83   if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
84     return;
85 
86   // Evaluate whether we should keep this def.
87   // This is useful in particular to implement scoping; i.e. return the
88   // transitive backwardSlice in the current scope.
89   if (filter && !filter(op))
90     return;
91 
92   for (auto en : llvm::enumerate(op->getOperands())) {
93     auto operand = en.value();
94     if (auto *definingOp = operand.getDefiningOp()) {
95       if (backwardSlice->count(definingOp) == 0)
96         getBackwardSliceImpl(definingOp, backwardSlice, filter);
97     } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
98       Block *block = blockArg.getOwner();
99       Operation *parentOp = block->getParentOp();
100       // TODO: determine whether we want to recurse backward into the other
101       // blocks of parentOp, which are not technically backward unless they flow
102       // into us. For now, just bail.
103       assert(parentOp->getNumRegions() == 1 &&
104              parentOp->getRegion(0).getBlocks().size() == 1);
105       if (backwardSlice->count(parentOp) == 0)
106         getBackwardSliceImpl(parentOp, backwardSlice, filter);
107     } else {
108       llvm_unreachable("No definingOp and not a block argument.");
109     }
110   }
111 
112   backwardSlice->insert(op);
113 }
114 
115 void mlir::getBackwardSlice(Operation *op,
116                             SetVector<Operation *> *backwardSlice,
117                             TransitiveFilter filter) {
118   getBackwardSliceImpl(op, backwardSlice, filter);
119 
120   // Don't insert the top level operation, we just queried on it and don't
121   // want it in the results.
122   backwardSlice->remove(op);
123 }
124 
125 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
126                             TransitiveFilter filter) {
127   if (Operation *definingOp = root.getDefiningOp()) {
128     getBackwardSlice(definingOp, backwardSlice, filter);
129     return;
130   }
131   Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
132   getBackwardSlice(bbAargOwner, backwardSlice, filter);
133 }
134 
135 SetVector<Operation *> mlir::getSlice(Operation *op,
136                                       TransitiveFilter backwardFilter,
137                                       TransitiveFilter forwardFilter) {
138   SetVector<Operation *> slice;
139   slice.insert(op);
140 
141   unsigned currentIndex = 0;
142   SetVector<Operation *> backwardSlice;
143   SetVector<Operation *> forwardSlice;
144   while (currentIndex != slice.size()) {
145     auto *currentOp = (slice)[currentIndex];
146     // Compute and insert the backwardSlice starting from currentOp.
147     backwardSlice.clear();
148     getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
149     slice.insert(backwardSlice.begin(), backwardSlice.end());
150 
151     // Compute and insert the forwardSlice starting from currentOp.
152     forwardSlice.clear();
153     getForwardSlice(currentOp, &forwardSlice, forwardFilter);
154     slice.insert(forwardSlice.begin(), forwardSlice.end());
155     ++currentIndex;
156   }
157   return topologicalSort(slice);
158 }
159 
160 namespace {
161 /// DFS post-order implementation that maintains a global count to work across
162 /// multiple invocations, to help implement topological sort on multi-root DAGs.
163 /// We traverse all operations but only record the ones that appear in
164 /// `toSort` for the final result.
165 struct DFSState {
166   DFSState(const SetVector<Operation *> &set)
167       : toSort(set), topologicalCounts(), seen() {}
168   const SetVector<Operation *> &toSort;
169   SmallVector<Operation *, 16> topologicalCounts;
170   DenseSet<Operation *> seen;
171 };
172 } // namespace
173 
174 static void DFSPostorder(Operation *current, DFSState *state) {
175   for (Value result : current->getResults()) {
176     for (Operation *op : result.getUsers())
177       DFSPostorder(op, state);
178   }
179   bool inserted;
180   using IterTy = decltype(state->seen.begin());
181   IterTy iter;
182   std::tie(iter, inserted) = state->seen.insert(current);
183   if (inserted) {
184     if (state->toSort.count(current) > 0) {
185       state->topologicalCounts.push_back(current);
186     }
187   }
188 }
189 
190 SetVector<Operation *>
191 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
192   if (toSort.empty()) {
193     return toSort;
194   }
195 
196   // Run from each root with global count and `seen` set.
197   DFSState state(toSort);
198   for (auto *s : toSort) {
199     assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
200     DFSPostorder(s, &state);
201   }
202 
203   // Reorder and return.
204   SetVector<Operation *> res;
205   for (auto it = state.topologicalCounts.rbegin(),
206             eit = state.topologicalCounts.rend();
207        it != eit; ++it) {
208     res.insert(*it);
209   }
210   return res;
211 }
212