1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Analysis functions specific to slicing in Function. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/Support/LLVM.h" 20 #include "llvm/ADT/SetVector.h" 21 22 /// 23 /// Implements Analysis functions specific to slicing in Function. 24 /// 25 26 using namespace mlir; 27 28 using llvm::SetVector; 29 30 static void getForwardSliceImpl(Operation *op, 31 SetVector<Operation *> *forwardSlice, 32 TransitiveFilter filter) { 33 if (!op) 34 return; 35 36 // Evaluate whether we should keep this use. 37 // This is useful in particular to implement scoping; i.e. return the 38 // transitive forwardSlice in the current scope. 39 if (filter && !filter(op)) 40 return; 41 42 for (Region ®ion : op->getRegions()) 43 for (Block &block : region) 44 for (Operation &blockOp : block) 45 if (forwardSlice->count(&blockOp) == 0) 46 getForwardSliceImpl(&blockOp, forwardSlice, filter); 47 for (Value result : op->getResults()) { 48 for (Operation *userOp : result.getUsers()) 49 if (forwardSlice->count(userOp) == 0) 50 getForwardSliceImpl(userOp, forwardSlice, filter); 51 } 52 53 forwardSlice->insert(op); 54 } 55 56 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 57 TransitiveFilter filter) { 58 getForwardSliceImpl(op, forwardSlice, filter); 59 // Don't insert the top level operation, we just queried on it and don't 60 // want it in the results. 61 forwardSlice->remove(op); 62 63 // Reverse to get back the actual topological order. 64 // std::reverse does not work out of the box on SetVector and I want an 65 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 66 std::vector<Operation *> v(forwardSlice->takeVector()); 67 forwardSlice->insert(v.rbegin(), v.rend()); 68 } 69 70 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, 71 TransitiveFilter filter) { 72 for (Operation *user : root.getUsers()) 73 getForwardSliceImpl(user, forwardSlice, filter); 74 75 // Reverse to get back the actual topological order. 76 // std::reverse does not work out of the box on SetVector and I want an 77 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 78 std::vector<Operation *> v(forwardSlice->takeVector()); 79 forwardSlice->insert(v.rbegin(), v.rend()); 80 } 81 82 static void getBackwardSliceImpl(Operation *op, 83 SetVector<Operation *> *backwardSlice, 84 TransitiveFilter filter) { 85 if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 86 return; 87 88 // Evaluate whether we should keep this def. 89 // This is useful in particular to implement scoping; i.e. return the 90 // transitive backwardSlice in the current scope. 91 if (filter && !filter(op)) 92 return; 93 94 for (auto en : llvm::enumerate(op->getOperands())) { 95 auto operand = en.value(); 96 if (auto *definingOp = operand.getDefiningOp()) { 97 if (backwardSlice->count(definingOp) == 0) 98 getBackwardSliceImpl(definingOp, backwardSlice, filter); 99 } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) { 100 Block *block = blockArg.getOwner(); 101 Operation *parentOp = block->getParentOp(); 102 // TODO: determine whether we want to recurse backward into the other 103 // blocks of parentOp, which are not technically backward unless they flow 104 // into us. For now, just bail. 105 assert(parentOp->getNumRegions() == 1 && 106 parentOp->getRegion(0).getBlocks().size() == 1); 107 if (backwardSlice->count(parentOp) == 0) 108 getBackwardSliceImpl(parentOp, backwardSlice, filter); 109 } else { 110 llvm_unreachable("No definingOp and not a block argument."); 111 } 112 } 113 114 backwardSlice->insert(op); 115 } 116 117 void mlir::getBackwardSlice(Operation *op, 118 SetVector<Operation *> *backwardSlice, 119 TransitiveFilter filter) { 120 getBackwardSliceImpl(op, backwardSlice, filter); 121 122 // Don't insert the top level operation, we just queried on it and don't 123 // want it in the results. 124 backwardSlice->remove(op); 125 } 126 127 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice, 128 TransitiveFilter filter) { 129 if (Operation *definingOp = root.getDefiningOp()) { 130 getBackwardSlice(definingOp, backwardSlice, filter); 131 return; 132 } 133 Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp(); 134 getBackwardSlice(bbAargOwner, backwardSlice, filter); 135 } 136 137 SetVector<Operation *> mlir::getSlice(Operation *op, 138 TransitiveFilter backwardFilter, 139 TransitiveFilter forwardFilter) { 140 SetVector<Operation *> slice; 141 slice.insert(op); 142 143 unsigned currentIndex = 0; 144 SetVector<Operation *> backwardSlice; 145 SetVector<Operation *> forwardSlice; 146 while (currentIndex != slice.size()) { 147 auto *currentOp = (slice)[currentIndex]; 148 // Compute and insert the backwardSlice starting from currentOp. 149 backwardSlice.clear(); 150 getBackwardSlice(currentOp, &backwardSlice, backwardFilter); 151 slice.insert(backwardSlice.begin(), backwardSlice.end()); 152 153 // Compute and insert the forwardSlice starting from currentOp. 154 forwardSlice.clear(); 155 getForwardSlice(currentOp, &forwardSlice, forwardFilter); 156 slice.insert(forwardSlice.begin(), forwardSlice.end()); 157 ++currentIndex; 158 } 159 return topologicalSort(slice); 160 } 161 162 namespace { 163 /// DFS post-order implementation that maintains a global count to work across 164 /// multiple invocations, to help implement topological sort on multi-root DAGs. 165 /// We traverse all operations but only record the ones that appear in 166 /// `toSort` for the final result. 167 struct DFSState { 168 DFSState(const SetVector<Operation *> &set) 169 : toSort(set), topologicalCounts(), seen() {} 170 const SetVector<Operation *> &toSort; 171 SmallVector<Operation *, 16> topologicalCounts; 172 DenseSet<Operation *> seen; 173 }; 174 } // namespace 175 176 static void DFSPostorder(Operation *current, DFSState *state) { 177 for (Value result : current->getResults()) { 178 for (Operation *op : result.getUsers()) 179 DFSPostorder(op, state); 180 } 181 bool inserted; 182 using IterTy = decltype(state->seen.begin()); 183 IterTy iter; 184 std::tie(iter, inserted) = state->seen.insert(current); 185 if (inserted) { 186 if (state->toSort.count(current) > 0) { 187 state->topologicalCounts.push_back(current); 188 } 189 } 190 } 191 192 SetVector<Operation *> 193 mlir::topologicalSort(const SetVector<Operation *> &toSort) { 194 if (toSort.empty()) { 195 return toSort; 196 } 197 198 // Run from each root with global count and `seen` set. 199 DFSState state(toSort); 200 for (auto *s : toSort) { 201 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); 202 DFSPostorder(s, &state); 203 } 204 205 // Reorder and return. 206 SetVector<Operation *> res; 207 for (auto it = state.topologicalCounts.rbegin(), 208 eit = state.topologicalCounts.rend(); 209 it != eit; ++it) { 210 res.insert(*it); 211 } 212 return res; 213 } 214