1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements Analysis functions specific to slicing in Function.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Analysis/SliceAnalysis.h"
23 #include "mlir/Analysis/VectorAnalysis.h"
24 #include "mlir/Dialect/AffineOps/AffineOps.h"
25 #include "mlir/Dialect/LoopOps/LoopOps.h"
26 #include "mlir/IR/Function.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/Support/Functional.h"
29 #include "mlir/Support/LLVM.h"
30 #include "mlir/Support/STLExtras.h"
31 #include "llvm/ADT/SetVector.h"
32 
33 ///
34 /// Implements Analysis functions specific to slicing in Function.
35 ///
36 
37 using namespace mlir;
38 
39 using llvm::SetVector;
40 
41 static void getForwardSliceImpl(Operation *op,
42                                 SetVector<Operation *> *forwardSlice,
43                                 TransitiveFilter filter) {
44   if (!op) {
45     return;
46   }
47 
48   // Evaluate whether we should keep this use.
49   // This is useful in particular to implement scoping; i.e. return the
50   // transitive forwardSlice in the current scope.
51   if (!filter(op)) {
52     return;
53   }
54 
55   if (auto forOp = dyn_cast<AffineForOp>(op)) {
56     for (auto *ownerInst : forOp.getInductionVar()->getUsers())
57       if (forwardSlice->count(ownerInst) == 0)
58         getForwardSliceImpl(ownerInst, forwardSlice, filter);
59   } else if (auto forOp = dyn_cast<loop::ForOp>(op)) {
60     for (auto *ownerInst : forOp.getInductionVar()->getUsers())
61       if (forwardSlice->count(ownerInst) == 0)
62         getForwardSliceImpl(ownerInst, forwardSlice, filter);
63   } else {
64     assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
65     assert(op->getNumResults() <= 1 && "unexpected multiple results");
66     if (op->getNumResults() > 0) {
67       for (auto *ownerInst : op->getResult(0)->getUsers())
68         if (forwardSlice->count(ownerInst) == 0)
69           getForwardSliceImpl(ownerInst, forwardSlice, filter);
70     }
71   }
72 
73   forwardSlice->insert(op);
74 }
75 
76 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
77                            TransitiveFilter filter) {
78   getForwardSliceImpl(op, forwardSlice, filter);
79   // Don't insert the top level operation, we just queried on it and don't
80   // want it in the results.
81   forwardSlice->remove(op);
82 
83   // Reverse to get back the actual topological order.
84   // std::reverse does not work out of the box on SetVector and I want an
85   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
86   std::vector<Operation *> v(forwardSlice->takeVector());
87   forwardSlice->insert(v.rbegin(), v.rend());
88 }
89 
90 static void getBackwardSliceImpl(Operation *op,
91                                  SetVector<Operation *> *backwardSlice,
92                                  TransitiveFilter filter) {
93   if (!op)
94     return;
95 
96   assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) ||
97           isa<loop::ForOp>(op)) &&
98          "unexpected generic op with regions");
99 
100   // Evaluate whether we should keep this def.
101   // This is useful in particular to implement scoping; i.e. return the
102   // transitive forwardSlice in the current scope.
103   if (!filter(op)) {
104     return;
105   }
106 
107   for (auto en : llvm::enumerate(op->getOperands())) {
108     auto *operand = en.value();
109     if (auto *blockArg = dyn_cast<BlockArgument>(operand)) {
110       if (auto affIv = getForInductionVarOwner(operand)) {
111         auto *affOp = affIv.getOperation();
112         if (backwardSlice->count(affOp) == 0)
113           getBackwardSliceImpl(affOp, backwardSlice, filter);
114       } else if (auto loopIv = loop::getForInductionVarOwner(operand)) {
115         auto *loopOp = loopIv.getOperation();
116         if (backwardSlice->count(loopOp) == 0)
117           getBackwardSliceImpl(loopOp, backwardSlice, filter);
118       } else if (blockArg->getOwner() !=
119                  &op->getParentOfType<FuncOp>().getBody().front()) {
120         op->emitError("Unsupported CF for operand ") << en.index();
121         llvm_unreachable("Unsupported control flow");
122       }
123       continue;
124     }
125     auto *op = operand->getDefiningOp();
126     if (backwardSlice->count(op) == 0) {
127       getBackwardSliceImpl(op, backwardSlice, filter);
128     }
129   }
130 
131   backwardSlice->insert(op);
132 }
133 
134 void mlir::getBackwardSlice(Operation *op,
135                             SetVector<Operation *> *backwardSlice,
136                             TransitiveFilter filter) {
137   getBackwardSliceImpl(op, backwardSlice, filter);
138 
139   // Don't insert the top level operation, we just queried on it and don't
140   // want it in the results.
141   backwardSlice->remove(op);
142 }
143 
144 SetVector<Operation *> mlir::getSlice(Operation *op,
145                                       TransitiveFilter backwardFilter,
146                                       TransitiveFilter forwardFilter) {
147   SetVector<Operation *> slice;
148   slice.insert(op);
149 
150   unsigned currentIndex = 0;
151   SetVector<Operation *> backwardSlice;
152   SetVector<Operation *> forwardSlice;
153   while (currentIndex != slice.size()) {
154     auto *currentInst = (slice)[currentIndex];
155     // Compute and insert the backwardSlice starting from currentInst.
156     backwardSlice.clear();
157     getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
158     slice.insert(backwardSlice.begin(), backwardSlice.end());
159 
160     // Compute and insert the forwardSlice starting from currentInst.
161     forwardSlice.clear();
162     getForwardSlice(currentInst, &forwardSlice, forwardFilter);
163     slice.insert(forwardSlice.begin(), forwardSlice.end());
164     ++currentIndex;
165   }
166   return topologicalSort(slice);
167 }
168 
169 namespace {
170 /// DFS post-order implementation that maintains a global count to work across
171 /// multiple invocations, to help implement topological sort on multi-root DAGs.
172 /// We traverse all operations but only record the ones that appear in
173 /// `toSort` for the final result.
174 struct DFSState {
175   DFSState(const SetVector<Operation *> &set)
176       : toSort(set), topologicalCounts(), seen() {}
177   const SetVector<Operation *> &toSort;
178   SmallVector<Operation *, 16> topologicalCounts;
179   DenseSet<Operation *> seen;
180 };
181 } // namespace
182 
183 static void DFSPostorder(Operation *current, DFSState *state) {
184   assert(current->getNumResults() <= 1 && "NYI: multi-result");
185   if (current->getNumResults() > 0) {
186     for (auto &u : current->getResult(0)->getUses()) {
187       auto *op = u.getOwner();
188       DFSPostorder(op, state);
189     }
190   }
191   bool inserted;
192   using IterTy = decltype(state->seen.begin());
193   IterTy iter;
194   std::tie(iter, inserted) = state->seen.insert(current);
195   if (inserted) {
196     if (state->toSort.count(current) > 0) {
197       state->topologicalCounts.push_back(current);
198     }
199   }
200 }
201 
202 SetVector<Operation *>
203 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
204   if (toSort.empty()) {
205     return toSort;
206   }
207 
208   // Run from each root with global count and `seen` set.
209   DFSState state(toSort);
210   for (auto *s : toSort) {
211     assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
212     DFSPostorder(s, &state);
213   }
214 
215   // Reorder and return.
216   SetVector<Operation *> res;
217   for (auto it = state.topologicalCounts.rbegin(),
218             eit = state.topologicalCounts.rend();
219        it != eit; ++it) {
220     res.insert(*it);
221   }
222   return res;
223 }
224