1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements Analysis functions specific to slicing in Function. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/Support/LLVM.h" 17 #include "llvm/ADT/SetVector.h" 18 #include "llvm/ADT/SmallPtrSet.h" 19 20 /// 21 /// Implements Analysis functions specific to slicing in Function. 22 /// 23 24 using namespace mlir; 25 26 static void getForwardSliceImpl(Operation *op, 27 SetVector<Operation *> *forwardSlice, 28 TransitiveFilter filter) { 29 if (!op) 30 return; 31 32 // Evaluate whether we should keep this use. 33 // This is useful in particular to implement scoping; i.e. return the 34 // transitive forwardSlice in the current scope. 35 if (filter && !filter(op)) 36 return; 37 38 for (Region ®ion : op->getRegions()) 39 for (Block &block : region) 40 for (Operation &blockOp : block) 41 if (forwardSlice->count(&blockOp) == 0) 42 getForwardSliceImpl(&blockOp, forwardSlice, filter); 43 for (Value result : op->getResults()) { 44 for (Operation *userOp : result.getUsers()) 45 if (forwardSlice->count(userOp) == 0) 46 getForwardSliceImpl(userOp, forwardSlice, filter); 47 } 48 49 forwardSlice->insert(op); 50 } 51 52 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, 53 TransitiveFilter filter) { 54 getForwardSliceImpl(op, forwardSlice, filter); 55 // Don't insert the top level operation, we just queried on it and don't 56 // want it in the results. 57 forwardSlice->remove(op); 58 59 // Reverse to get back the actual topological order. 60 // std::reverse does not work out of the box on SetVector and I want an 61 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 62 std::vector<Operation *> v(forwardSlice->takeVector()); 63 forwardSlice->insert(v.rbegin(), v.rend()); 64 } 65 66 void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice, 67 TransitiveFilter filter) { 68 for (Operation *user : root.getUsers()) 69 getForwardSliceImpl(user, forwardSlice, filter); 70 71 // Reverse to get back the actual topological order. 72 // std::reverse does not work out of the box on SetVector and I want an 73 // in-place swap based thing (the real std::reverse, not the LLVM adapter). 74 std::vector<Operation *> v(forwardSlice->takeVector()); 75 forwardSlice->insert(v.rbegin(), v.rend()); 76 } 77 78 static void getBackwardSliceImpl(Operation *op, 79 SetVector<Operation *> *backwardSlice, 80 TransitiveFilter filter) { 81 if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 82 return; 83 84 // Evaluate whether we should keep this def. 85 // This is useful in particular to implement scoping; i.e. return the 86 // transitive backwardSlice in the current scope. 87 if (filter && !filter(op)) 88 return; 89 90 for (const auto &en : llvm::enumerate(op->getOperands())) { 91 auto operand = en.value(); 92 if (auto *definingOp = operand.getDefiningOp()) { 93 if (backwardSlice->count(definingOp) == 0) 94 getBackwardSliceImpl(definingOp, backwardSlice, filter); 95 } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) { 96 Block *block = blockArg.getOwner(); 97 Operation *parentOp = block->getParentOp(); 98 // TODO: determine whether we want to recurse backward into the other 99 // blocks of parentOp, which are not technically backward unless they flow 100 // into us. For now, just bail. 101 assert(parentOp->getNumRegions() == 1 && 102 parentOp->getRegion(0).getBlocks().size() == 1); 103 if (backwardSlice->count(parentOp) == 0) 104 getBackwardSliceImpl(parentOp, backwardSlice, filter); 105 } else { 106 llvm_unreachable("No definingOp and not a block argument."); 107 } 108 } 109 110 backwardSlice->insert(op); 111 } 112 113 void mlir::getBackwardSlice(Operation *op, 114 SetVector<Operation *> *backwardSlice, 115 TransitiveFilter filter) { 116 getBackwardSliceImpl(op, backwardSlice, filter); 117 118 // Don't insert the top level operation, we just queried on it and don't 119 // want it in the results. 120 backwardSlice->remove(op); 121 } 122 123 void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice, 124 TransitiveFilter filter) { 125 if (Operation *definingOp = root.getDefiningOp()) { 126 getBackwardSlice(definingOp, backwardSlice, filter); 127 return; 128 } 129 Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp(); 130 getBackwardSlice(bbAargOwner, backwardSlice, filter); 131 } 132 133 SetVector<Operation *> mlir::getSlice(Operation *op, 134 TransitiveFilter backwardFilter, 135 TransitiveFilter forwardFilter) { 136 SetVector<Operation *> slice; 137 slice.insert(op); 138 139 unsigned currentIndex = 0; 140 SetVector<Operation *> backwardSlice; 141 SetVector<Operation *> forwardSlice; 142 while (currentIndex != slice.size()) { 143 auto *currentOp = (slice)[currentIndex]; 144 // Compute and insert the backwardSlice starting from currentOp. 145 backwardSlice.clear(); 146 getBackwardSlice(currentOp, &backwardSlice, backwardFilter); 147 slice.insert(backwardSlice.begin(), backwardSlice.end()); 148 149 // Compute and insert the forwardSlice starting from currentOp. 150 forwardSlice.clear(); 151 getForwardSlice(currentOp, &forwardSlice, forwardFilter); 152 slice.insert(forwardSlice.begin(), forwardSlice.end()); 153 ++currentIndex; 154 } 155 return topologicalSort(slice); 156 } 157 158 namespace { 159 /// DFS post-order implementation that maintains a global count to work across 160 /// multiple invocations, to help implement topological sort on multi-root DAGs. 161 /// We traverse all operations but only record the ones that appear in 162 /// `toSort` for the final result. 163 struct DFSState { 164 DFSState(const SetVector<Operation *> &set) 165 : toSort(set), topologicalCounts(), seen() {} 166 const SetVector<Operation *> &toSort; 167 SmallVector<Operation *, 16> topologicalCounts; 168 DenseSet<Operation *> seen; 169 }; 170 } // namespace 171 172 static void dfsPostorder(Operation *root, DFSState *state) { 173 SmallVector<Operation *> queue(1, root); 174 std::vector<Operation *> ops; 175 while (!queue.empty()) { 176 Operation *current = queue.pop_back_val(); 177 ops.push_back(current); 178 for (Value result : current->getResults()) { 179 for (Operation *op : result.getUsers()) 180 queue.push_back(op); 181 } 182 for (Region ®ion : current->getRegions()) { 183 for (Operation &op : region.getOps()) 184 queue.push_back(&op); 185 } 186 } 187 188 for (Operation *op : llvm::reverse(ops)) { 189 if (state->seen.insert(op).second && state->toSort.count(op) > 0) 190 state->topologicalCounts.push_back(op); 191 } 192 } 193 194 SetVector<Operation *> 195 mlir::topologicalSort(const SetVector<Operation *> &toSort) { 196 if (toSort.empty()) { 197 return toSort; 198 } 199 200 // Run from each root with global count and `seen` set. 201 DFSState state(toSort); 202 for (auto *s : toSort) { 203 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); 204 dfsPostorder(s, &state); 205 } 206 207 // Reorder and return. 208 SetVector<Operation *> res; 209 for (auto it = state.topologicalCounts.rbegin(), 210 eit = state.topologicalCounts.rend(); 211 it != eit; ++it) { 212 res.insert(*it); 213 } 214 return res; 215 } 216 217 /// Returns true if `value` (transitively) depends on iteration-carried values 218 /// of the given `ancestorOp`. 219 static bool dependsOnCarriedVals(Value value, 220 ArrayRef<BlockArgument> iterCarriedArgs, 221 Operation *ancestorOp) { 222 // Compute the backward slice of the value. 223 SetVector<Operation *> slice; 224 getBackwardSlice(value, &slice, 225 [&](Operation *op) { return !ancestorOp->isAncestor(op); }); 226 227 // Check that none of the operands of the operations in the backward slice are 228 // loop iteration arguments, and neither is the value itself. 229 SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(), 230 iterCarriedArgs.end()); 231 if (iterCarriedValSet.contains(value)) 232 return true; 233 234 for (Operation *op : slice) 235 for (Value operand : op->getOperands()) 236 if (iterCarriedValSet.contains(operand)) 237 return true; 238 239 return false; 240 } 241 242 /// Utility to match a generic reduction given a list of iteration-carried 243 /// arguments, `iterCarriedArgs` and the position of the potential reduction 244 /// argument within the list, `redPos`. If a reduction is matched, returns the 245 /// reduced value and the topologically-sorted list of combiner operations 246 /// involved in the reduction. Otherwise, returns a null value. 247 /// 248 /// The matching algorithm relies on the following invariants, which are subject 249 /// to change: 250 /// 1. The first combiner operation must be a binary operation with the 251 /// iteration-carried value and the reduced value as operands. 252 /// 2. The iteration-carried value and combiner operations must be side 253 /// effect-free, have single result and a single use. 254 /// 3. Combiner operations must be immediately nested in the region op 255 /// performing the reduction. 256 /// 4. Reduction def-use chain must end in a terminator op that yields the 257 /// next iteration/output values in the same order as the iteration-carried 258 /// values in `iterCarriedArgs`. 259 /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values 260 /// of the region op performing the reduction. 261 /// 262 /// This utility is generic enough to detect reductions involving multiple 263 /// combiner operations (disabled for now) across multiple dialects, including 264 /// Linalg, Affine and SCF. For the sake of genericity, it does not return 265 /// specific enum values for the combiner operations since its goal is also 266 /// matching reductions without pre-defined semantics in core MLIR. It's up to 267 /// each client to make sense out of the list of combiner operations. It's also 268 /// up to each client to check for additional invariants on the expected 269 /// reductions not covered by this generic matching. 270 Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs, 271 unsigned redPos, 272 SmallVectorImpl<Operation *> &combinerOps) { 273 assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds"); 274 275 BlockArgument redCarriedVal = iterCarriedArgs[redPos]; 276 if (!redCarriedVal.hasOneUse()) 277 return nullptr; 278 279 // For now, the first combiner op must be a binary op. 280 Operation *combinerOp = *redCarriedVal.getUsers().begin(); 281 if (combinerOp->getNumOperands() != 2) 282 return nullptr; 283 Value reducedVal = combinerOp->getOperand(0) == redCarriedVal 284 ? combinerOp->getOperand(1) 285 : combinerOp->getOperand(0); 286 287 Operation *redRegionOp = 288 iterCarriedArgs.front().getOwner()->getParent()->getParentOp(); 289 if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp)) 290 return nullptr; 291 292 // Traverse the def-use chain starting from the first combiner op until a 293 // terminator is found. Gather all the combiner ops along the way in 294 // topological order. 295 while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) { 296 if (!MemoryEffectOpInterface::hasNoEffect(combinerOp) || 297 combinerOp->getNumResults() != 1 || !combinerOp->hasOneUse() || 298 combinerOp->getParentOp() != redRegionOp) 299 return nullptr; 300 301 combinerOps.push_back(combinerOp); 302 combinerOp = *combinerOp->getUsers().begin(); 303 } 304 305 // Limit matching to single combiner op until we can properly test reductions 306 // involving multiple combiners. 307 if (combinerOps.size() != 1) 308 return nullptr; 309 310 // Check that the yielded value is in the same position as in 311 // `iterCarriedArgs`. 312 Operation *terminatorOp = combinerOp; 313 if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0]) 314 return nullptr; 315 316 return reducedVal; 317 } 318