1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Analysis functions specific to slicing in Function. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/IR/Function.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/Support/LLVM.h" 20 #include "llvm/ADT/SetVector.h" 21 22 /// 23 /// Implements Analysis functions specific to slicing in Function. 24 /// 25 26 using namespace mlir; 27 28 using llvm::SetVector; 29 30 static void getForwardSliceImpl(Operation *op, 31 SetVector<Operation *> *forwardSlice, 32 TransitiveFilter filter) { 33 if (!op) { 34 return; 35 } 36 37 // Evaluate whether we should keep this use. 38 // This is useful in particular to implement scoping; i.e. return the 39 // transitive forwardSlice in the current scope. 40 if (!filter(op)) { 41 return; 42 } 43 44 if (auto forOp = dyn_cast<AffineForOp>(op)) { 45 for (Operation *userOp : forOp.getInductionVar().getUsers()) 46 if (forwardSlice->count(userOp) == 0) 47 getForwardSliceImpl(userOp, forwardSlice, filter); 48 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 49 for (Operation *userOp : forOp.getInductionVar().getUsers()) 50 if (forwardSlice->count(userOp) == 0) 51 getForwardSliceImpl(userOp, forwardSlice, filter); 52 for (Value result : forOp.getResults()) 53 for (Operation *userOp : result.getUsers()) 54 if (forwardSlice->count(userOp) == 0) 55 getForwardSliceImpl(userOp, forwardSlice, filter); 56 } else { 57 assert(op->getNumRegions() == 0 && "unexpected generic op with regions"); 58 for (Value result : op->getResults()) { 59 for (Operation *userOp : result.getUsers()) 60 if (forwardSlice->count(userOp) == 0) 61 getForwardSliceImpl(userOp, forwardSlice, filter); 62 } 63 } 64 65 forwardSlice->insert(op); 66 } 67 68 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 69 TransitiveFilter filter) { 70 getForwardSliceImpl(op, forwardSlice, filter); 71 // Don't insert the top level operation, we just queried on it and don't 72 // want it in the results. 73 forwardSlice->remove(op); 74 75 // Reverse to get back the actual topological order. 76 // std::reverse does not work out of the box on SetVector and I want an 77 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 78 std::vector<Operation *> v(forwardSlice->takeVector()); 79 forwardSlice->insert(v.rbegin(), v.rend()); 80 } 81 82 static void getBackwardSliceImpl(Operation *op, 83 SetVector<Operation *> *backwardSlice, 84 TransitiveFilter filter) { 85 if (!op) 86 return; 87 88 assert((op->getNumRegions() == 0 || 89 isa<AffineForOp, scf::ForOp, linalg::LinalgOp>(op)) && 90 "unexpected generic op with regions"); 91 92 // Evaluate whether we should keep this def. 93 // This is useful in particular to implement scoping; i.e. return the 94 // transitive forwardSlice in the current scope. 95 if (!filter(op)) { 96 return; 97 } 98 99 for (auto en : llvm::enumerate(op->getOperands())) { 100 auto operand = en.value(); 101 if (auto blockArg = operand.dyn_cast<BlockArgument>()) { 102 if (auto affIv = getForInductionVarOwner(operand)) { 103 auto *affOp = affIv.getOperation(); 104 if (backwardSlice->count(affOp) == 0) 105 getBackwardSliceImpl(affOp, backwardSlice, filter); 106 } else if (auto loopIv = scf::getForInductionVarOwner(operand)) { 107 auto *loopOp = loopIv.getOperation(); 108 if (backwardSlice->count(loopOp) == 0) 109 getBackwardSliceImpl(loopOp, backwardSlice, filter); 110 } else if (blockArg.getOwner() != 111 &op->getParentOfType<FuncOp>().getBody().front()) { 112 op->emitError("unsupported CF for operand ") << en.index(); 113 llvm_unreachable("Unsupported control flow"); 114 } 115 continue; 116 } 117 auto *op = operand.getDefiningOp(); 118 if (backwardSlice->count(op) == 0) { 119 getBackwardSliceImpl(op, backwardSlice, filter); 120 } 121 } 122 123 backwardSlice->insert(op); 124 } 125 126 void mlir::getBackwardSlice(Operation *op, 127 SetVector<Operation *> *backwardSlice, 128 TransitiveFilter filter) { 129 getBackwardSliceImpl(op, backwardSlice, filter); 130 131 // Don't insert the top level operation, we just queried on it and don't 132 // want it in the results. 133 backwardSlice->remove(op); 134 } 135 136 SetVector<Operation *> mlir::getSlice(Operation *op, 137 TransitiveFilter backwardFilter, 138 TransitiveFilter forwardFilter) { 139 SetVector<Operation *> slice; 140 slice.insert(op); 141 142 unsigned currentIndex = 0; 143 SetVector<Operation *> backwardSlice; 144 SetVector<Operation *> forwardSlice; 145 while (currentIndex != slice.size()) { 146 auto *currentOp = (slice)[currentIndex]; 147 // Compute and insert the backwardSlice starting from currentOp. 148 backwardSlice.clear(); 149 getBackwardSlice(currentOp, &backwardSlice, backwardFilter); 150 slice.insert(backwardSlice.begin(), backwardSlice.end()); 151 152 // Compute and insert the forwardSlice starting from currentOp. 153 forwardSlice.clear(); 154 getForwardSlice(currentOp, &forwardSlice, forwardFilter); 155 slice.insert(forwardSlice.begin(), forwardSlice.end()); 156 ++currentIndex; 157 } 158 return topologicalSort(slice); 159 } 160 161 namespace { 162 /// DFS post-order implementation that maintains a global count to work across 163 /// multiple invocations, to help implement topological sort on multi-root DAGs. 164 /// We traverse all operations but only record the ones that appear in 165 /// `toSort` for the final result. 166 struct DFSState { 167 DFSState(const SetVector<Operation *> &set) 168 : toSort(set), topologicalCounts(), seen() {} 169 const SetVector<Operation *> &toSort; 170 SmallVector<Operation *, 16> topologicalCounts; 171 DenseSet<Operation *> seen; 172 }; 173 } // namespace 174 175 static void DFSPostorder(Operation *current, DFSState *state) { 176 for (Value result : current->getResults()) { 177 for (Operation *op : result.getUsers()) 178 DFSPostorder(op, state); 179 } 180 bool inserted; 181 using IterTy = decltype(state->seen.begin()); 182 IterTy iter; 183 std::tie(iter, inserted) = state->seen.insert(current); 184 if (inserted) { 185 if (state->toSort.count(current) > 0) { 186 state->topologicalCounts.push_back(current); 187 } 188 } 189 } 190 191 SetVector<Operation *> 192 mlir::topologicalSort(const SetVector<Operation *> &toSort) { 193 if (toSort.empty()) { 194 return toSort; 195 } 196 197 // Run from each root with global count and `seen` set. 198 DFSState state(toSort); 199 for (auto *s : toSort) { 200 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); 201 DFSPostorder(s, &state); 202 } 203 204 // Reorder and return. 205 SetVector<Operation *> res; 206 for (auto it = state.topologicalCounts.rbegin(), 207 eit = state.topologicalCounts.rend(); 208 it != eit; ++it) { 209 res.insert(*it); 210 } 211 return res; 212 } 213