1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Analysis functions specific to slicing in Function. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/Dialect/AffineOps/AffineOps.h" 15 #include "mlir/Dialect/LoopOps/LoopOps.h" 16 #include "mlir/IR/Function.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Support/Functional.h" 19 #include "mlir/Support/LLVM.h" 20 #include "mlir/Support/STLExtras.h" 21 #include "llvm/ADT/SetVector.h" 22 23 /// 24 /// Implements Analysis functions specific to slicing in Function. 25 /// 26 27 using namespace mlir; 28 29 using llvm::SetVector; 30 31 static void getForwardSliceImpl(Operation *op, 32 SetVector<Operation *> *forwardSlice, 33 TransitiveFilter filter) { 34 if (!op) { 35 return; 36 } 37 38 // Evaluate whether we should keep this use. 39 // This is useful in particular to implement scoping; i.e. return the 40 // transitive forwardSlice in the current scope. 41 if (!filter(op)) { 42 return; 43 } 44 45 if (auto forOp = dyn_cast<AffineForOp>(op)) { 46 for (auto *ownerInst : forOp.getInductionVar().getUsers()) 47 if (forwardSlice->count(ownerInst) == 0) 48 getForwardSliceImpl(ownerInst, forwardSlice, filter); 49 } else if (auto forOp = dyn_cast<loop::ForOp>(op)) { 50 for (auto *ownerInst : forOp.getInductionVar().getUsers()) 51 if (forwardSlice->count(ownerInst) == 0) 52 getForwardSliceImpl(ownerInst, forwardSlice, filter); 53 } else { 54 assert(op->getNumRegions() == 0 && "unexpected generic op with regions"); 55 assert(op->getNumResults() <= 1 && "unexpected multiple results"); 56 if (op->getNumResults() > 0) { 57 for (auto *ownerInst : op->getResult(0).getUsers()) 58 if (forwardSlice->count(ownerInst) == 0) 59 getForwardSliceImpl(ownerInst, forwardSlice, filter); 60 } 61 } 62 63 forwardSlice->insert(op); 64 } 65 66 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 67 TransitiveFilter filter) { 68 getForwardSliceImpl(op, forwardSlice, filter); 69 // Don't insert the top level operation, we just queried on it and don't 70 // want it in the results. 71 forwardSlice->remove(op); 72 73 // Reverse to get back the actual topological order. 74 // std::reverse does not work out of the box on SetVector and I want an 75 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 76 std::vector<Operation *> v(forwardSlice->takeVector()); 77 forwardSlice->insert(v.rbegin(), v.rend()); 78 } 79 80 static void getBackwardSliceImpl(Operation *op, 81 SetVector<Operation *> *backwardSlice, 82 TransitiveFilter filter) { 83 if (!op) 84 return; 85 86 assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) || 87 isa<loop::ForOp>(op)) && 88 "unexpected generic op with regions"); 89 90 // Evaluate whether we should keep this def. 91 // This is useful in particular to implement scoping; i.e. return the 92 // transitive forwardSlice in the current scope. 93 if (!filter(op)) { 94 return; 95 } 96 97 for (auto en : llvm::enumerate(op->getOperands())) { 98 auto operand = en.value(); 99 if (auto blockArg = operand.dyn_cast<BlockArgument>()) { 100 if (auto affIv = getForInductionVarOwner(operand)) { 101 auto *affOp = affIv.getOperation(); 102 if (backwardSlice->count(affOp) == 0) 103 getBackwardSliceImpl(affOp, backwardSlice, filter); 104 } else if (auto loopIv = loop::getForInductionVarOwner(operand)) { 105 auto *loopOp = loopIv.getOperation(); 106 if (backwardSlice->count(loopOp) == 0) 107 getBackwardSliceImpl(loopOp, backwardSlice, filter); 108 } else if (blockArg.getOwner() != 109 &op->getParentOfType<FuncOp>().getBody().front()) { 110 op->emitError("unsupported CF for operand ") << en.index(); 111 llvm_unreachable("Unsupported control flow"); 112 } 113 continue; 114 } 115 auto *op = operand.getDefiningOp(); 116 if (backwardSlice->count(op) == 0) { 117 getBackwardSliceImpl(op, backwardSlice, filter); 118 } 119 } 120 121 backwardSlice->insert(op); 122 } 123 124 void mlir::getBackwardSlice(Operation *op, 125 SetVector<Operation *> *backwardSlice, 126 TransitiveFilter filter) { 127 getBackwardSliceImpl(op, backwardSlice, filter); 128 129 // Don't insert the top level operation, we just queried on it and don't 130 // want it in the results. 131 backwardSlice->remove(op); 132 } 133 134 SetVector<Operation *> mlir::getSlice(Operation *op, 135 TransitiveFilter backwardFilter, 136 TransitiveFilter forwardFilter) { 137 SetVector<Operation *> slice; 138 slice.insert(op); 139 140 unsigned currentIndex = 0; 141 SetVector<Operation *> backwardSlice; 142 SetVector<Operation *> forwardSlice; 143 while (currentIndex != slice.size()) { 144 auto *currentInst = (slice)[currentIndex]; 145 // Compute and insert the backwardSlice starting from currentInst. 146 backwardSlice.clear(); 147 getBackwardSlice(currentInst, &backwardSlice, backwardFilter); 148 slice.insert(backwardSlice.begin(), backwardSlice.end()); 149 150 // Compute and insert the forwardSlice starting from currentInst. 151 forwardSlice.clear(); 152 getForwardSlice(currentInst, &forwardSlice, forwardFilter); 153 slice.insert(forwardSlice.begin(), forwardSlice.end()); 154 ++currentIndex; 155 } 156 return topologicalSort(slice); 157 } 158 159 namespace { 160 /// DFS post-order implementation that maintains a global count to work across 161 /// multiple invocations, to help implement topological sort on multi-root DAGs. 162 /// We traverse all operations but only record the ones that appear in 163 /// `toSort` for the final result. 164 struct DFSState { 165 DFSState(const SetVector<Operation *> &set) 166 : toSort(set), topologicalCounts(), seen() {} 167 const SetVector<Operation *> &toSort; 168 SmallVector<Operation *, 16> topologicalCounts; 169 DenseSet<Operation *> seen; 170 }; 171 } // namespace 172 173 static void DFSPostorder(Operation *current, DFSState *state) { 174 assert(current->getNumResults() <= 1 && "NYI: multi-result"); 175 if (current->getNumResults() > 0) { 176 for (auto &u : current->getResult(0).getUses()) { 177 auto *op = u.getOwner(); 178 DFSPostorder(op, state); 179 } 180 } 181 bool inserted; 182 using IterTy = decltype(state->seen.begin()); 183 IterTy iter; 184 std::tie(iter, inserted) = state->seen.insert(current); 185 if (inserted) { 186 if (state->toSort.count(current) > 0) { 187 state->topologicalCounts.push_back(current); 188 } 189 } 190 } 191 192 SetVector<Operation *> 193 mlir::topologicalSort(const SetVector<Operation *> &toSort) { 194 if (toSort.empty()) { 195 return toSort; 196 } 197 198 // Run from each root with global count and `seen` set. 199 DFSState state(toSort); 200 for (auto *s : toSort) { 201 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); 202 DFSPostorder(s, &state); 203 } 204 205 // Reorder and return. 206 SetVector<Operation *> res; 207 for (auto it = state.topologicalCounts.rbegin(), 208 eit = state.topologicalCounts.rend(); 209 it != eit; ++it) { 210 res.insert(*it); 211 } 212 return res; 213 } 214