1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Transforms/RegionUtils.h" 19 #include "mlir/IR/Block.h" 20 #include "mlir/IR/Operation.h" 21 #include "mlir/IR/RegionGraphTraits.h" 22 #include "mlir/IR/Value.h" 23 24 #include "llvm/ADT/DepthFirstIterator.h" 25 #include "llvm/ADT/PostOrderIterator.h" 26 #include "llvm/ADT/SmallSet.h" 27 28 using namespace mlir; 29 30 void mlir::replaceAllUsesInRegionWith(Value *orig, Value *replacement, 31 Region ®ion) { 32 for (IROperand &use : llvm::make_early_inc_range(orig->getUses())) { 33 if (region.isAncestor(use.getOwner()->getParentRegion())) 34 use.set(replacement); 35 } 36 } 37 38 void mlir::visitUsedValuesDefinedAbove( 39 Region ®ion, Region &limit, 40 llvm::function_ref<void(OpOperand *)> callback) { 41 assert(limit.isAncestor(®ion) && 42 "expected isolation limit to be an ancestor of the given region"); 43 44 // Collect proper ancestors of `limit` upfront to avoid traversing the region 45 // tree for every value. 46 llvm::SmallPtrSet<Region *, 4> properAncestors; 47 for (auto *reg = limit.getParentRegion(); reg != nullptr; 48 reg = reg->getParentRegion()) { 49 properAncestors.insert(reg); 50 } 51 52 region.walk([callback, &properAncestors](Operation *op) { 53 for (OpOperand &operand : op->getOpOperands()) 54 // Callback on values defined in a proper ancestor of region. 55 if (properAncestors.count(operand.get()->getParentRegion())) 56 callback(&operand); 57 }); 58 } 59 60 void mlir::visitUsedValuesDefinedAbove( 61 llvm::MutableArrayRef<Region> regions, 62 llvm::function_ref<void(OpOperand *)> callback) { 63 for (Region ®ion : regions) 64 visitUsedValuesDefinedAbove(region, region, callback); 65 } 66 67 void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, 68 llvm::SetVector<Value *> &values) { 69 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { 70 values.insert(operand->get()); 71 }); 72 } 73 74 void mlir::getUsedValuesDefinedAbove(llvm::MutableArrayRef<Region> regions, 75 llvm::SetVector<Value *> &values) { 76 for (Region ®ion : regions) 77 getUsedValuesDefinedAbove(region, region, values); 78 } 79 80 //===----------------------------------------------------------------------===// 81 // Unreachable Block Elimination 82 //===----------------------------------------------------------------------===// 83 84 /// Erase the unreachable blocks within the provided regions. Returns success 85 /// if any blocks were erased, failure otherwise. 86 // TODO: We could likely merge this with the DCE algorithm below. 87 static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) { 88 // Set of blocks found to be reachable within a given region. 89 llvm::df_iterator_default_set<Block *, 16> reachable; 90 // If any blocks were found to be dead. 91 bool erasedDeadBlocks = false; 92 93 SmallVector<Region *, 1> worklist; 94 worklist.reserve(regions.size()); 95 for (Region ®ion : regions) 96 worklist.push_back(®ion); 97 while (!worklist.empty()) { 98 Region *region = worklist.pop_back_val(); 99 if (region->empty()) 100 continue; 101 102 // If this is a single block region, just collect the nested regions. 103 if (std::next(region->begin()) == region->end()) { 104 for (Operation &op : region->front()) 105 for (Region ®ion : op.getRegions()) 106 worklist.push_back(®ion); 107 continue; 108 } 109 110 // Mark all reachable blocks. 111 reachable.clear(); 112 for (Block *block : depth_first_ext(®ion->front(), reachable)) 113 (void)block /* Mark all reachable blocks */; 114 115 // Collect all of the dead blocks and push the live regions onto the 116 // worklist. 117 for (Block &block : llvm::make_early_inc_range(*region)) { 118 if (!reachable.count(&block)) { 119 block.dropAllDefinedValueUses(); 120 block.erase(); 121 erasedDeadBlocks = true; 122 continue; 123 } 124 125 // Walk any regions within this block. 126 for (Operation &op : block) 127 for (Region ®ion : op.getRegions()) 128 worklist.push_back(®ion); 129 } 130 } 131 132 return success(erasedDeadBlocks); 133 } 134 135 //===----------------------------------------------------------------------===// 136 // Dead Code Elimination 137 //===----------------------------------------------------------------------===// 138 139 namespace { 140 /// Data structure used to track which values have already been proved live. 141 /// 142 /// Because Operation's can have multiple results, this data structure tracks 143 /// liveness for both Value's and Operation's to avoid having to look through 144 /// all Operation results when analyzing a use. 145 /// 146 /// This data structure essentially tracks the dataflow lattice. 147 /// The set of values/ops proved live increases monotonically to a fixed-point. 148 class LiveMap { 149 public: 150 /// Value methods. 151 bool wasProvenLive(Value *value) { return liveValues.count(value); } 152 void setProvedLive(Value *value) { 153 changed |= liveValues.insert(value).second; 154 } 155 156 /// Operation methods. 157 bool wasProvenLive(Operation *op) { return liveOps.count(op); } 158 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; } 159 160 /// Methods for tracking if we have reached a fixed-point. 161 void resetChanged() { changed = false; } 162 bool hasChanged() { return changed; } 163 164 private: 165 bool changed = false; 166 DenseSet<Value *> liveValues; 167 DenseSet<Operation *> liveOps; 168 }; 169 } // namespace 170 171 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { 172 Operation *owner = use.getOwner(); 173 unsigned operandIndex = use.getOperandNumber(); 174 // This pass generally treats all uses of an op as live if the op itself is 175 // considered live. However, for successor operands to terminators we need a 176 // finer-grained notion where we deduce liveness for operands individually. 177 // The reason for this is easiest to think about in terms of a classical phi 178 // node based SSA IR, where each successor operand is really an operand to a 179 // *separate* phi node, rather than all operands to the branch itself as with 180 // the block argument representation that MLIR uses. 181 // 182 // And similarly, because each successor operand is really an operand to a phi 183 // node, rather than to the terminator op itself, a terminator op can't e.g. 184 // "print" the value of a successor operand. 185 if (owner->isKnownTerminator()) { 186 if (auto arg = owner->getSuccessorBlockArgument(operandIndex)) 187 return !liveMap.wasProvenLive(*arg); 188 return false; 189 } 190 return false; 191 } 192 193 static void processValue(Value *value, LiveMap &liveMap) { 194 bool provedLive = llvm::any_of(value->getUses(), [&](OpOperand &use) { 195 if (isUseSpeciallyKnownDead(use, liveMap)) 196 return false; 197 return liveMap.wasProvenLive(use.getOwner()); 198 }); 199 if (provedLive) 200 liveMap.setProvedLive(value); 201 } 202 203 static bool isOpIntrinsicallyLive(Operation *op) { 204 // This pass doesn't modify the CFG, so terminators are never deleted. 205 if (!op->isKnownNonTerminator()) 206 return true; 207 // If the op has a side effect, we treat it as live. 208 if (!op->hasNoSideEffect()) 209 return true; 210 return false; 211 } 212 213 static void propagateLiveness(Region ®ion, LiveMap &liveMap); 214 static void propagateLiveness(Operation *op, LiveMap &liveMap) { 215 // All Value's are either a block argument or an op result. 216 // We call processValue on those cases. 217 218 // Recurse on any regions the op has. 219 for (Region ®ion : op->getRegions()) 220 propagateLiveness(region, liveMap); 221 222 // Process the op itself. 223 if (isOpIntrinsicallyLive(op)) { 224 liveMap.setProvedLive(op); 225 return; 226 } 227 for (Value *value : op->getResults()) 228 processValue(value, liveMap); 229 bool provedLive = llvm::any_of(op->getResults(), [&](Value *value) { 230 return liveMap.wasProvenLive(value); 231 }); 232 if (provedLive) 233 liveMap.setProvedLive(op); 234 } 235 236 static void propagateLiveness(Region ®ion, LiveMap &liveMap) { 237 if (region.empty()) 238 return; 239 240 for (Block *block : llvm::post_order(®ion.front())) { 241 // We process block arguments after the ops in the block, to promote 242 // faster convergence to a fixed point (we try to visit uses before defs). 243 for (Operation &op : llvm::reverse(block->getOperations())) 244 propagateLiveness(&op, liveMap); 245 for (Value *value : block->getArguments()) 246 processValue(value, liveMap); 247 } 248 } 249 250 static void eraseTerminatorSuccessorOperands(Operation *terminator, 251 LiveMap &liveMap) { 252 for (unsigned succI = 0, succE = terminator->getNumSuccessors(); 253 succI < succE; succI++) { 254 // Iterating successors in reverse is not strictly needed, since we 255 // aren't erasing any successors. But it is slightly more efficient 256 // since it will promote later operands of the terminator being erased 257 // first, reducing the quadratic-ness. 258 unsigned succ = succE - succI - 1; 259 for (unsigned argI = 0, argE = terminator->getNumSuccessorOperands(succ); 260 argI < argE; argI++) { 261 // Iterating args in reverse is needed for correctness, to avoid 262 // shifting later args when earlier args are erased. 263 unsigned arg = argE - argI - 1; 264 Value *value = terminator->getSuccessor(succ)->getArgument(arg); 265 if (!liveMap.wasProvenLive(value)) { 266 terminator->eraseSuccessorOperand(succ, arg); 267 } 268 } 269 } 270 } 271 272 static LogicalResult deleteDeadness(MutableArrayRef<Region> regions, 273 LiveMap &liveMap) { 274 bool erasedAnything = false; 275 for (Region ®ion : regions) { 276 if (region.empty()) 277 continue; 278 279 // We do the deletion in an order that deletes all uses before deleting 280 // defs. 281 // MLIR's SSA structural invariants guarantee that except for block 282 // arguments, the use-def graph is acyclic, so this is possible with a 283 // single walk of ops and then a final pass to clean up block arguments. 284 // 285 // To do this, we visit ops in an order that visits domtree children 286 // before domtree parents. A CFG post-order (with reverse iteration with a 287 // block) satisfies that without needing an explicit domtree calculation. 288 for (Block *block : llvm::post_order(®ion.front())) { 289 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); 290 for (Operation &childOp : 291 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { 292 erasedAnything |= 293 succeeded(deleteDeadness(childOp.getRegions(), liveMap)); 294 if (!liveMap.wasProvenLive(&childOp)) { 295 erasedAnything = true; 296 childOp.erase(); 297 } 298 } 299 } 300 // Delete block arguments. 301 // The entry block has an unknown contract with their enclosing block, so 302 // skip it. 303 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { 304 // Iterate in reverse to avoid shifting later arguments when deleting 305 // earlier arguments. 306 for (unsigned i = 0, e = block.getNumArguments(); i < e; i++) 307 if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) { 308 block.eraseArgument(e - i - 1, /*updatePredTerms=*/false); 309 erasedAnything = true; 310 } 311 } 312 } 313 return success(erasedAnything); 314 } 315 316 // This function performs a simple dead code elimination algorithm over the 317 // given regions. 318 // 319 // The overall goal is to prove that Values are dead, which allows deleting ops 320 // and block arguments. 321 // 322 // This uses an optimistic algorithm that assumes everything is dead until 323 // proved otherwise, allowing it to delete recursively dead cycles. 324 // 325 // This is a simple fixed-point dataflow analysis algorithm on a lattice 326 // {Dead,Alive}. Because liveness flows backward, we generally try to 327 // iterate everything backward to speed up convergence to the fixed-point. This 328 // allows for being able to delete recursively dead cycles of the use-def graph, 329 // including block arguments. 330 // 331 // This function returns success if any operations or arguments were deleted, 332 // failure otherwise. 333 static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) { 334 assert(regions.size() == 1); 335 336 LiveMap liveMap; 337 do { 338 liveMap.resetChanged(); 339 340 for (Region ®ion : regions) 341 propagateLiveness(region, liveMap); 342 } while (liveMap.hasChanged()); 343 344 return deleteDeadness(regions, liveMap); 345 } 346 347 //===----------------------------------------------------------------------===// 348 // Region Simplification 349 //===----------------------------------------------------------------------===// 350 351 /// Run a set of structural simplifications over the given regions. This 352 /// includes transformations like unreachable block elimination, dead argument 353 /// elimination, as well as some other DCE. This function returns success if any 354 /// of the regions were simplified, failure otherwise. 355 LogicalResult mlir::simplifyRegions(llvm::MutableArrayRef<Region> regions) { 356 LogicalResult eliminatedBlocks = eraseUnreachableBlocks(regions); 357 LogicalResult eliminatedOpsOrArgs = runRegionDCE(regions); 358 return success(succeeded(eliminatedBlocks) || succeeded(eliminatedOpsOrArgs)); 359 } 360