1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/IR/Function.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/SetVector.h"
20 
21 ///
22 /// Implements Analysis functions specific to slicing in Function.
23 ///
24 
25 using namespace mlir;
26 
27 using llvm::SetVector;
28 
29 static void getForwardSliceImpl(Operation *op,
30                                 SetVector<Operation *> *forwardSlice,
31                                 TransitiveFilter filter) {
32   if (!op) {
33     return;
34   }
35 
36   // Evaluate whether we should keep this use.
37   // This is useful in particular to implement scoping; i.e. return the
38   // transitive forwardSlice in the current scope.
39   if (!filter(op)) {
40     return;
41   }
42 
43   if (auto forOp = dyn_cast<AffineForOp>(op)) {
44     for (auto *ownerInst : forOp.getInductionVar().getUsers())
45       if (forwardSlice->count(ownerInst) == 0)
46         getForwardSliceImpl(ownerInst, forwardSlice, filter);
47   } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
48     for (auto *ownerInst : forOp.getInductionVar().getUsers())
49       if (forwardSlice->count(ownerInst) == 0)
50         getForwardSliceImpl(ownerInst, forwardSlice, filter);
51   } else {
52     assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
53     assert(op->getNumResults() <= 1 && "unexpected multiple results");
54     if (op->getNumResults() > 0) {
55       for (auto *ownerInst : op->getResult(0).getUsers())
56         if (forwardSlice->count(ownerInst) == 0)
57           getForwardSliceImpl(ownerInst, forwardSlice, filter);
58     }
59   }
60 
61   forwardSlice->insert(op);
62 }
63 
64 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
65                            TransitiveFilter filter) {
66   getForwardSliceImpl(op, forwardSlice, filter);
67   // Don't insert the top level operation, we just queried on it and don't
68   // want it in the results.
69   forwardSlice->remove(op);
70 
71   // Reverse to get back the actual topological order.
72   // std::reverse does not work out of the box on SetVector and I want an
73   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
74   std::vector<Operation *> v(forwardSlice->takeVector());
75   forwardSlice->insert(v.rbegin(), v.rend());
76 }
77 
78 static void getBackwardSliceImpl(Operation *op,
79                                  SetVector<Operation *> *backwardSlice,
80                                  TransitiveFilter filter) {
81   if (!op)
82     return;
83 
84   assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) ||
85           isa<scf::ForOp>(op)) &&
86          "unexpected generic op with regions");
87 
88   // Evaluate whether we should keep this def.
89   // This is useful in particular to implement scoping; i.e. return the
90   // transitive forwardSlice in the current scope.
91   if (!filter(op)) {
92     return;
93   }
94 
95   for (auto en : llvm::enumerate(op->getOperands())) {
96     auto operand = en.value();
97     if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
98       if (auto affIv = getForInductionVarOwner(operand)) {
99         auto *affOp = affIv.getOperation();
100         if (backwardSlice->count(affOp) == 0)
101           getBackwardSliceImpl(affOp, backwardSlice, filter);
102       } else if (auto loopIv = scf::getForInductionVarOwner(operand)) {
103         auto *loopOp = loopIv.getOperation();
104         if (backwardSlice->count(loopOp) == 0)
105           getBackwardSliceImpl(loopOp, backwardSlice, filter);
106       } else if (blockArg.getOwner() !=
107                  &op->getParentOfType<FuncOp>().getBody().front()) {
108         op->emitError("unsupported CF for operand ") << en.index();
109         llvm_unreachable("Unsupported control flow");
110       }
111       continue;
112     }
113     auto *op = operand.getDefiningOp();
114     if (backwardSlice->count(op) == 0) {
115       getBackwardSliceImpl(op, backwardSlice, filter);
116     }
117   }
118 
119   backwardSlice->insert(op);
120 }
121 
122 void mlir::getBackwardSlice(Operation *op,
123                             SetVector<Operation *> *backwardSlice,
124                             TransitiveFilter filter) {
125   getBackwardSliceImpl(op, backwardSlice, filter);
126 
127   // Don't insert the top level operation, we just queried on it and don't
128   // want it in the results.
129   backwardSlice->remove(op);
130 }
131 
132 SetVector<Operation *> mlir::getSlice(Operation *op,
133                                       TransitiveFilter backwardFilter,
134                                       TransitiveFilter forwardFilter) {
135   SetVector<Operation *> slice;
136   slice.insert(op);
137 
138   unsigned currentIndex = 0;
139   SetVector<Operation *> backwardSlice;
140   SetVector<Operation *> forwardSlice;
141   while (currentIndex != slice.size()) {
142     auto *currentInst = (slice)[currentIndex];
143     // Compute and insert the backwardSlice starting from currentInst.
144     backwardSlice.clear();
145     getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
146     slice.insert(backwardSlice.begin(), backwardSlice.end());
147 
148     // Compute and insert the forwardSlice starting from currentInst.
149     forwardSlice.clear();
150     getForwardSlice(currentInst, &forwardSlice, forwardFilter);
151     slice.insert(forwardSlice.begin(), forwardSlice.end());
152     ++currentIndex;
153   }
154   return topologicalSort(slice);
155 }
156 
157 namespace {
158 /// DFS post-order implementation that maintains a global count to work across
159 /// multiple invocations, to help implement topological sort on multi-root DAGs.
160 /// We traverse all operations but only record the ones that appear in
161 /// `toSort` for the final result.
162 struct DFSState {
163   DFSState(const SetVector<Operation *> &set)
164       : toSort(set), topologicalCounts(), seen() {}
165   const SetVector<Operation *> &toSort;
166   SmallVector<Operation *, 16> topologicalCounts;
167   DenseSet<Operation *> seen;
168 };
169 } // namespace
170 
171 static void DFSPostorder(Operation *current, DFSState *state) {
172   assert(current->getNumResults() <= 1 && "NYI: multi-result");
173   if (current->getNumResults() > 0) {
174     for (auto &u : current->getResult(0).getUses()) {
175       auto *op = u.getOwner();
176       DFSPostorder(op, state);
177     }
178   }
179   bool inserted;
180   using IterTy = decltype(state->seen.begin());
181   IterTy iter;
182   std::tie(iter, inserted) = state->seen.insert(current);
183   if (inserted) {
184     if (state->toSort.count(current) > 0) {
185       state->topologicalCounts.push_back(current);
186     }
187   }
188 }
189 
190 SetVector<Operation *>
191 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
192   if (toSort.empty()) {
193     return toSort;
194   }
195 
196   // Run from each root with global count and `seen` set.
197   DFSState state(toSort);
198   for (auto *s : toSort) {
199     assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
200     DFSPostorder(s, &state);
201   }
202 
203   // Reorder and return.
204   SetVector<Operation *> res;
205   for (auto it = state.topologicalCounts.rbegin(),
206             eit = state.topologicalCounts.rend();
207        it != eit; ++it) {
208     res.insert(*it);
209   }
210   return res;
211 }
212