1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/RegionUtils.h" 10 #include "mlir/IR/Block.h" 11 #include "mlir/IR/Operation.h" 12 #include "mlir/IR/RegionGraphTraits.h" 13 #include "mlir/IR/Value.h" 14 15 #include "llvm/ADT/DepthFirstIterator.h" 16 #include "llvm/ADT/PostOrderIterator.h" 17 #include "llvm/ADT/SmallSet.h" 18 19 using namespace mlir; 20 21 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, 22 Region ®ion) { 23 for (auto &use : llvm::make_early_inc_range(orig.getUses())) { 24 if (region.isAncestor(use.getOwner()->getParentRegion())) 25 use.set(replacement); 26 } 27 } 28 29 void mlir::visitUsedValuesDefinedAbove( 30 Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { 31 assert(limit.isAncestor(®ion) && 32 "expected isolation limit to be an ancestor of the given region"); 33 34 // Collect proper ancestors of `limit` upfront to avoid traversing the region 35 // tree for every value. 36 SmallPtrSet<Region *, 4> properAncestors; 37 for (auto *reg = limit.getParentRegion(); reg != nullptr; 38 reg = reg->getParentRegion()) { 39 properAncestors.insert(reg); 40 } 41 42 region.walk([callback, &properAncestors](Operation *op) { 43 for (OpOperand &operand : op->getOpOperands()) 44 // Callback on values defined in a proper ancestor of region. 45 if (properAncestors.count(operand.get().getParentRegion())) 46 callback(&operand); 47 }); 48 } 49 50 void mlir::visitUsedValuesDefinedAbove( 51 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { 52 for (Region ®ion : regions) 53 visitUsedValuesDefinedAbove(region, region, callback); 54 } 55 56 void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, 57 llvm::SetVector<Value> &values) { 58 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { 59 values.insert(operand->get()); 60 }); 61 } 62 63 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, 64 llvm::SetVector<Value> &values) { 65 for (Region ®ion : regions) 66 getUsedValuesDefinedAbove(region, region, values); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Unreachable Block Elimination 71 //===----------------------------------------------------------------------===// 72 73 /// Erase the unreachable blocks within the provided regions. Returns success 74 /// if any blocks were erased, failure otherwise. 75 // TODO: We could likely merge this with the DCE algorithm below. 76 static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) { 77 // Set of blocks found to be reachable within a given region. 78 llvm::df_iterator_default_set<Block *, 16> reachable; 79 // If any blocks were found to be dead. 80 bool erasedDeadBlocks = false; 81 82 SmallVector<Region *, 1> worklist; 83 worklist.reserve(regions.size()); 84 for (Region ®ion : regions) 85 worklist.push_back(®ion); 86 while (!worklist.empty()) { 87 Region *region = worklist.pop_back_val(); 88 if (region->empty()) 89 continue; 90 91 // If this is a single block region, just collect the nested regions. 92 if (std::next(region->begin()) == region->end()) { 93 for (Operation &op : region->front()) 94 for (Region ®ion : op.getRegions()) 95 worklist.push_back(®ion); 96 continue; 97 } 98 99 // Mark all reachable blocks. 100 reachable.clear(); 101 for (Block *block : depth_first_ext(®ion->front(), reachable)) 102 (void)block /* Mark all reachable blocks */; 103 104 // Collect all of the dead blocks and push the live regions onto the 105 // worklist. 106 for (Block &block : llvm::make_early_inc_range(*region)) { 107 if (!reachable.count(&block)) { 108 block.dropAllDefinedValueUses(); 109 block.erase(); 110 erasedDeadBlocks = true; 111 continue; 112 } 113 114 // Walk any regions within this block. 115 for (Operation &op : block) 116 for (Region ®ion : op.getRegions()) 117 worklist.push_back(®ion); 118 } 119 } 120 121 return success(erasedDeadBlocks); 122 } 123 124 //===----------------------------------------------------------------------===// 125 // Dead Code Elimination 126 //===----------------------------------------------------------------------===// 127 128 namespace { 129 /// Data structure used to track which values have already been proved live. 130 /// 131 /// Because Operation's can have multiple results, this data structure tracks 132 /// liveness for both Value's and Operation's to avoid having to look through 133 /// all Operation results when analyzing a use. 134 /// 135 /// This data structure essentially tracks the dataflow lattice. 136 /// The set of values/ops proved live increases monotonically to a fixed-point. 137 class LiveMap { 138 public: 139 /// Value methods. 140 bool wasProvenLive(Value value) { return liveValues.count(value); } 141 void setProvedLive(Value value) { 142 changed |= liveValues.insert(value).second; 143 } 144 145 /// Operation methods. 146 bool wasProvenLive(Operation *op) { return liveOps.count(op); } 147 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; } 148 149 /// Methods for tracking if we have reached a fixed-point. 150 void resetChanged() { changed = false; } 151 bool hasChanged() { return changed; } 152 153 private: 154 bool changed = false; 155 DenseSet<Value> liveValues; 156 DenseSet<Operation *> liveOps; 157 }; 158 } // namespace 159 160 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { 161 Operation *owner = use.getOwner(); 162 unsigned operandIndex = use.getOperandNumber(); 163 // This pass generally treats all uses of an op as live if the op itself is 164 // considered live. However, for successor operands to terminators we need a 165 // finer-grained notion where we deduce liveness for operands individually. 166 // The reason for this is easiest to think about in terms of a classical phi 167 // node based SSA IR, where each successor operand is really an operand to a 168 // *separate* phi node, rather than all operands to the branch itself as with 169 // the block argument representation that MLIR uses. 170 // 171 // And similarly, because each successor operand is really an operand to a phi 172 // node, rather than to the terminator op itself, a terminator op can't e.g. 173 // "print" the value of a successor operand. 174 if (owner->isKnownTerminator()) { 175 if (auto arg = owner->getSuccessorBlockArgument(operandIndex)) 176 return !liveMap.wasProvenLive(*arg); 177 return false; 178 } 179 return false; 180 } 181 182 static void processValue(Value value, LiveMap &liveMap) { 183 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) { 184 if (isUseSpeciallyKnownDead(use, liveMap)) 185 return false; 186 return liveMap.wasProvenLive(use.getOwner()); 187 }); 188 if (provedLive) 189 liveMap.setProvedLive(value); 190 } 191 192 static bool isOpIntrinsicallyLive(Operation *op) { 193 // This pass doesn't modify the CFG, so terminators are never deleted. 194 if (!op->isKnownNonTerminator()) 195 return true; 196 // If the op has a side effect, we treat it as live. 197 if (!op->hasNoSideEffect()) 198 return true; 199 return false; 200 } 201 202 static void propagateLiveness(Region ®ion, LiveMap &liveMap); 203 static void propagateLiveness(Operation *op, LiveMap &liveMap) { 204 // All Value's are either a block argument or an op result. 205 // We call processValue on those cases. 206 207 // Recurse on any regions the op has. 208 for (Region ®ion : op->getRegions()) 209 propagateLiveness(region, liveMap); 210 211 // Process the op itself. 212 if (isOpIntrinsicallyLive(op)) { 213 liveMap.setProvedLive(op); 214 return; 215 } 216 for (Value value : op->getResults()) 217 processValue(value, liveMap); 218 bool provedLive = llvm::any_of(op->getResults(), [&](Value value) { 219 return liveMap.wasProvenLive(value); 220 }); 221 if (provedLive) 222 liveMap.setProvedLive(op); 223 } 224 225 static void propagateLiveness(Region ®ion, LiveMap &liveMap) { 226 if (region.empty()) 227 return; 228 229 for (Block *block : llvm::post_order(®ion.front())) { 230 // We process block arguments after the ops in the block, to promote 231 // faster convergence to a fixed point (we try to visit uses before defs). 232 for (Operation &op : llvm::reverse(block->getOperations())) 233 propagateLiveness(&op, liveMap); 234 for (Value value : block->getArguments()) 235 processValue(value, liveMap); 236 } 237 } 238 239 static void eraseTerminatorSuccessorOperands(Operation *terminator, 240 LiveMap &liveMap) { 241 for (unsigned succI = 0, succE = terminator->getNumSuccessors(); 242 succI < succE; succI++) { 243 // Iterating successors in reverse is not strictly needed, since we 244 // aren't erasing any successors. But it is slightly more efficient 245 // since it will promote later operands of the terminator being erased 246 // first, reducing the quadratic-ness. 247 unsigned succ = succE - succI - 1; 248 for (unsigned argI = 0, argE = terminator->getNumSuccessorOperands(succ); 249 argI < argE; argI++) { 250 // Iterating args in reverse is needed for correctness, to avoid 251 // shifting later args when earlier args are erased. 252 unsigned arg = argE - argI - 1; 253 Value value = terminator->getSuccessor(succ)->getArgument(arg); 254 if (!liveMap.wasProvenLive(value)) { 255 terminator->eraseSuccessorOperand(succ, arg); 256 } 257 } 258 } 259 } 260 261 static LogicalResult deleteDeadness(MutableArrayRef<Region> regions, 262 LiveMap &liveMap) { 263 bool erasedAnything = false; 264 for (Region ®ion : regions) { 265 if (region.empty()) 266 continue; 267 268 // We do the deletion in an order that deletes all uses before deleting 269 // defs. 270 // MLIR's SSA structural invariants guarantee that except for block 271 // arguments, the use-def graph is acyclic, so this is possible with a 272 // single walk of ops and then a final pass to clean up block arguments. 273 // 274 // To do this, we visit ops in an order that visits domtree children 275 // before domtree parents. A CFG post-order (with reverse iteration with a 276 // block) satisfies that without needing an explicit domtree calculation. 277 for (Block *block : llvm::post_order(®ion.front())) { 278 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); 279 for (Operation &childOp : 280 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { 281 erasedAnything |= 282 succeeded(deleteDeadness(childOp.getRegions(), liveMap)); 283 if (!liveMap.wasProvenLive(&childOp)) { 284 erasedAnything = true; 285 childOp.erase(); 286 } 287 } 288 } 289 // Delete block arguments. 290 // The entry block has an unknown contract with their enclosing block, so 291 // skip it. 292 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { 293 // Iterate in reverse to avoid shifting later arguments when deleting 294 // earlier arguments. 295 for (unsigned i = 0, e = block.getNumArguments(); i < e; i++) 296 if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) { 297 block.eraseArgument(e - i - 1, /*updatePredTerms=*/false); 298 erasedAnything = true; 299 } 300 } 301 } 302 return success(erasedAnything); 303 } 304 305 // This function performs a simple dead code elimination algorithm over the 306 // given regions. 307 // 308 // The overall goal is to prove that Values are dead, which allows deleting ops 309 // and block arguments. 310 // 311 // This uses an optimistic algorithm that assumes everything is dead until 312 // proved otherwise, allowing it to delete recursively dead cycles. 313 // 314 // This is a simple fixed-point dataflow analysis algorithm on a lattice 315 // {Dead,Alive}. Because liveness flows backward, we generally try to 316 // iterate everything backward to speed up convergence to the fixed-point. This 317 // allows for being able to delete recursively dead cycles of the use-def graph, 318 // including block arguments. 319 // 320 // This function returns success if any operations or arguments were deleted, 321 // failure otherwise. 322 static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) { 323 LiveMap liveMap; 324 do { 325 liveMap.resetChanged(); 326 327 for (Region ®ion : regions) 328 propagateLiveness(region, liveMap); 329 } while (liveMap.hasChanged()); 330 331 return deleteDeadness(regions, liveMap); 332 } 333 334 //===----------------------------------------------------------------------===// 335 // Region Simplification 336 //===----------------------------------------------------------------------===// 337 338 /// Run a set of structural simplifications over the given regions. This 339 /// includes transformations like unreachable block elimination, dead argument 340 /// elimination, as well as some other DCE. This function returns success if any 341 /// of the regions were simplified, failure otherwise. 342 LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) { 343 LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions); 344 LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions); 345 return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs)); 346 } 347