1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Analysis functions specific to slicing in Function. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/IR/Function.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Support/LLVM.h" 19 #include "llvm/ADT/SetVector.h" 20 21 /// 22 /// Implements Analysis functions specific to slicing in Function. 23 /// 24 25 using namespace mlir; 26 27 using llvm::SetVector; 28 29 static void getForwardSliceImpl(Operation *op, 30 SetVector<Operation *> *forwardSlice, 31 TransitiveFilter filter) { 32 if (!op) { 33 return; 34 } 35 36 // Evaluate whether we should keep this use. 37 // This is useful in particular to implement scoping; i.e. return the 38 // transitive forwardSlice in the current scope. 39 if (!filter(op)) { 40 return; 41 } 42 43 if (auto forOp = dyn_cast<AffineForOp>(op)) { 44 for (auto *ownerOp : forOp.getInductionVar().getUsers()) 45 if (forwardSlice->count(ownerOp) == 0) 46 getForwardSliceImpl(ownerOp, forwardSlice, filter); 47 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 48 for (auto *ownerOp : forOp.getInductionVar().getUsers()) 49 if (forwardSlice->count(ownerOp) == 0) 50 getForwardSliceImpl(ownerOp, forwardSlice, filter); 51 for (auto result : forOp.getResults()) 52 for (auto *ownerOp : result.getUsers()) 53 if (forwardSlice->count(ownerOp) == 0) 54 getForwardSliceImpl(ownerOp, forwardSlice, filter); 55 } else { 56 assert(op->getNumRegions() == 0 && "unexpected generic op with regions"); 57 assert(op->getNumResults() <= 1 && "unexpected multiple results"); 58 if (op->getNumResults() > 0) { 59 for (auto *ownerOp : op->getResult(0).getUsers()) 60 if (forwardSlice->count(ownerOp) == 0) 61 getForwardSliceImpl(ownerOp, forwardSlice, filter); 62 } 63 } 64 65 forwardSlice->insert(op); 66 } 67 68 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 69 TransitiveFilter filter) { 70 getForwardSliceImpl(op, forwardSlice, filter); 71 // Don't insert the top level operation, we just queried on it and don't 72 // want it in the results. 73 forwardSlice->remove(op); 74 75 // Reverse to get back the actual topological order. 76 // std::reverse does not work out of the box on SetVector and I want an 77 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 78 std::vector<Operation *> v(forwardSlice->takeVector()); 79 forwardSlice->insert(v.rbegin(), v.rend()); 80 } 81 82 static void getBackwardSliceImpl(Operation *op, 83 SetVector<Operation *> *backwardSlice, 84 TransitiveFilter filter) { 85 if (!op) 86 return; 87 88 assert((op->getNumRegions() == 0 || isa<AffineForOp, scf::ForOp>(op)) && 89 "unexpected generic op with regions"); 90 91 // Evaluate whether we should keep this def. 92 // This is useful in particular to implement scoping; i.e. return the 93 // transitive forwardSlice in the current scope. 94 if (!filter(op)) { 95 return; 96 } 97 98 for (auto en : llvm::enumerate(op->getOperands())) { 99 auto operand = en.value(); 100 if (auto blockArg = operand.dyn_cast<BlockArgument>()) { 101 if (auto affIv = getForInductionVarOwner(operand)) { 102 auto *affOp = affIv.getOperation(); 103 if (backwardSlice->count(affOp) == 0) 104 getBackwardSliceImpl(affOp, backwardSlice, filter); 105 } else if (auto loopIv = scf::getForInductionVarOwner(operand)) { 106 auto *loopOp = loopIv.getOperation(); 107 if (backwardSlice->count(loopOp) == 0) 108 getBackwardSliceImpl(loopOp, backwardSlice, filter); 109 } else if (blockArg.getOwner() != 110 &op->getParentOfType<FuncOp>().getBody().front()) { 111 op->emitError("unsupported CF for operand ") << en.index(); 112 llvm_unreachable("Unsupported control flow"); 113 } 114 continue; 115 } 116 auto *op = operand.getDefiningOp(); 117 if (backwardSlice->count(op) == 0) { 118 getBackwardSliceImpl(op, backwardSlice, filter); 119 } 120 } 121 122 backwardSlice->insert(op); 123 } 124 125 void mlir::getBackwardSlice(Operation *op, 126 SetVector<Operation *> *backwardSlice, 127 TransitiveFilter filter) { 128 getBackwardSliceImpl(op, backwardSlice, filter); 129 130 // Don't insert the top level operation, we just queried on it and don't 131 // want it in the results. 132 backwardSlice->remove(op); 133 } 134 135 SetVector<Operation *> mlir::getSlice(Operation *op, 136 TransitiveFilter backwardFilter, 137 TransitiveFilter forwardFilter) { 138 SetVector<Operation *> slice; 139 slice.insert(op); 140 141 unsigned currentIndex = 0; 142 SetVector<Operation *> backwardSlice; 143 SetVector<Operation *> forwardSlice; 144 while (currentIndex != slice.size()) { 145 auto *currentOp = (slice)[currentIndex]; 146 // Compute and insert the backwardSlice starting from currentOp. 147 backwardSlice.clear(); 148 getBackwardSlice(currentOp, &backwardSlice, backwardFilter); 149 slice.insert(backwardSlice.begin(), backwardSlice.end()); 150 151 // Compute and insert the forwardSlice starting from currentOp. 152 forwardSlice.clear(); 153 getForwardSlice(currentOp, &forwardSlice, forwardFilter); 154 slice.insert(forwardSlice.begin(), forwardSlice.end()); 155 ++currentIndex; 156 } 157 return topologicalSort(slice); 158 } 159 160 namespace { 161 /// DFS post-order implementation that maintains a global count to work across 162 /// multiple invocations, to help implement topological sort on multi-root DAGs. 163 /// We traverse all operations but only record the ones that appear in 164 /// `toSort` for the final result. 165 struct DFSState { 166 DFSState(const SetVector<Operation *> &set) 167 : toSort(set), topologicalCounts(), seen() {} 168 const SetVector<Operation *> &toSort; 169 SmallVector<Operation *, 16> topologicalCounts; 170 DenseSet<Operation *> seen; 171 }; 172 } // namespace 173 174 static void DFSPostorder(Operation *current, DFSState *state) { 175 assert(current->getNumResults() <= 1 && "NYI: multi-result"); 176 if (current->getNumResults() > 0) { 177 for (auto &u : current->getResult(0).getUses()) { 178 auto *op = u.getOwner(); 179 DFSPostorder(op, state); 180 } 181 } 182 bool inserted; 183 using IterTy = decltype(state->seen.begin()); 184 IterTy iter; 185 std::tie(iter, inserted) = state->seen.insert(current); 186 if (inserted) { 187 if (state->toSort.count(current) > 0) { 188 state->topologicalCounts.push_back(current); 189 } 190 } 191 } 192 193 SetVector<Operation *> 194 mlir::topologicalSort(const SetVector<Operation *> &toSort) { 195 if (toSort.empty()) { 196 return toSort; 197 } 198 199 // Run from each root with global count and `seen` set. 200 DFSState state(toSort); 201 for (auto *s : toSort) { 202 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); 203 DFSPostorder(s, &state); 204 } 205 206 // Reorder and return. 207 SetVector<Operation *> res; 208 for (auto it = state.topologicalCounts.rbegin(), 209 eit = state.topologicalCounts.rend(); 210 it != eit; ++it) { 211 res.insert(*it); 212 } 213 return res; 214 } 215