1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Analysis functions specific to slicing in Function. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/LoopOps/LoopOps.h" 16 #include "mlir/IR/Function.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Support/LLVM.h" 19 #include "mlir/Support/STLExtras.h" 20 #include "llvm/ADT/SetVector.h" 21 22 /// 23 /// Implements Analysis functions specific to slicing in Function. 24 /// 25 26 using namespace mlir; 27 28 using llvm::SetVector; 29 30 static void getForwardSliceImpl(Operation *op, 31 SetVector<Operation *> *forwardSlice, 32 TransitiveFilter filter) { 33 if (!op) { 34 return; 35 } 36 37 // Evaluate whether we should keep this use. 38 // This is useful in particular to implement scoping; i.e. return the 39 // transitive forwardSlice in the current scope. 40 if (!filter(op)) { 41 return; 42 } 43 44 if (auto forOp = dyn_cast<AffineForOp>(op)) { 45 for (auto *ownerInst : forOp.getInductionVar().getUsers()) 46 if (forwardSlice->count(ownerInst) == 0) 47 getForwardSliceImpl(ownerInst, forwardSlice, filter); 48 } else if (auto forOp = dyn_cast<loop::ForOp>(op)) { 49 for (auto *ownerInst : forOp.getInductionVar().getUsers()) 50 if (forwardSlice->count(ownerInst) == 0) 51 getForwardSliceImpl(ownerInst, forwardSlice, filter); 52 } else { 53 assert(op->getNumRegions() == 0 && "unexpected generic op with regions"); 54 assert(op->getNumResults() <= 1 && "unexpected multiple results"); 55 if (op->getNumResults() > 0) { 56 for (auto *ownerInst : op->getResult(0).getUsers()) 57 if (forwardSlice->count(ownerInst) == 0) 58 getForwardSliceImpl(ownerInst, forwardSlice, filter); 59 } 60 } 61 62 forwardSlice->insert(op); 63 } 64 65 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 66 TransitiveFilter filter) { 67 getForwardSliceImpl(op, forwardSlice, filter); 68 // Don't insert the top level operation, we just queried on it and don't 69 // want it in the results. 70 forwardSlice->remove(op); 71 72 // Reverse to get back the actual topological order. 73 // std::reverse does not work out of the box on SetVector and I want an 74 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 75 std::vector<Operation *> v(forwardSlice->takeVector()); 76 forwardSlice->insert(v.rbegin(), v.rend()); 77 } 78 79 static void getBackwardSliceImpl(Operation *op, 80 SetVector<Operation *> *backwardSlice, 81 TransitiveFilter filter) { 82 if (!op) 83 return; 84 85 assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) || 86 isa<loop::ForOp>(op)) && 87 "unexpected generic op with regions"); 88 89 // Evaluate whether we should keep this def. 90 // This is useful in particular to implement scoping; i.e. return the 91 // transitive forwardSlice in the current scope. 92 if (!filter(op)) { 93 return; 94 } 95 96 for (auto en : llvm::enumerate(op->getOperands())) { 97 auto operand = en.value(); 98 if (auto blockArg = operand.dyn_cast<BlockArgument>()) { 99 if (auto affIv = getForInductionVarOwner(operand)) { 100 auto *affOp = affIv.getOperation(); 101 if (backwardSlice->count(affOp) == 0) 102 getBackwardSliceImpl(affOp, backwardSlice, filter); 103 } else if (auto loopIv = loop::getForInductionVarOwner(operand)) { 104 auto *loopOp = loopIv.getOperation(); 105 if (backwardSlice->count(loopOp) == 0) 106 getBackwardSliceImpl(loopOp, backwardSlice, filter); 107 } else if (blockArg.getOwner() != 108 &op->getParentOfType<FuncOp>().getBody().front()) { 109 op->emitError("unsupported CF for operand ") << en.index(); 110 llvm_unreachable("Unsupported control flow"); 111 } 112 continue; 113 } 114 auto *op = operand.getDefiningOp(); 115 if (backwardSlice->count(op) == 0) { 116 getBackwardSliceImpl(op, backwardSlice, filter); 117 } 118 } 119 120 backwardSlice->insert(op); 121 } 122 123 void mlir::getBackwardSlice(Operation *op, 124 SetVector<Operation *> *backwardSlice, 125 TransitiveFilter filter) { 126 getBackwardSliceImpl(op, backwardSlice, filter); 127 128 // Don't insert the top level operation, we just queried on it and don't 129 // want it in the results. 130 backwardSlice->remove(op); 131 } 132 133 SetVector<Operation *> mlir::getSlice(Operation *op, 134 TransitiveFilter backwardFilter, 135 TransitiveFilter forwardFilter) { 136 SetVector<Operation *> slice; 137 slice.insert(op); 138 139 unsigned currentIndex = 0; 140 SetVector<Operation *> backwardSlice; 141 SetVector<Operation *> forwardSlice; 142 while (currentIndex != slice.size()) { 143 auto *currentInst = (slice)[currentIndex]; 144 // Compute and insert the backwardSlice starting from currentInst. 145 backwardSlice.clear(); 146 getBackwardSlice(currentInst, &backwardSlice, backwardFilter); 147 slice.insert(backwardSlice.begin(), backwardSlice.end()); 148 149 // Compute and insert the forwardSlice starting from currentInst. 150 forwardSlice.clear(); 151 getForwardSlice(currentInst, &forwardSlice, forwardFilter); 152 slice.insert(forwardSlice.begin(), forwardSlice.end()); 153 ++currentIndex; 154 } 155 return topologicalSort(slice); 156 } 157 158 namespace { 159 /// DFS post-order implementation that maintains a global count to work across 160 /// multiple invocations, to help implement topological sort on multi-root DAGs. 161 /// We traverse all operations but only record the ones that appear in 162 /// `toSort` for the final result. 163 struct DFSState { 164 DFSState(const SetVector<Operation *> &set) 165 : toSort(set), topologicalCounts(), seen() {} 166 const SetVector<Operation *> &toSort; 167 SmallVector<Operation *, 16> topologicalCounts; 168 DenseSet<Operation *> seen; 169 }; 170 } // namespace 171 172 static void DFSPostorder(Operation *current, DFSState *state) { 173 assert(current->getNumResults() <= 1 && "NYI: multi-result"); 174 if (current->getNumResults() > 0) { 175 for (auto &u : current->getResult(0).getUses()) { 176 auto *op = u.getOwner(); 177 DFSPostorder(op, state); 178 } 179 } 180 bool inserted; 181 using IterTy = decltype(state->seen.begin()); 182 IterTy iter; 183 std::tie(iter, inserted) = state->seen.insert(current); 184 if (inserted) { 185 if (state->toSort.count(current) > 0) { 186 state->topologicalCounts.push_back(current); 187 } 188 } 189 } 190 191 SetVector<Operation *> 192 mlir::topologicalSort(const SetVector<Operation *> &toSort) { 193 if (toSort.empty()) { 194 return toSort; 195 } 196 197 // Run from each root with global count and `seen` set. 198 DFSState state(toSort); 199 for (auto *s : toSort) { 200 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); 201 DFSPostorder(s, &state); 202 } 203 204 // Reorder and return. 205 SetVector<Operation *> res; 206 for (auto it = state.topologicalCounts.rbegin(), 207 eit = state.topologicalCounts.rend(); 208 it != eit; ++it) { 209 res.insert(*it); 210 } 211 return res; 212 } 213