1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/IR/Function.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/SetVector.h"
20 
21 ///
22 /// Implements Analysis functions specific to slicing in Function.
23 ///
24 
25 using namespace mlir;
26 
27 using llvm::SetVector;
28 
29 static void getForwardSliceImpl(Operation *op,
30                                 SetVector<Operation *> *forwardSlice,
31                                 TransitiveFilter filter) {
32   if (!op) {
33     return;
34   }
35 
36   // Evaluate whether we should keep this use.
37   // This is useful in particular to implement scoping; i.e. return the
38   // transitive forwardSlice in the current scope.
39   if (!filter(op)) {
40     return;
41   }
42 
43   if (auto forOp = dyn_cast<AffineForOp>(op)) {
44     for (Operation *userOp : forOp.getInductionVar().getUsers())
45       if (forwardSlice->count(userOp) == 0)
46         getForwardSliceImpl(userOp, forwardSlice, filter);
47   } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
48     for (Operation *userOp : forOp.getInductionVar().getUsers())
49       if (forwardSlice->count(userOp) == 0)
50         getForwardSliceImpl(userOp, forwardSlice, filter);
51     for (Value result : forOp.getResults())
52       for (Operation *userOp : result.getUsers())
53         if (forwardSlice->count(userOp) == 0)
54           getForwardSliceImpl(userOp, forwardSlice, filter);
55   } else {
56     assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
57     for (Value result : op->getResults()) {
58       for (Operation *userOp : result.getUsers())
59         if (forwardSlice->count(userOp) == 0)
60           getForwardSliceImpl(userOp, forwardSlice, filter);
61     }
62   }
63 
64   forwardSlice->insert(op);
65 }
66 
67 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
68                            TransitiveFilter filter) {
69   getForwardSliceImpl(op, forwardSlice, filter);
70   // Don't insert the top level operation, we just queried on it and don't
71   // want it in the results.
72   forwardSlice->remove(op);
73 
74   // Reverse to get back the actual topological order.
75   // std::reverse does not work out of the box on SetVector and I want an
76   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
77   std::vector<Operation *> v(forwardSlice->takeVector());
78   forwardSlice->insert(v.rbegin(), v.rend());
79 }
80 
81 static void getBackwardSliceImpl(Operation *op,
82                                  SetVector<Operation *> *backwardSlice,
83                                  TransitiveFilter filter) {
84   if (!op)
85     return;
86 
87   assert((op->getNumRegions() == 0 || isa<AffineForOp, scf::ForOp>(op)) &&
88          "unexpected generic op with regions");
89 
90   // Evaluate whether we should keep this def.
91   // This is useful in particular to implement scoping; i.e. return the
92   // transitive forwardSlice in the current scope.
93   if (!filter(op)) {
94     return;
95   }
96 
97   for (auto en : llvm::enumerate(op->getOperands())) {
98     auto operand = en.value();
99     if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
100       if (auto affIv = getForInductionVarOwner(operand)) {
101         auto *affOp = affIv.getOperation();
102         if (backwardSlice->count(affOp) == 0)
103           getBackwardSliceImpl(affOp, backwardSlice, filter);
104       } else if (auto loopIv = scf::getForInductionVarOwner(operand)) {
105         auto *loopOp = loopIv.getOperation();
106         if (backwardSlice->count(loopOp) == 0)
107           getBackwardSliceImpl(loopOp, backwardSlice, filter);
108       } else if (blockArg.getOwner() !=
109                  &op->getParentOfType<FuncOp>().getBody().front()) {
110         op->emitError("unsupported CF for operand ") << en.index();
111         llvm_unreachable("Unsupported control flow");
112       }
113       continue;
114     }
115     auto *op = operand.getDefiningOp();
116     if (backwardSlice->count(op) == 0) {
117       getBackwardSliceImpl(op, backwardSlice, filter);
118     }
119   }
120 
121   backwardSlice->insert(op);
122 }
123 
124 void mlir::getBackwardSlice(Operation *op,
125                             SetVector<Operation *> *backwardSlice,
126                             TransitiveFilter filter) {
127   getBackwardSliceImpl(op, backwardSlice, filter);
128 
129   // Don't insert the top level operation, we just queried on it and don't
130   // want it in the results.
131   backwardSlice->remove(op);
132 }
133 
134 SetVector<Operation *> mlir::getSlice(Operation *op,
135                                       TransitiveFilter backwardFilter,
136                                       TransitiveFilter forwardFilter) {
137   SetVector<Operation *> slice;
138   slice.insert(op);
139 
140   unsigned currentIndex = 0;
141   SetVector<Operation *> backwardSlice;
142   SetVector<Operation *> forwardSlice;
143   while (currentIndex != slice.size()) {
144     auto *currentOp = (slice)[currentIndex];
145     // Compute and insert the backwardSlice starting from currentOp.
146     backwardSlice.clear();
147     getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
148     slice.insert(backwardSlice.begin(), backwardSlice.end());
149 
150     // Compute and insert the forwardSlice starting from currentOp.
151     forwardSlice.clear();
152     getForwardSlice(currentOp, &forwardSlice, forwardFilter);
153     slice.insert(forwardSlice.begin(), forwardSlice.end());
154     ++currentIndex;
155   }
156   return topologicalSort(slice);
157 }
158 
159 namespace {
160 /// DFS post-order implementation that maintains a global count to work across
161 /// multiple invocations, to help implement topological sort on multi-root DAGs.
162 /// We traverse all operations but only record the ones that appear in
163 /// `toSort` for the final result.
164 struct DFSState {
165   DFSState(const SetVector<Operation *> &set)
166       : toSort(set), topologicalCounts(), seen() {}
167   const SetVector<Operation *> &toSort;
168   SmallVector<Operation *, 16> topologicalCounts;
169   DenseSet<Operation *> seen;
170 };
171 } // namespace
172 
173 static void DFSPostorder(Operation *current, DFSState *state) {
174   for (Value result : current->getResults()) {
175     for (Operation *op : result.getUsers())
176       DFSPostorder(op, state);
177   }
178   bool inserted;
179   using IterTy = decltype(state->seen.begin());
180   IterTy iter;
181   std::tie(iter, inserted) = state->seen.insert(current);
182   if (inserted) {
183     if (state->toSort.count(current) > 0) {
184       state->topologicalCounts.push_back(current);
185     }
186   }
187 }
188 
189 SetVector<Operation *>
190 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
191   if (toSort.empty()) {
192     return toSort;
193   }
194 
195   // Run from each root with global count and `seen` set.
196   DFSState state(toSort);
197   for (auto *s : toSort) {
198     assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
199     DFSPostorder(s, &state);
200   }
201 
202   // Reorder and return.
203   SetVector<Operation *> res;
204   for (auto it = state.topologicalCounts.rbegin(),
205             eit = state.topologicalCounts.rend();
206        it != eit; ++it) {
207     res.insert(*it);
208   }
209   return res;
210 }
211