1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements Analysis functions specific to slicing in Function. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Analysis/SliceAnalysis.h" 23 #include "mlir/Dialect/AffineOps/AffineOps.h" 24 #include "mlir/Dialect/LoopOps/LoopOps.h" 25 #include "mlir/IR/Function.h" 26 #include "mlir/IR/Operation.h" 27 #include "mlir/Support/Functional.h" 28 #include "mlir/Support/LLVM.h" 29 #include "mlir/Support/STLExtras.h" 30 #include "llvm/ADT/SetVector.h" 31 32 /// 33 /// Implements Analysis functions specific to slicing in Function. 34 /// 35 36 using namespace mlir; 37 38 using llvm::SetVector; 39 40 static void getForwardSliceImpl(Operation *op, 41 SetVector<Operation *> *forwardSlice, 42 TransitiveFilter filter) { 43 if (!op) { 44 return; 45 } 46 47 // Evaluate whether we should keep this use. 48 // This is useful in particular to implement scoping; i.e. return the 49 // transitive forwardSlice in the current scope. 50 if (!filter(op)) { 51 return; 52 } 53 54 if (auto forOp = dyn_cast<AffineForOp>(op)) { 55 for (auto *ownerInst : forOp.getInductionVar()->getUsers()) 56 if (forwardSlice->count(ownerInst) == 0) 57 getForwardSliceImpl(ownerInst, forwardSlice, filter); 58 } else if (auto forOp = dyn_cast<loop::ForOp>(op)) { 59 for (auto *ownerInst : forOp.getInductionVar()->getUsers()) 60 if (forwardSlice->count(ownerInst) == 0) 61 getForwardSliceImpl(ownerInst, forwardSlice, filter); 62 } else { 63 assert(op->getNumRegions() == 0 && "unexpected generic op with regions"); 64 assert(op->getNumResults() <= 1 && "unexpected multiple results"); 65 if (op->getNumResults() > 0) { 66 for (auto *ownerInst : op->getResult(0)->getUsers()) 67 if (forwardSlice->count(ownerInst) == 0) 68 getForwardSliceImpl(ownerInst, forwardSlice, filter); 69 } 70 } 71 72 forwardSlice->insert(op); 73 } 74 75 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 76 TransitiveFilter filter) { 77 getForwardSliceImpl(op, forwardSlice, filter); 78 // Don't insert the top level operation, we just queried on it and don't 79 // want it in the results. 80 forwardSlice->remove(op); 81 82 // Reverse to get back the actual topological order. 83 // std::reverse does not work out of the box on SetVector and I want an 84 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 85 std::vector<Operation *> v(forwardSlice->takeVector()); 86 forwardSlice->insert(v.rbegin(), v.rend()); 87 } 88 89 static void getBackwardSliceImpl(Operation *op, 90 SetVector<Operation *> *backwardSlice, 91 TransitiveFilter filter) { 92 if (!op) 93 return; 94 95 assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) || 96 isa<loop::ForOp>(op)) && 97 "unexpected generic op with regions"); 98 99 // Evaluate whether we should keep this def. 100 // This is useful in particular to implement scoping; i.e. return the 101 // transitive forwardSlice in the current scope. 102 if (!filter(op)) { 103 return; 104 } 105 106 for (auto en : llvm::enumerate(op->getOperands())) { 107 auto *operand = en.value(); 108 if (auto *blockArg = dyn_cast<BlockArgument>(operand)) { 109 if (auto affIv = getForInductionVarOwner(operand)) { 110 auto *affOp = affIv.getOperation(); 111 if (backwardSlice->count(affOp) == 0) 112 getBackwardSliceImpl(affOp, backwardSlice, filter); 113 } else if (auto loopIv = loop::getForInductionVarOwner(operand)) { 114 auto *loopOp = loopIv.getOperation(); 115 if (backwardSlice->count(loopOp) == 0) 116 getBackwardSliceImpl(loopOp, backwardSlice, filter); 117 } else if (blockArg->getOwner() != 118 &op->getParentOfType<FuncOp>().getBody().front()) { 119 op->emitError("unsupported CF for operand ") << en.index(); 120 llvm_unreachable("Unsupported control flow"); 121 } 122 continue; 123 } 124 auto *op = operand->getDefiningOp(); 125 if (backwardSlice->count(op) == 0) { 126 getBackwardSliceImpl(op, backwardSlice, filter); 127 } 128 } 129 130 backwardSlice->insert(op); 131 } 132 133 void mlir::getBackwardSlice(Operation *op, 134 SetVector<Operation *> *backwardSlice, 135 TransitiveFilter filter) { 136 getBackwardSliceImpl(op, backwardSlice, filter); 137 138 // Don't insert the top level operation, we just queried on it and don't 139 // want it in the results. 140 backwardSlice->remove(op); 141 } 142 143 SetVector<Operation *> mlir::getSlice(Operation *op, 144 TransitiveFilter backwardFilter, 145 TransitiveFilter forwardFilter) { 146 SetVector<Operation *> slice; 147 slice.insert(op); 148 149 unsigned currentIndex = 0; 150 SetVector<Operation *> backwardSlice; 151 SetVector<Operation *> forwardSlice; 152 while (currentIndex != slice.size()) { 153 auto *currentInst = (slice)[currentIndex]; 154 // Compute and insert the backwardSlice starting from currentInst. 155 backwardSlice.clear(); 156 getBackwardSlice(currentInst, &backwardSlice, backwardFilter); 157 slice.insert(backwardSlice.begin(), backwardSlice.end()); 158 159 // Compute and insert the forwardSlice starting from currentInst. 160 forwardSlice.clear(); 161 getForwardSlice(currentInst, &forwardSlice, forwardFilter); 162 slice.insert(forwardSlice.begin(), forwardSlice.end()); 163 ++currentIndex; 164 } 165 return topologicalSort(slice); 166 } 167 168 namespace { 169 /// DFS post-order implementation that maintains a global count to work across 170 /// multiple invocations, to help implement topological sort on multi-root DAGs. 171 /// We traverse all operations but only record the ones that appear in 172 /// `toSort` for the final result. 173 struct DFSState { 174 DFSState(const SetVector<Operation *> &set) 175 : toSort(set), topologicalCounts(), seen() {} 176 const SetVector<Operation *> &toSort; 177 SmallVector<Operation *, 16> topologicalCounts; 178 DenseSet<Operation *> seen; 179 }; 180 } // namespace 181 182 static void DFSPostorder(Operation *current, DFSState *state) { 183 assert(current->getNumResults() <= 1 && "NYI: multi-result"); 184 if (current->getNumResults() > 0) { 185 for (auto &u : current->getResult(0)->getUses()) { 186 auto *op = u.getOwner(); 187 DFSPostorder(op, state); 188 } 189 } 190 bool inserted; 191 using IterTy = decltype(state->seen.begin()); 192 IterTy iter; 193 std::tie(iter, inserted) = state->seen.insert(current); 194 if (inserted) { 195 if (state->toSort.count(current) > 0) { 196 state->topologicalCounts.push_back(current); 197 } 198 } 199 } 200 201 SetVector<Operation *> 202 mlir::topologicalSort(const SetVector<Operation *> &toSort) { 203 if (toSort.empty()) { 204 return toSort; 205 } 206 207 // Run from each root with global count and `seen` set. 208 DFSState state(toSort); 209 for (auto *s : toSort) { 210 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); 211 DFSPostorder(s, &state); 212 } 213 214 // Reorder and return. 215 SetVector<Operation *> res; 216 for (auto it = state.topologicalCounts.rbegin(), 217 eit = state.topologicalCounts.rend(); 218 it != eit; ++it) { 219 res.insert(*it); 220 } 221 return res; 222 } 223