1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Analysis functions specific to slicing in Function. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/Support/LLVM.h" 17 #include "llvm/ADT/SetVector.h" 18 19 /// 20 /// Implements Analysis functions specific to slicing in Function. 21 /// 22 23 using namespace mlir; 24 25 static void getForwardSliceImpl(Operation *op, 26 SetVector<Operation *> *forwardSlice, 27 TransitiveFilter filter) { 28 if (!op) 29 return; 30 31 // Evaluate whether we should keep this use. 32 // This is useful in particular to implement scoping; i.e. return the 33 // transitive forwardSlice in the current scope. 34 if (filter && !filter(op)) 35 return; 36 37 for (Region ®ion : op->getRegions()) 38 for (Block &block : region) 39 for (Operation &blockOp : block) 40 if (forwardSlice->count(&blockOp) == 0) 41 getForwardSliceImpl(&blockOp, forwardSlice, filter); 42 for (Value result : op->getResults()) { 43 for (Operation *userOp : result.getUsers()) 44 if (forwardSlice->count(userOp) == 0) 45 getForwardSliceImpl(userOp, forwardSlice, filter); 46 } 47 48 forwardSlice->insert(op); 49 } 50 51 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 52 TransitiveFilter filter) { 53 getForwardSliceImpl(op, forwardSlice, filter); 54 // Don't insert the top level operation, we just queried on it and don't 55 // want it in the results. 56 forwardSlice->remove(op); 57 58 // Reverse to get back the actual topological order. 59 // std::reverse does not work out of the box on SetVector and I want an 60 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 61 std::vector<Operation *> v(forwardSlice->takeVector()); 62 forwardSlice->insert(v.rbegin(), v.rend()); 63 } 64 65 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, 66 TransitiveFilter filter) { 67 for (Operation *user : root.getUsers()) 68 getForwardSliceImpl(user, forwardSlice, filter); 69 70 // Reverse to get back the actual topological order. 71 // std::reverse does not work out of the box on SetVector and I want an 72 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 73 std::vector<Operation *> v(forwardSlice->takeVector()); 74 forwardSlice->insert(v.rbegin(), v.rend()); 75 } 76 77 static void getBackwardSliceImpl(Operation *op, 78 SetVector<Operation *> *backwardSlice, 79 TransitiveFilter filter) { 80 if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 81 return; 82 83 // Evaluate whether we should keep this def. 84 // This is useful in particular to implement scoping; i.e. return the 85 // transitive backwardSlice in the current scope. 86 if (filter && !filter(op)) 87 return; 88 89 for (auto en : llvm::enumerate(op->getOperands())) { 90 auto operand = en.value(); 91 if (auto *definingOp = operand.getDefiningOp()) { 92 if (backwardSlice->count(definingOp) == 0) 93 getBackwardSliceImpl(definingOp, backwardSlice, filter); 94 } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) { 95 Block *block = blockArg.getOwner(); 96 Operation *parentOp = block->getParentOp(); 97 // TODO: determine whether we want to recurse backward into the other 98 // blocks of parentOp, which are not technically backward unless they flow 99 // into us. For now, just bail. 100 assert(parentOp->getNumRegions() == 1 && 101 parentOp->getRegion(0).getBlocks().size() == 1); 102 if (backwardSlice->count(parentOp) == 0) 103 getBackwardSliceImpl(parentOp, backwardSlice, filter); 104 } else { 105 llvm_unreachable("No definingOp and not a block argument."); 106 } 107 } 108 109 backwardSlice->insert(op); 110 } 111 112 void mlir::getBackwardSlice(Operation *op, 113 SetVector<Operation *> *backwardSlice, 114 TransitiveFilter filter) { 115 getBackwardSliceImpl(op, backwardSlice, filter); 116 117 // Don't insert the top level operation, we just queried on it and don't 118 // want it in the results. 119 backwardSlice->remove(op); 120 } 121 122 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice, 123 TransitiveFilter filter) { 124 if (Operation *definingOp = root.getDefiningOp()) { 125 getBackwardSlice(definingOp, backwardSlice, filter); 126 return; 127 } 128 Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp(); 129 getBackwardSlice(bbAargOwner, backwardSlice, filter); 130 } 131 132 SetVector<Operation *> mlir::getSlice(Operation *op, 133 TransitiveFilter backwardFilter, 134 TransitiveFilter forwardFilter) { 135 SetVector<Operation *> slice; 136 slice.insert(op); 137 138 unsigned currentIndex = 0; 139 SetVector<Operation *> backwardSlice; 140 SetVector<Operation *> forwardSlice; 141 while (currentIndex != slice.size()) { 142 auto *currentOp = (slice)[currentIndex]; 143 // Compute and insert the backwardSlice starting from currentOp. 144 backwardSlice.clear(); 145 getBackwardSlice(currentOp, &backwardSlice, backwardFilter); 146 slice.insert(backwardSlice.begin(), backwardSlice.end()); 147 148 // Compute and insert the forwardSlice starting from currentOp. 149 forwardSlice.clear(); 150 getForwardSlice(currentOp, &forwardSlice, forwardFilter); 151 slice.insert(forwardSlice.begin(), forwardSlice.end()); 152 ++currentIndex; 153 } 154 return topologicalSort(slice); 155 } 156 157 namespace { 158 /// DFS post-order implementation that maintains a global count to work across 159 /// multiple invocations, to help implement topological sort on multi-root DAGs. 160 /// We traverse all operations but only record the ones that appear in 161 /// `toSort` for the final result. 162 struct DFSState { 163 DFSState(const SetVector<Operation *> &set) 164 : toSort(set), topologicalCounts(), seen() {} 165 const SetVector<Operation *> &toSort; 166 SmallVector<Operation *, 16> topologicalCounts; 167 DenseSet<Operation *> seen; 168 }; 169 } // namespace 170 171 static void DFSPostorder(Operation *root, DFSState *state) { 172 SmallVector<Operation *> queue(1, root); 173 std::vector<Operation *> ops; 174 while (!queue.empty()) { 175 Operation *current = queue.pop_back_val(); 176 ops.push_back(current); 177 for (Value result : current->getResults()) { 178 for (Operation *op : result.getUsers()) 179 queue.push_back(op); 180 } 181 for (Region ®ion : current->getRegions()) { 182 for (Operation &op : region.getOps()) 183 queue.push_back(&op); 184 } 185 } 186 187 for (Operation *op : llvm::reverse(ops)) { 188 if (state->seen.insert(op).second && state->toSort.count(op) > 0) 189 state->topologicalCounts.push_back(op); 190 } 191 } 192 193 SetVector<Operation *> 194 mlir::topologicalSort(const SetVector<Operation *> &toSort) { 195 if (toSort.empty()) { 196 return toSort; 197 } 198 199 // Run from each root with global count and `seen` set. 200 DFSState state(toSort); 201 for (auto *s : toSort) { 202 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); 203 DFSPostorder(s, &state); 204 } 205 206 // Reorder and return. 207 SetVector<Operation *> res; 208 for (auto it = state.topologicalCounts.rbegin(), 209 eit = state.topologicalCounts.rend(); 210 it != eit; ++it) { 211 res.insert(*it); 212 } 213 return res; 214 } 215