1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements Analysis functions specific to slicing in Function.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Analysis/SliceAnalysis.h"
23 #include "mlir/Dialect/AffineOps/AffineOps.h"
24 #include "mlir/Dialect/LoopOps/LoopOps.h"
25 #include "mlir/IR/Function.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/Support/Functional.h"
28 #include "mlir/Support/LLVM.h"
29 #include "mlir/Support/STLExtras.h"
30 #include "llvm/ADT/SetVector.h"
31 
32 ///
33 /// Implements Analysis functions specific to slicing in Function.
34 ///
35 
36 using namespace mlir;
37 
38 using llvm::SetVector;
39 
40 static void getForwardSliceImpl(Operation *op,
41                                 SetVector<Operation *> *forwardSlice,
42                                 TransitiveFilter filter) {
43   if (!op) {
44     return;
45   }
46 
47   // Evaluate whether we should keep this use.
48   // This is useful in particular to implement scoping; i.e. return the
49   // transitive forwardSlice in the current scope.
50   if (!filter(op)) {
51     return;
52   }
53 
54   if (auto forOp = dyn_cast<AffineForOp>(op)) {
55     for (auto *ownerInst : forOp.getInductionVar()->getUsers())
56       if (forwardSlice->count(ownerInst) == 0)
57         getForwardSliceImpl(ownerInst, forwardSlice, filter);
58   } else if (auto forOp = dyn_cast<loop::ForOp>(op)) {
59     for (auto *ownerInst : forOp.getInductionVar()->getUsers())
60       if (forwardSlice->count(ownerInst) == 0)
61         getForwardSliceImpl(ownerInst, forwardSlice, filter);
62   } else {
63     assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
64     assert(op->getNumResults() <= 1 && "unexpected multiple results");
65     if (op->getNumResults() > 0) {
66       for (auto *ownerInst : op->getResult(0)->getUsers())
67         if (forwardSlice->count(ownerInst) == 0)
68           getForwardSliceImpl(ownerInst, forwardSlice, filter);
69     }
70   }
71 
72   forwardSlice->insert(op);
73 }
74 
75 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
76                            TransitiveFilter filter) {
77   getForwardSliceImpl(op, forwardSlice, filter);
78   // Don't insert the top level operation, we just queried on it and don't
79   // want it in the results.
80   forwardSlice->remove(op);
81 
82   // Reverse to get back the actual topological order.
83   // std::reverse does not work out of the box on SetVector and I want an
84   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
85   std::vector<Operation *> v(forwardSlice->takeVector());
86   forwardSlice->insert(v.rbegin(), v.rend());
87 }
88 
89 static void getBackwardSliceImpl(Operation *op,
90                                  SetVector<Operation *> *backwardSlice,
91                                  TransitiveFilter filter) {
92   if (!op)
93     return;
94 
95   assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) ||
96           isa<loop::ForOp>(op)) &&
97          "unexpected generic op with regions");
98 
99   // Evaluate whether we should keep this def.
100   // This is useful in particular to implement scoping; i.e. return the
101   // transitive forwardSlice in the current scope.
102   if (!filter(op)) {
103     return;
104   }
105 
106   for (auto en : llvm::enumerate(op->getOperands())) {
107     auto *operand = en.value();
108     if (auto *blockArg = dyn_cast<BlockArgument>(operand)) {
109       if (auto affIv = getForInductionVarOwner(operand)) {
110         auto *affOp = affIv.getOperation();
111         if (backwardSlice->count(affOp) == 0)
112           getBackwardSliceImpl(affOp, backwardSlice, filter);
113       } else if (auto loopIv = loop::getForInductionVarOwner(operand)) {
114         auto *loopOp = loopIv.getOperation();
115         if (backwardSlice->count(loopOp) == 0)
116           getBackwardSliceImpl(loopOp, backwardSlice, filter);
117       } else if (blockArg->getOwner() !=
118                  &op->getParentOfType<FuncOp>().getBody().front()) {
119         op->emitError("unsupported CF for operand ") << en.index();
120         llvm_unreachable("Unsupported control flow");
121       }
122       continue;
123     }
124     auto *op = operand->getDefiningOp();
125     if (backwardSlice->count(op) == 0) {
126       getBackwardSliceImpl(op, backwardSlice, filter);
127     }
128   }
129 
130   backwardSlice->insert(op);
131 }
132 
133 void mlir::getBackwardSlice(Operation *op,
134                             SetVector<Operation *> *backwardSlice,
135                             TransitiveFilter filter) {
136   getBackwardSliceImpl(op, backwardSlice, filter);
137 
138   // Don't insert the top level operation, we just queried on it and don't
139   // want it in the results.
140   backwardSlice->remove(op);
141 }
142 
143 SetVector<Operation *> mlir::getSlice(Operation *op,
144                                       TransitiveFilter backwardFilter,
145                                       TransitiveFilter forwardFilter) {
146   SetVector<Operation *> slice;
147   slice.insert(op);
148 
149   unsigned currentIndex = 0;
150   SetVector<Operation *> backwardSlice;
151   SetVector<Operation *> forwardSlice;
152   while (currentIndex != slice.size()) {
153     auto *currentInst = (slice)[currentIndex];
154     // Compute and insert the backwardSlice starting from currentInst.
155     backwardSlice.clear();
156     getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
157     slice.insert(backwardSlice.begin(), backwardSlice.end());
158 
159     // Compute and insert the forwardSlice starting from currentInst.
160     forwardSlice.clear();
161     getForwardSlice(currentInst, &forwardSlice, forwardFilter);
162     slice.insert(forwardSlice.begin(), forwardSlice.end());
163     ++currentIndex;
164   }
165   return topologicalSort(slice);
166 }
167 
168 namespace {
169 /// DFS post-order implementation that maintains a global count to work across
170 /// multiple invocations, to help implement topological sort on multi-root DAGs.
171 /// We traverse all operations but only record the ones that appear in
172 /// `toSort` for the final result.
173 struct DFSState {
174   DFSState(const SetVector<Operation *> &set)
175       : toSort(set), topologicalCounts(), seen() {}
176   const SetVector<Operation *> &toSort;
177   SmallVector<Operation *, 16> topologicalCounts;
178   DenseSet<Operation *> seen;
179 };
180 } // namespace
181 
182 static void DFSPostorder(Operation *current, DFSState *state) {
183   assert(current->getNumResults() <= 1 && "NYI: multi-result");
184   if (current->getNumResults() > 0) {
185     for (auto &u : current->getResult(0)->getUses()) {
186       auto *op = u.getOwner();
187       DFSPostorder(op, state);
188     }
189   }
190   bool inserted;
191   using IterTy = decltype(state->seen.begin());
192   IterTy iter;
193   std::tie(iter, inserted) = state->seen.insert(current);
194   if (inserted) {
195     if (state->toSort.count(current) > 0) {
196       state->topologicalCounts.push_back(current);
197     }
198   }
199 }
200 
201 SetVector<Operation *>
202 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
203   if (toSort.empty()) {
204     return toSort;
205   }
206 
207   // Run from each root with global count and `seen` set.
208   DFSState state(toSort);
209   for (auto *s : toSort) {
210     assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
211     DFSPostorder(s, &state);
212   }
213 
214   // Reorder and return.
215   SetVector<Operation *> res;
216   for (auto it = state.topologicalCounts.rbegin(),
217             eit = state.topologicalCounts.rend();
218        it != eit; ++it) {
219     res.insert(*it);
220   }
221   return res;
222 }
223