1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Analysis functions specific to slicing in Function. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/Support/LLVM.h" 20 #include "llvm/ADT/SetVector.h" 21 22 /// 23 /// Implements Analysis functions specific to slicing in Function. 24 /// 25 26 using namespace mlir; 27 28 static void getForwardSliceImpl(Operation *op, 29 SetVector<Operation *> *forwardSlice, 30 TransitiveFilter filter) { 31 if (!op) 32 return; 33 34 // Evaluate whether we should keep this use. 35 // This is useful in particular to implement scoping; i.e. return the 36 // transitive forwardSlice in the current scope. 37 if (filter && !filter(op)) 38 return; 39 40 for (Region ®ion : op->getRegions()) 41 for (Block &block : region) 42 for (Operation &blockOp : block) 43 if (forwardSlice->count(&blockOp) == 0) 44 getForwardSliceImpl(&blockOp, forwardSlice, filter); 45 for (Value result : op->getResults()) { 46 for (Operation *userOp : result.getUsers()) 47 if (forwardSlice->count(userOp) == 0) 48 getForwardSliceImpl(userOp, forwardSlice, filter); 49 } 50 51 forwardSlice->insert(op); 52 } 53 54 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 55 TransitiveFilter filter) { 56 getForwardSliceImpl(op, forwardSlice, filter); 57 // Don't insert the top level operation, we just queried on it and don't 58 // want it in the results. 59 forwardSlice->remove(op); 60 61 // Reverse to get back the actual topological order. 62 // std::reverse does not work out of the box on SetVector and I want an 63 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 64 std::vector<Operation *> v(forwardSlice->takeVector()); 65 forwardSlice->insert(v.rbegin(), v.rend()); 66 } 67 68 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, 69 TransitiveFilter filter) { 70 for (Operation *user : root.getUsers()) 71 getForwardSliceImpl(user, forwardSlice, filter); 72 73 // Reverse to get back the actual topological order. 74 // std::reverse does not work out of the box on SetVector and I want an 75 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 76 std::vector<Operation *> v(forwardSlice->takeVector()); 77 forwardSlice->insert(v.rbegin(), v.rend()); 78 } 79 80 static void getBackwardSliceImpl(Operation *op, 81 SetVector<Operation *> *backwardSlice, 82 TransitiveFilter filter) { 83 if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 84 return; 85 86 // Evaluate whether we should keep this def. 87 // This is useful in particular to implement scoping; i.e. return the 88 // transitive backwardSlice in the current scope. 89 if (filter && !filter(op)) 90 return; 91 92 for (auto en : llvm::enumerate(op->getOperands())) { 93 auto operand = en.value(); 94 if (auto *definingOp = operand.getDefiningOp()) { 95 if (backwardSlice->count(definingOp) == 0) 96 getBackwardSliceImpl(definingOp, backwardSlice, filter); 97 } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) { 98 Block *block = blockArg.getOwner(); 99 Operation *parentOp = block->getParentOp(); 100 // TODO: determine whether we want to recurse backward into the other 101 // blocks of parentOp, which are not technically backward unless they flow 102 // into us. For now, just bail. 103 assert(parentOp->getNumRegions() == 1 && 104 parentOp->getRegion(0).getBlocks().size() == 1); 105 if (backwardSlice->count(parentOp) == 0) 106 getBackwardSliceImpl(parentOp, backwardSlice, filter); 107 } else { 108 llvm_unreachable("No definingOp and not a block argument."); 109 } 110 } 111 112 backwardSlice->insert(op); 113 } 114 115 void mlir::getBackwardSlice(Operation *op, 116 SetVector<Operation *> *backwardSlice, 117 TransitiveFilter filter) { 118 getBackwardSliceImpl(op, backwardSlice, filter); 119 120 // Don't insert the top level operation, we just queried on it and don't 121 // want it in the results. 122 backwardSlice->remove(op); 123 } 124 125 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice, 126 TransitiveFilter filter) { 127 if (Operation *definingOp = root.getDefiningOp()) { 128 getBackwardSlice(definingOp, backwardSlice, filter); 129 return; 130 } 131 Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp(); 132 getBackwardSlice(bbAargOwner, backwardSlice, filter); 133 } 134 135 SetVector<Operation *> mlir::getSlice(Operation *op, 136 TransitiveFilter backwardFilter, 137 TransitiveFilter forwardFilter) { 138 SetVector<Operation *> slice; 139 slice.insert(op); 140 141 unsigned currentIndex = 0; 142 SetVector<Operation *> backwardSlice; 143 SetVector<Operation *> forwardSlice; 144 while (currentIndex != slice.size()) { 145 auto *currentOp = (slice)[currentIndex]; 146 // Compute and insert the backwardSlice starting from currentOp. 147 backwardSlice.clear(); 148 getBackwardSlice(currentOp, &backwardSlice, backwardFilter); 149 slice.insert(backwardSlice.begin(), backwardSlice.end()); 150 151 // Compute and insert the forwardSlice starting from currentOp. 152 forwardSlice.clear(); 153 getForwardSlice(currentOp, &forwardSlice, forwardFilter); 154 slice.insert(forwardSlice.begin(), forwardSlice.end()); 155 ++currentIndex; 156 } 157 return topologicalSort(slice); 158 } 159 160 namespace { 161 /// DFS post-order implementation that maintains a global count to work across 162 /// multiple invocations, to help implement topological sort on multi-root DAGs. 163 /// We traverse all operations but only record the ones that appear in 164 /// `toSort` for the final result. 165 struct DFSState { 166 DFSState(const SetVector<Operation *> &set) 167 : toSort(set), topologicalCounts(), seen() {} 168 const SetVector<Operation *> &toSort; 169 SmallVector<Operation *, 16> topologicalCounts; 170 DenseSet<Operation *> seen; 171 }; 172 } // namespace 173 174 static void DFSPostorder(Operation *current, DFSState *state) { 175 for (Value result : current->getResults()) { 176 for (Operation *op : result.getUsers()) 177 DFSPostorder(op, state); 178 } 179 bool inserted; 180 using IterTy = decltype(state->seen.begin()); 181 IterTy iter; 182 std::tie(iter, inserted) = state->seen.insert(current); 183 if (inserted) { 184 if (state->toSort.count(current) > 0) { 185 state->topologicalCounts.push_back(current); 186 } 187 } 188 } 189 190 SetVector<Operation *> 191 mlir::topologicalSort(const SetVector<Operation *> &toSort) { 192 if (toSort.empty()) { 193 return toSort; 194 } 195 196 // Run from each root with global count and `seen` set. 197 DFSState state(toSort); 198 for (auto *s : toSort) { 199 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); 200 DFSPostorder(s, &state); 201 } 202 203 // Reorder and return. 204 SetVector<Operation *> res; 205 for (auto it = state.topologicalCounts.rbegin(), 206 eit = state.topologicalCounts.rend(); 207 it != eit; ++it) { 208 res.insert(*it); 209 } 210 return res; 211 } 212