1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/LoopOps/LoopOps.h"
16 #include "mlir/IR/Function.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/LLVM.h"
19 #include "mlir/Support/STLExtras.h"
20 #include "llvm/ADT/SetVector.h"
21 
22 ///
23 /// Implements Analysis functions specific to slicing in Function.
24 ///
25 
26 using namespace mlir;
27 
28 using llvm::SetVector;
29 
30 static void getForwardSliceImpl(Operation *op,
31                                 SetVector<Operation *> *forwardSlice,
32                                 TransitiveFilter filter) {
33   if (!op) {
34     return;
35   }
36 
37   // Evaluate whether we should keep this use.
38   // This is useful in particular to implement scoping; i.e. return the
39   // transitive forwardSlice in the current scope.
40   if (!filter(op)) {
41     return;
42   }
43 
44   if (auto forOp = dyn_cast<AffineForOp>(op)) {
45     for (auto *ownerInst : forOp.getInductionVar().getUsers())
46       if (forwardSlice->count(ownerInst) == 0)
47         getForwardSliceImpl(ownerInst, forwardSlice, filter);
48   } else if (auto forOp = dyn_cast<loop::ForOp>(op)) {
49     for (auto *ownerInst : forOp.getInductionVar().getUsers())
50       if (forwardSlice->count(ownerInst) == 0)
51         getForwardSliceImpl(ownerInst, forwardSlice, filter);
52   } else {
53     assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
54     assert(op->getNumResults() <= 1 && "unexpected multiple results");
55     if (op->getNumResults() > 0) {
56       for (auto *ownerInst : op->getResult(0).getUsers())
57         if (forwardSlice->count(ownerInst) == 0)
58           getForwardSliceImpl(ownerInst, forwardSlice, filter);
59     }
60   }
61 
62   forwardSlice->insert(op);
63 }
64 
65 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
66                            TransitiveFilter filter) {
67   getForwardSliceImpl(op, forwardSlice, filter);
68   // Don't insert the top level operation, we just queried on it and don't
69   // want it in the results.
70   forwardSlice->remove(op);
71 
72   // Reverse to get back the actual topological order.
73   // std::reverse does not work out of the box on SetVector and I want an
74   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
75   std::vector<Operation *> v(forwardSlice->takeVector());
76   forwardSlice->insert(v.rbegin(), v.rend());
77 }
78 
79 static void getBackwardSliceImpl(Operation *op,
80                                  SetVector<Operation *> *backwardSlice,
81                                  TransitiveFilter filter) {
82   if (!op)
83     return;
84 
85   assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) ||
86           isa<loop::ForOp>(op)) &&
87          "unexpected generic op with regions");
88 
89   // Evaluate whether we should keep this def.
90   // This is useful in particular to implement scoping; i.e. return the
91   // transitive forwardSlice in the current scope.
92   if (!filter(op)) {
93     return;
94   }
95 
96   for (auto en : llvm::enumerate(op->getOperands())) {
97     auto operand = en.value();
98     if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
99       if (auto affIv = getForInductionVarOwner(operand)) {
100         auto *affOp = affIv.getOperation();
101         if (backwardSlice->count(affOp) == 0)
102           getBackwardSliceImpl(affOp, backwardSlice, filter);
103       } else if (auto loopIv = loop::getForInductionVarOwner(operand)) {
104         auto *loopOp = loopIv.getOperation();
105         if (backwardSlice->count(loopOp) == 0)
106           getBackwardSliceImpl(loopOp, backwardSlice, filter);
107       } else if (blockArg.getOwner() !=
108                  &op->getParentOfType<FuncOp>().getBody().front()) {
109         op->emitError("unsupported CF for operand ") << en.index();
110         llvm_unreachable("Unsupported control flow");
111       }
112       continue;
113     }
114     auto *op = operand.getDefiningOp();
115     if (backwardSlice->count(op) == 0) {
116       getBackwardSliceImpl(op, backwardSlice, filter);
117     }
118   }
119 
120   backwardSlice->insert(op);
121 }
122 
123 void mlir::getBackwardSlice(Operation *op,
124                             SetVector<Operation *> *backwardSlice,
125                             TransitiveFilter filter) {
126   getBackwardSliceImpl(op, backwardSlice, filter);
127 
128   // Don't insert the top level operation, we just queried on it and don't
129   // want it in the results.
130   backwardSlice->remove(op);
131 }
132 
133 SetVector<Operation *> mlir::getSlice(Operation *op,
134                                       TransitiveFilter backwardFilter,
135                                       TransitiveFilter forwardFilter) {
136   SetVector<Operation *> slice;
137   slice.insert(op);
138 
139   unsigned currentIndex = 0;
140   SetVector<Operation *> backwardSlice;
141   SetVector<Operation *> forwardSlice;
142   while (currentIndex != slice.size()) {
143     auto *currentInst = (slice)[currentIndex];
144     // Compute and insert the backwardSlice starting from currentInst.
145     backwardSlice.clear();
146     getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
147     slice.insert(backwardSlice.begin(), backwardSlice.end());
148 
149     // Compute and insert the forwardSlice starting from currentInst.
150     forwardSlice.clear();
151     getForwardSlice(currentInst, &forwardSlice, forwardFilter);
152     slice.insert(forwardSlice.begin(), forwardSlice.end());
153     ++currentIndex;
154   }
155   return topologicalSort(slice);
156 }
157 
158 namespace {
159 /// DFS post-order implementation that maintains a global count to work across
160 /// multiple invocations, to help implement topological sort on multi-root DAGs.
161 /// We traverse all operations but only record the ones that appear in
162 /// `toSort` for the final result.
163 struct DFSState {
164   DFSState(const SetVector<Operation *> &set)
165       : toSort(set), topologicalCounts(), seen() {}
166   const SetVector<Operation *> &toSort;
167   SmallVector<Operation *, 16> topologicalCounts;
168   DenseSet<Operation *> seen;
169 };
170 } // namespace
171 
172 static void DFSPostorder(Operation *current, DFSState *state) {
173   assert(current->getNumResults() <= 1 && "NYI: multi-result");
174   if (current->getNumResults() > 0) {
175     for (auto &u : current->getResult(0).getUses()) {
176       auto *op = u.getOwner();
177       DFSPostorder(op, state);
178     }
179   }
180   bool inserted;
181   using IterTy = decltype(state->seen.begin());
182   IterTy iter;
183   std::tie(iter, inserted) = state->seen.insert(current);
184   if (inserted) {
185     if (state->toSort.count(current) > 0) {
186       state->topologicalCounts.push_back(current);
187     }
188   }
189 }
190 
191 SetVector<Operation *>
192 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
193   if (toSort.empty()) {
194     return toSort;
195   }
196 
197   // Run from each root with global count and `seen` set.
198   DFSState state(toSort);
199   for (auto *s : toSort) {
200     assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
201     DFSPostorder(s, &state);
202   }
203 
204   // Reorder and return.
205   SetVector<Operation *> res;
206   for (auto it = state.topologicalCounts.rbegin(),
207             eit = state.topologicalCounts.rend();
208        it != eit; ++it) {
209     res.insert(*it);
210   }
211   return res;
212 }
213