1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 
20 ///
21 /// Implements Analysis functions specific to slicing in Function.
22 ///
23 
24 using namespace mlir;
25 
getForwardSliceImpl(Operation * op,SetVector<Operation * > * forwardSlice,TransitiveFilter filter)26 static void getForwardSliceImpl(Operation *op,
27                                 SetVector<Operation *> *forwardSlice,
28                                 TransitiveFilter filter) {
29   if (!op)
30     return;
31 
32   // Evaluate whether we should keep this use.
33   // This is useful in particular to implement scoping; i.e. return the
34   // transitive forwardSlice in the current scope.
35   if (filter && !filter(op))
36     return;
37 
38   for (Region &region : op->getRegions())
39     for (Block &block : region)
40       for (Operation &blockOp : block)
41         if (forwardSlice->count(&blockOp) == 0)
42           getForwardSliceImpl(&blockOp, forwardSlice, filter);
43   for (Value result : op->getResults()) {
44     for (Operation *userOp : result.getUsers())
45       if (forwardSlice->count(userOp) == 0)
46         getForwardSliceImpl(userOp, forwardSlice, filter);
47   }
48 
49   forwardSlice->insert(op);
50 }
51 
getForwardSlice(Operation * op,SetVector<Operation * > * forwardSlice,TransitiveFilter filter)52 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
53                            TransitiveFilter filter) {
54   getForwardSliceImpl(op, forwardSlice, filter);
55   // Don't insert the top level operation, we just queried on it and don't
56   // want it in the results.
57   forwardSlice->remove(op);
58 
59   // Reverse to get back the actual topological order.
60   // std::reverse does not work out of the box on SetVector and I want an
61   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
62   std::vector<Operation *> v(forwardSlice->takeVector());
63   forwardSlice->insert(v.rbegin(), v.rend());
64 }
65 
getForwardSlice(Value root,SetVector<Operation * > * forwardSlice,TransitiveFilter filter)66 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
67                            TransitiveFilter filter) {
68   for (Operation *user : root.getUsers())
69     getForwardSliceImpl(user, forwardSlice, filter);
70 
71   // Reverse to get back the actual topological order.
72   // std::reverse does not work out of the box on SetVector and I want an
73   // in-place swap based thing (the real std::reverse, not the LLVM adapter).
74   std::vector<Operation *> v(forwardSlice->takeVector());
75   forwardSlice->insert(v.rbegin(), v.rend());
76 }
77 
getBackwardSliceImpl(Operation * op,SetVector<Operation * > * backwardSlice,TransitiveFilter filter)78 static void getBackwardSliceImpl(Operation *op,
79                                  SetVector<Operation *> *backwardSlice,
80                                  TransitiveFilter filter) {
81   if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
82     return;
83 
84   // Evaluate whether we should keep this def.
85   // This is useful in particular to implement scoping; i.e. return the
86   // transitive backwardSlice in the current scope.
87   if (filter && !filter(op))
88     return;
89 
90   for (const auto &en : llvm::enumerate(op->getOperands())) {
91     auto operand = en.value();
92     if (auto *definingOp = operand.getDefiningOp()) {
93       if (backwardSlice->count(definingOp) == 0)
94         getBackwardSliceImpl(definingOp, backwardSlice, filter);
95     } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
96       Block *block = blockArg.getOwner();
97       Operation *parentOp = block->getParentOp();
98       // TODO: determine whether we want to recurse backward into the other
99       // blocks of parentOp, which are not technically backward unless they flow
100       // into us. For now, just bail.
101       assert(parentOp->getNumRegions() == 1 &&
102              parentOp->getRegion(0).getBlocks().size() == 1);
103       if (backwardSlice->count(parentOp) == 0)
104         getBackwardSliceImpl(parentOp, backwardSlice, filter);
105     } else {
106       llvm_unreachable("No definingOp and not a block argument.");
107     }
108   }
109 
110   backwardSlice->insert(op);
111 }
112 
getBackwardSlice(Operation * op,SetVector<Operation * > * backwardSlice,TransitiveFilter filter)113 void mlir::getBackwardSlice(Operation *op,
114                             SetVector<Operation *> *backwardSlice,
115                             TransitiveFilter filter) {
116   getBackwardSliceImpl(op, backwardSlice, filter);
117 
118   // Don't insert the top level operation, we just queried on it and don't
119   // want it in the results.
120   backwardSlice->remove(op);
121 }
122 
getBackwardSlice(Value root,SetVector<Operation * > * backwardSlice,TransitiveFilter filter)123 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
124                             TransitiveFilter filter) {
125   if (Operation *definingOp = root.getDefiningOp()) {
126     getBackwardSlice(definingOp, backwardSlice, filter);
127     return;
128   }
129   Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
130   getBackwardSlice(bbAargOwner, backwardSlice, filter);
131 }
132 
getSlice(Operation * op,TransitiveFilter backwardFilter,TransitiveFilter forwardFilter)133 SetVector<Operation *> mlir::getSlice(Operation *op,
134                                       TransitiveFilter backwardFilter,
135                                       TransitiveFilter forwardFilter) {
136   SetVector<Operation *> slice;
137   slice.insert(op);
138 
139   unsigned currentIndex = 0;
140   SetVector<Operation *> backwardSlice;
141   SetVector<Operation *> forwardSlice;
142   while (currentIndex != slice.size()) {
143     auto *currentOp = (slice)[currentIndex];
144     // Compute and insert the backwardSlice starting from currentOp.
145     backwardSlice.clear();
146     getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
147     slice.insert(backwardSlice.begin(), backwardSlice.end());
148 
149     // Compute and insert the forwardSlice starting from currentOp.
150     forwardSlice.clear();
151     getForwardSlice(currentOp, &forwardSlice, forwardFilter);
152     slice.insert(forwardSlice.begin(), forwardSlice.end());
153     ++currentIndex;
154   }
155   return topologicalSort(slice);
156 }
157 
158 namespace {
159 /// DFS post-order implementation that maintains a global count to work across
160 /// multiple invocations, to help implement topological sort on multi-root DAGs.
161 /// We traverse all operations but only record the ones that appear in
162 /// `toSort` for the final result.
163 struct DFSState {
DFSState__anonfb076aa00111::DFSState164   DFSState(const SetVector<Operation *> &set)
165       : toSort(set), topologicalCounts(), seen() {}
166   const SetVector<Operation *> &toSort;
167   SmallVector<Operation *, 16> topologicalCounts;
168   DenseSet<Operation *> seen;
169 };
170 } // namespace
171 
dfsPostorder(Operation * root,DFSState * state)172 static void dfsPostorder(Operation *root, DFSState *state) {
173   SmallVector<Operation *> queue(1, root);
174   std::vector<Operation *> ops;
175   while (!queue.empty()) {
176     Operation *current = queue.pop_back_val();
177     ops.push_back(current);
178     for (Value result : current->getResults()) {
179       for (Operation *op : result.getUsers())
180         queue.push_back(op);
181     }
182     for (Region &region : current->getRegions()) {
183       for (Operation &op : region.getOps())
184         queue.push_back(&op);
185     }
186   }
187 
188   for (Operation *op : llvm::reverse(ops)) {
189     if (state->seen.insert(op).second && state->toSort.count(op) > 0)
190       state->topologicalCounts.push_back(op);
191   }
192 }
193 
194 SetVector<Operation *>
topologicalSort(const SetVector<Operation * > & toSort)195 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
196   if (toSort.empty()) {
197     return toSort;
198   }
199 
200   // Run from each root with global count and `seen` set.
201   DFSState state(toSort);
202   for (auto *s : toSort) {
203     assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
204     dfsPostorder(s, &state);
205   }
206 
207   // Reorder and return.
208   SetVector<Operation *> res;
209   for (auto it = state.topologicalCounts.rbegin(),
210             eit = state.topologicalCounts.rend();
211        it != eit; ++it) {
212     res.insert(*it);
213   }
214   return res;
215 }
216 
217 /// Returns true if `value` (transitively) depends on iteration-carried values
218 /// of the given `ancestorOp`.
dependsOnCarriedVals(Value value,ArrayRef<BlockArgument> iterCarriedArgs,Operation * ancestorOp)219 static bool dependsOnCarriedVals(Value value,
220                                  ArrayRef<BlockArgument> iterCarriedArgs,
221                                  Operation *ancestorOp) {
222   // Compute the backward slice of the value.
223   SetVector<Operation *> slice;
224   getBackwardSlice(value, &slice,
225                    [&](Operation *op) { return !ancestorOp->isAncestor(op); });
226 
227   // Check that none of the operands of the operations in the backward slice are
228   // loop iteration arguments, and neither is the value itself.
229   SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
230                                           iterCarriedArgs.end());
231   if (iterCarriedValSet.contains(value))
232     return true;
233 
234   for (Operation *op : slice)
235     for (Value operand : op->getOperands())
236       if (iterCarriedValSet.contains(operand))
237         return true;
238 
239   return false;
240 }
241 
242 /// Utility to match a generic reduction given a list of iteration-carried
243 /// arguments, `iterCarriedArgs` and the position of the potential reduction
244 /// argument within the list, `redPos`. If a reduction is matched, returns the
245 /// reduced value and the topologically-sorted list of combiner operations
246 /// involved in the reduction. Otherwise, returns a null value.
247 ///
248 /// The matching algorithm relies on the following invariants, which are subject
249 /// to change:
250 ///  1. The first combiner operation must be a binary operation with the
251 ///     iteration-carried value and the reduced value as operands.
252 ///  2. The iteration-carried value and combiner operations must be side
253 ///     effect-free, have single result and a single use.
254 ///  3. Combiner operations must be immediately nested in the region op
255 ///     performing the reduction.
256 ///  4. Reduction def-use chain must end in a terminator op that yields the
257 ///     next iteration/output values in the same order as the iteration-carried
258 ///     values in `iterCarriedArgs`.
259 ///  5. `iterCarriedArgs` must contain all the iteration-carried/output values
260 ///     of the region op performing the reduction.
261 ///
262 /// This utility is generic enough to detect reductions involving multiple
263 /// combiner operations (disabled for now) across multiple dialects, including
264 /// Linalg, Affine and SCF. For the sake of genericity, it does not return
265 /// specific enum values for the combiner operations since its goal is also
266 /// matching reductions without pre-defined semantics in core MLIR. It's up to
267 /// each client to make sense out of the list of combiner operations. It's also
268 /// up to each client to check for additional invariants on the expected
269 /// reductions not covered by this generic matching.
matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,unsigned redPos,SmallVectorImpl<Operation * > & combinerOps)270 Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
271                            unsigned redPos,
272                            SmallVectorImpl<Operation *> &combinerOps) {
273   assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
274 
275   BlockArgument redCarriedVal = iterCarriedArgs[redPos];
276   if (!redCarriedVal.hasOneUse())
277     return nullptr;
278 
279   // For now, the first combiner op must be a binary op.
280   Operation *combinerOp = *redCarriedVal.getUsers().begin();
281   if (combinerOp->getNumOperands() != 2)
282     return nullptr;
283   Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
284                          ? combinerOp->getOperand(1)
285                          : combinerOp->getOperand(0);
286 
287   Operation *redRegionOp =
288       iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
289   if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
290     return nullptr;
291 
292   // Traverse the def-use chain starting from the first combiner op until a
293   // terminator is found. Gather all the combiner ops along the way in
294   // topological order.
295   while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
296     if (!MemoryEffectOpInterface::hasNoEffect(combinerOp) ||
297         combinerOp->getNumResults() != 1 || !combinerOp->hasOneUse() ||
298         combinerOp->getParentOp() != redRegionOp)
299       return nullptr;
300 
301     combinerOps.push_back(combinerOp);
302     combinerOp = *combinerOp->getUsers().begin();
303   }
304 
305   // Limit matching to single combiner op until we can properly test reductions
306   // involving multiple combiners.
307   if (combinerOps.size() != 1)
308     return nullptr;
309 
310   // Check that the yielded value is in the same position as in
311   // `iterCarriedArgs`.
312   Operation *terminatorOp = combinerOp;
313   if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
314     return nullptr;
315 
316   return reducedVal;
317 }
318