1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/RegionUtils.h" 10 #include "mlir/IR/Block.h" 11 #include "mlir/IR/Operation.h" 12 #include "mlir/IR/RegionGraphTraits.h" 13 #include "mlir/IR/Value.h" 14 #include "mlir/Interfaces/ControlFlowInterfaces.h" 15 #include "mlir/Interfaces/SideEffectInterfaces.h" 16 17 #include "llvm/ADT/DepthFirstIterator.h" 18 #include "llvm/ADT/PostOrderIterator.h" 19 #include "llvm/ADT/SmallSet.h" 20 21 using namespace mlir; 22 23 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, 24 Region ®ion) { 25 for (auto &use : llvm::make_early_inc_range(orig.getUses())) { 26 if (region.isAncestor(use.getOwner()->getParentRegion())) 27 use.set(replacement); 28 } 29 } 30 31 void mlir::visitUsedValuesDefinedAbove( 32 Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { 33 assert(limit.isAncestor(®ion) && 34 "expected isolation limit to be an ancestor of the given region"); 35 36 // Collect proper ancestors of `limit` upfront to avoid traversing the region 37 // tree for every value. 38 SmallPtrSet<Region *, 4> properAncestors; 39 for (auto *reg = limit.getParentRegion(); reg != nullptr; 40 reg = reg->getParentRegion()) { 41 properAncestors.insert(reg); 42 } 43 44 region.walk([callback, &properAncestors](Operation *op) { 45 for (OpOperand &operand : op->getOpOperands()) 46 // Callback on values defined in a proper ancestor of region. 47 if (properAncestors.count(operand.get().getParentRegion())) 48 callback(&operand); 49 }); 50 } 51 52 void mlir::visitUsedValuesDefinedAbove( 53 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { 54 for (Region ®ion : regions) 55 visitUsedValuesDefinedAbove(region, region, callback); 56 } 57 58 void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, 59 llvm::SetVector<Value> &values) { 60 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { 61 values.insert(operand->get()); 62 }); 63 } 64 65 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, 66 llvm::SetVector<Value> &values) { 67 for (Region ®ion : regions) 68 getUsedValuesDefinedAbove(region, region, values); 69 } 70 71 //===----------------------------------------------------------------------===// 72 // Unreachable Block Elimination 73 //===----------------------------------------------------------------------===// 74 75 /// Erase the unreachable blocks within the provided regions. Returns success 76 /// if any blocks were erased, failure otherwise. 77 // TODO: We could likely merge this with the DCE algorithm below. 78 static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) { 79 // Set of blocks found to be reachable within a given region. 80 llvm::df_iterator_default_set<Block *, 16> reachable; 81 // If any blocks were found to be dead. 82 bool erasedDeadBlocks = false; 83 84 SmallVector<Region *, 1> worklist; 85 worklist.reserve(regions.size()); 86 for (Region ®ion : regions) 87 worklist.push_back(®ion); 88 while (!worklist.empty()) { 89 Region *region = worklist.pop_back_val(); 90 if (region->empty()) 91 continue; 92 93 // If this is a single block region, just collect the nested regions. 94 if (std::next(region->begin()) == region->end()) { 95 for (Operation &op : region->front()) 96 for (Region ®ion : op.getRegions()) 97 worklist.push_back(®ion); 98 continue; 99 } 100 101 // Mark all reachable blocks. 102 reachable.clear(); 103 for (Block *block : depth_first_ext(®ion->front(), reachable)) 104 (void)block /* Mark all reachable blocks */; 105 106 // Collect all of the dead blocks and push the live regions onto the 107 // worklist. 108 for (Block &block : llvm::make_early_inc_range(*region)) { 109 if (!reachable.count(&block)) { 110 block.dropAllDefinedValueUses(); 111 block.erase(); 112 erasedDeadBlocks = true; 113 continue; 114 } 115 116 // Walk any regions within this block. 117 for (Operation &op : block) 118 for (Region ®ion : op.getRegions()) 119 worklist.push_back(®ion); 120 } 121 } 122 123 return success(erasedDeadBlocks); 124 } 125 126 //===----------------------------------------------------------------------===// 127 // Dead Code Elimination 128 //===----------------------------------------------------------------------===// 129 130 namespace { 131 /// Data structure used to track which values have already been proved live. 132 /// 133 /// Because Operation's can have multiple results, this data structure tracks 134 /// liveness for both Value's and Operation's to avoid having to look through 135 /// all Operation results when analyzing a use. 136 /// 137 /// This data structure essentially tracks the dataflow lattice. 138 /// The set of values/ops proved live increases monotonically to a fixed-point. 139 class LiveMap { 140 public: 141 /// Value methods. 142 bool wasProvenLive(Value value) { 143 // TODO: For results that are removable, e.g. for region based control flow, 144 // we could allow for these values to be tracked independently. 145 if (OpResult result = value.dyn_cast<OpResult>()) 146 return wasProvenLive(result.getOwner()); 147 return wasProvenLive(value.cast<BlockArgument>()); 148 } 149 bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); } 150 void setProvedLive(Value value) { 151 // TODO: For results that are removable, e.g. for region based control flow, 152 // we could allow for these values to be tracked independently. 153 if (OpResult result = value.dyn_cast<OpResult>()) 154 return setProvedLive(result.getOwner()); 155 setProvedLive(value.cast<BlockArgument>()); 156 } 157 void setProvedLive(BlockArgument arg) { 158 changed |= liveValues.insert(arg).second; 159 } 160 161 /// Operation methods. 162 bool wasProvenLive(Operation *op) { return liveOps.count(op); } 163 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; } 164 165 /// Methods for tracking if we have reached a fixed-point. 166 void resetChanged() { changed = false; } 167 bool hasChanged() { return changed; } 168 169 private: 170 bool changed = false; 171 DenseSet<Value> liveValues; 172 DenseSet<Operation *> liveOps; 173 }; 174 } // namespace 175 176 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { 177 Operation *owner = use.getOwner(); 178 unsigned operandIndex = use.getOperandNumber(); 179 // This pass generally treats all uses of an op as live if the op itself is 180 // considered live. However, for successor operands to terminators we need a 181 // finer-grained notion where we deduce liveness for operands individually. 182 // The reason for this is easiest to think about in terms of a classical phi 183 // node based SSA IR, where each successor operand is really an operand to a 184 // *separate* phi node, rather than all operands to the branch itself as with 185 // the block argument representation that MLIR uses. 186 // 187 // And similarly, because each successor operand is really an operand to a phi 188 // node, rather than to the terminator op itself, a terminator op can't e.g. 189 // "print" the value of a successor operand. 190 if (owner->hasTrait<OpTrait::IsTerminator>()) { 191 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner)) 192 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex)) 193 return !liveMap.wasProvenLive(*arg); 194 return false; 195 } 196 return false; 197 } 198 199 static void processValue(Value value, LiveMap &liveMap) { 200 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) { 201 if (isUseSpeciallyKnownDead(use, liveMap)) 202 return false; 203 return liveMap.wasProvenLive(use.getOwner()); 204 }); 205 if (provedLive) 206 liveMap.setProvedLive(value); 207 } 208 209 static void propagateLiveness(Region ®ion, LiveMap &liveMap); 210 211 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { 212 // Terminators are always live. 213 liveMap.setProvedLive(op); 214 215 // Check to see if we can reason about the successor operands and mutate them. 216 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op); 217 if (!branchInterface) { 218 for (Block *successor : op->getSuccessors()) 219 for (BlockArgument arg : successor->getArguments()) 220 liveMap.setProvedLive(arg); 221 return; 222 } 223 224 // If we can't reason about the operands to a successor, conservatively mark 225 // all arguments as live. 226 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { 227 if (!branchInterface.getMutableSuccessorOperands(i)) 228 for (BlockArgument arg : op->getSuccessor(i)->getArguments()) 229 liveMap.setProvedLive(arg); 230 } 231 } 232 233 static void propagateLiveness(Operation *op, LiveMap &liveMap) { 234 // Recurse on any regions the op has. 235 for (Region ®ion : op->getRegions()) 236 propagateLiveness(region, liveMap); 237 238 // Process terminator operations. 239 if (op->hasTrait<OpTrait::IsTerminator>()) 240 return propagateTerminatorLiveness(op, liveMap); 241 242 // Don't reprocess live operations. 243 if (liveMap.wasProvenLive(op)) 244 return; 245 246 // Process the op itself. 247 if (!wouldOpBeTriviallyDead(op)) 248 return liveMap.setProvedLive(op); 249 250 // If the op isn't intrinsically alive, check it's results. 251 for (Value value : op->getResults()) 252 processValue(value, liveMap); 253 } 254 255 static void propagateLiveness(Region ®ion, LiveMap &liveMap) { 256 if (region.empty()) 257 return; 258 259 for (Block *block : llvm::post_order(®ion.front())) { 260 // We process block arguments after the ops in the block, to promote 261 // faster convergence to a fixed point (we try to visit uses before defs). 262 for (Operation &op : llvm::reverse(block->getOperations())) 263 propagateLiveness(&op, liveMap); 264 265 // We currently do not remove entry block arguments, so there is no need to 266 // track their liveness. 267 // TODO: We could track these and enable removing dead operands/arguments 268 // from region control flow operations. 269 if (block->isEntryBlock()) 270 continue; 271 272 for (Value value : block->getArguments()) { 273 if (!liveMap.wasProvenLive(value)) 274 processValue(value, liveMap); 275 } 276 } 277 } 278 279 static void eraseTerminatorSuccessorOperands(Operation *terminator, 280 LiveMap &liveMap) { 281 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator); 282 if (!branchOp) 283 return; 284 285 for (unsigned succI = 0, succE = terminator->getNumSuccessors(); 286 succI < succE; succI++) { 287 // Iterating successors in reverse is not strictly needed, since we 288 // aren't erasing any successors. But it is slightly more efficient 289 // since it will promote later operands of the terminator being erased 290 // first, reducing the quadratic-ness. 291 unsigned succ = succE - succI - 1; 292 Optional<MutableOperandRange> succOperands = 293 branchOp.getMutableSuccessorOperands(succ); 294 if (!succOperands) 295 continue; 296 Block *successor = terminator->getSuccessor(succ); 297 298 for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) { 299 // Iterating args in reverse is needed for correctness, to avoid 300 // shifting later args when earlier args are erased. 301 unsigned arg = argE - argI - 1; 302 if (!liveMap.wasProvenLive(successor->getArgument(arg))) 303 succOperands->erase(arg); 304 } 305 } 306 } 307 308 static LogicalResult deleteDeadness(MutableArrayRef<Region> regions, 309 LiveMap &liveMap) { 310 bool erasedAnything = false; 311 for (Region ®ion : regions) { 312 if (region.empty()) 313 continue; 314 315 // We do the deletion in an order that deletes all uses before deleting 316 // defs. 317 // MLIR's SSA structural invariants guarantee that except for block 318 // arguments, the use-def graph is acyclic, so this is possible with a 319 // single walk of ops and then a final pass to clean up block arguments. 320 // 321 // To do this, we visit ops in an order that visits domtree children 322 // before domtree parents. A CFG post-order (with reverse iteration with a 323 // block) satisfies that without needing an explicit domtree calculation. 324 for (Block *block : llvm::post_order(®ion.front())) { 325 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); 326 for (Operation &childOp : 327 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { 328 if (!liveMap.wasProvenLive(&childOp)) { 329 erasedAnything = true; 330 childOp.erase(); 331 } else { 332 erasedAnything |= 333 succeeded(deleteDeadness(childOp.getRegions(), liveMap)); 334 } 335 } 336 } 337 // Delete block arguments. 338 // The entry block has an unknown contract with their enclosing block, so 339 // skip it. 340 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) { 341 block.eraseArguments( 342 [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); }); 343 } 344 } 345 return success(erasedAnything); 346 } 347 348 // This function performs a simple dead code elimination algorithm over the 349 // given regions. 350 // 351 // The overall goal is to prove that Values are dead, which allows deleting ops 352 // and block arguments. 353 // 354 // This uses an optimistic algorithm that assumes everything is dead until 355 // proved otherwise, allowing it to delete recursively dead cycles. 356 // 357 // This is a simple fixed-point dataflow analysis algorithm on a lattice 358 // {Dead,Alive}. Because liveness flows backward, we generally try to 359 // iterate everything backward to speed up convergence to the fixed-point. This 360 // allows for being able to delete recursively dead cycles of the use-def graph, 361 // including block arguments. 362 // 363 // This function returns success if any operations or arguments were deleted, 364 // failure otherwise. 365 static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) { 366 LiveMap liveMap; 367 do { 368 liveMap.resetChanged(); 369 370 for (Region ®ion : regions) 371 propagateLiveness(region, liveMap); 372 } while (liveMap.hasChanged()); 373 374 return deleteDeadness(regions, liveMap); 375 } 376 377 //===----------------------------------------------------------------------===// 378 // Block Merging 379 //===----------------------------------------------------------------------===// 380 381 //===----------------------------------------------------------------------===// 382 // BlockEquivalenceData 383 384 namespace { 385 /// This class contains the information for comparing the equivalencies of two 386 /// blocks. Blocks are considered equivalent if they contain the same operations 387 /// in the same order. The only allowed divergence is for operands that come 388 /// from sources outside of the parent block, i.e. the uses of values produced 389 /// within the block must be equivalent. 390 /// e.g., 391 /// Equivalent: 392 /// ^bb1(%arg0: i32) 393 /// return %arg0, %foo : i32, i32 394 /// ^bb2(%arg1: i32) 395 /// return %arg1, %bar : i32, i32 396 /// Not Equivalent: 397 /// ^bb1(%arg0: i32) 398 /// return %foo, %arg0 : i32, i32 399 /// ^bb2(%arg1: i32) 400 /// return %arg1, %bar : i32, i32 401 struct BlockEquivalenceData { 402 BlockEquivalenceData(Block *block); 403 404 /// Return the order index for the given value that is within the block of 405 /// this data. 406 unsigned getOrderOf(Value value) const; 407 408 /// The block this data refers to. 409 Block *block; 410 /// A hash value for this block. 411 llvm::hash_code hash; 412 /// A map of result producing operations to their relative orders within this 413 /// block. The order of an operation is the number of defined values that are 414 /// produced within the block before this operation. 415 DenseMap<Operation *, unsigned> opOrderIndex; 416 }; 417 } // end anonymous namespace 418 419 BlockEquivalenceData::BlockEquivalenceData(Block *block) 420 : block(block), hash(0) { 421 unsigned orderIt = block->getNumArguments(); 422 for (Operation &op : *block) { 423 if (unsigned numResults = op.getNumResults()) { 424 opOrderIndex.try_emplace(&op, orderIt); 425 orderIt += numResults; 426 } 427 auto opHash = OperationEquivalence::computeHash( 428 &op, OperationEquivalence::Flags::IgnoreOperands); 429 hash = llvm::hash_combine(hash, opHash); 430 } 431 } 432 433 unsigned BlockEquivalenceData::getOrderOf(Value value) const { 434 assert(value.getParentBlock() == block && "expected value of this block"); 435 436 // Arguments use the argument number as the order index. 437 if (BlockArgument arg = value.dyn_cast<BlockArgument>()) 438 return arg.getArgNumber(); 439 440 // Otherwise, the result order is offset from the parent op's order. 441 OpResult result = value.cast<OpResult>(); 442 auto opOrderIt = opOrderIndex.find(result.getDefiningOp()); 443 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); 444 return opOrderIt->second + result.getResultNumber(); 445 } 446 447 //===----------------------------------------------------------------------===// 448 // BlockMergeCluster 449 450 namespace { 451 /// This class represents a cluster of blocks to be merged together. 452 class BlockMergeCluster { 453 public: 454 BlockMergeCluster(BlockEquivalenceData &&leaderData) 455 : leaderData(std::move(leaderData)) {} 456 457 /// Attempt to add the given block to this cluster. Returns success if the 458 /// block was merged, failure otherwise. 459 LogicalResult addToCluster(BlockEquivalenceData &blockData); 460 461 /// Try to merge all of the blocks within this cluster into the leader block. 462 LogicalResult merge(); 463 464 private: 465 /// The equivalence data for the leader of the cluster. 466 BlockEquivalenceData leaderData; 467 468 /// The set of blocks that can be merged into the leader. 469 llvm::SmallSetVector<Block *, 1> blocksToMerge; 470 471 /// A set of operand+index pairs that correspond to operands that need to be 472 /// replaced by arguments when the cluster gets merged. 473 std::set<std::pair<int, int>> operandsToMerge; 474 }; 475 } // end anonymous namespace 476 477 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) { 478 if (leaderData.hash != blockData.hash) 479 return failure(); 480 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block; 481 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes()) 482 return failure(); 483 484 // A set of operands that mismatch between the leader and the new block. 485 SmallVector<std::pair<int, int>, 8> mismatchedOperands; 486 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end(); 487 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end(); 488 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { 489 // Check that the operations are equivalent. 490 if (!OperationEquivalence::isEquivalentTo( 491 &*lhsIt, &*rhsIt, OperationEquivalence::Flags::IgnoreOperands)) 492 return failure(); 493 494 // Compare the operands of the two operations. If the operand is within 495 // the block, it must refer to the same operation. 496 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands(); 497 for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) { 498 Value lhsOperand = lhsOperands[operand]; 499 Value rhsOperand = rhsOperands[operand]; 500 if (lhsOperand == rhsOperand) 501 continue; 502 // Check that the types of the operands match. 503 if (lhsOperand.getType() != rhsOperand.getType()) 504 return failure(); 505 506 // Check that these uses are both external, or both internal. 507 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock; 508 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock; 509 if (lhsIsInBlock != rhsIsInBlock) 510 return failure(); 511 // Let the operands differ if they are defined in a different block. These 512 // will become new arguments if the blocks get merged. 513 if (!lhsIsInBlock) { 514 mismatchedOperands.emplace_back(opI, operand); 515 continue; 516 } 517 518 // Otherwise, these operands must have the same logical order within the 519 // parent block. 520 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand)) 521 return failure(); 522 } 523 524 // If the lhs or rhs has external uses, the blocks cannot be merged as the 525 // merged version of this operation will not be either the lhs or rhs 526 // alone (thus semantically incorrect), but some mix dependending on which 527 // block preceeded this. 528 // TODO allow merging of operations when one block does not dominate the 529 // other 530 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) || 531 lhsIt->isUsedOutsideOfBlock(leaderBlock)) { 532 return failure(); 533 } 534 } 535 // Make sure that the block sizes are equivalent. 536 if (lhsIt != lhsE || rhsIt != rhsE) 537 return failure(); 538 539 // If we get here, the blocks are equivalent and can be merged. 540 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end()); 541 blocksToMerge.insert(blockData.block); 542 return success(); 543 } 544 545 /// Returns true if the predecessor terminators of the given block can not have 546 /// their operands updated. 547 static bool ableToUpdatePredOperands(Block *block) { 548 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { 549 auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator()); 550 if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex())) 551 return false; 552 } 553 return true; 554 } 555 556 LogicalResult BlockMergeCluster::merge() { 557 // Don't consider clusters that don't have blocks to merge. 558 if (blocksToMerge.empty()) 559 return failure(); 560 561 Block *leaderBlock = leaderData.block; 562 if (!operandsToMerge.empty()) { 563 // If the cluster has operands to merge, verify that the predecessor 564 // terminators of each of the blocks can have their successor operands 565 // updated. 566 // TODO: We could try and sub-partition this cluster if only some blocks 567 // cause the mismatch. 568 if (!ableToUpdatePredOperands(leaderBlock) || 569 !llvm::all_of(blocksToMerge, ableToUpdatePredOperands)) 570 return failure(); 571 572 // Collect the iterators for each of the blocks to merge. We will walk all 573 // of the iterators at once to avoid operand index invalidation. 574 SmallVector<Block::iterator, 2> blockIterators; 575 blockIterators.reserve(blocksToMerge.size() + 1); 576 blockIterators.push_back(leaderBlock->begin()); 577 for (Block *mergeBlock : blocksToMerge) 578 blockIterators.push_back(mergeBlock->begin()); 579 580 // Update each of the predecessor terminators with the new arguments. 581 SmallVector<SmallVector<Value, 8>, 2> newArguments( 582 1 + blocksToMerge.size(), 583 SmallVector<Value, 8>(operandsToMerge.size())); 584 unsigned curOpIndex = 0; 585 for (auto it : llvm::enumerate(operandsToMerge)) { 586 unsigned nextOpOffset = it.value().first - curOpIndex; 587 curOpIndex = it.value().first; 588 589 // Process the operand for each of the block iterators. 590 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) { 591 Block::iterator &blockIter = blockIterators[i]; 592 std::advance(blockIter, nextOpOffset); 593 auto &operand = blockIter->getOpOperand(it.value().second); 594 newArguments[i][it.index()] = operand.get(); 595 596 // Update the operand and insert an argument if this is the leader. 597 if (i == 0) 598 operand.set(leaderBlock->addArgument(operand.get().getType())); 599 } 600 } 601 // Update the predecessors for each of the blocks. 602 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) { 603 for (auto predIt = block->pred_begin(), predE = block->pred_end(); 604 predIt != predE; ++predIt) { 605 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); 606 unsigned succIndex = predIt.getSuccessorIndex(); 607 branch.getMutableSuccessorOperands(succIndex)->append( 608 newArguments[clusterIndex]); 609 } 610 }; 611 updatePredecessors(leaderBlock, /*clusterIndex=*/0); 612 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i) 613 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1); 614 } 615 616 // Replace all uses of the merged blocks with the leader and erase them. 617 for (Block *block : blocksToMerge) { 618 block->replaceAllUsesWith(leaderBlock); 619 block->erase(); 620 } 621 return success(); 622 } 623 624 /// Identify identical blocks within the given region and merge them, inserting 625 /// new block arguments as necessary. Returns success if any blocks were merged, 626 /// failure otherwise. 627 static LogicalResult mergeIdenticalBlocks(Region ®ion) { 628 if (region.empty() || llvm::hasSingleElement(region)) 629 return failure(); 630 631 // Identify sets of blocks, other than the entry block, that branch to the 632 // same successors. We will use these groups to create clusters of equivalent 633 // blocks. 634 DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors; 635 for (Block &block : llvm::drop_begin(region, 1)) 636 matchingSuccessors[block.getSuccessors()].push_back(&block); 637 638 bool mergedAnyBlocks = false; 639 for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) { 640 if (blocks.size() == 1) 641 continue; 642 643 SmallVector<BlockMergeCluster, 1> clusters; 644 for (Block *block : blocks) { 645 BlockEquivalenceData data(block); 646 647 // Don't allow merging if this block has any regions. 648 // TODO: Add support for regions if necessary. 649 bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) { 650 return llvm::any_of(op.getRegions(), 651 [](Region ®ion) { return !region.empty(); }); 652 }); 653 if (hasNonEmptyRegion) 654 continue; 655 656 // Try to add this block to an existing cluster. 657 bool addedToCluster = false; 658 for (auto &cluster : clusters) 659 if ((addedToCluster = succeeded(cluster.addToCluster(data)))) 660 break; 661 if (!addedToCluster) 662 clusters.emplace_back(std::move(data)); 663 } 664 for (auto &cluster : clusters) 665 mergedAnyBlocks |= succeeded(cluster.merge()); 666 } 667 668 return success(mergedAnyBlocks); 669 } 670 671 /// Identify identical blocks within the given regions and merge them, inserting 672 /// new block arguments as necessary. 673 static LogicalResult mergeIdenticalBlocks(MutableArrayRef<Region> regions) { 674 llvm::SmallSetVector<Region *, 1> worklist; 675 for (auto ®ion : regions) 676 worklist.insert(®ion); 677 bool anyChanged = false; 678 while (!worklist.empty()) { 679 Region *region = worklist.pop_back_val(); 680 if (succeeded(mergeIdenticalBlocks(*region))) { 681 worklist.insert(region); 682 anyChanged = true; 683 } 684 685 // Add any nested regions to the worklist. 686 for (Block &block : *region) 687 for (auto &op : block) 688 for (auto &nestedRegion : op.getRegions()) 689 worklist.insert(&nestedRegion); 690 } 691 692 return success(anyChanged); 693 } 694 695 //===----------------------------------------------------------------------===// 696 // Region Simplification 697 //===----------------------------------------------------------------------===// 698 699 /// Run a set of structural simplifications over the given regions. This 700 /// includes transformations like unreachable block elimination, dead argument 701 /// elimination, as well as some other DCE. This function returns success if any 702 /// of the regions were simplified, failure otherwise. 703 LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) { 704 bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(regions)); 705 bool eliminatedOpsOrArgs = succeeded(runRegionDCE(regions)); 706 bool mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(regions)); 707 return success(eliminatedBlocks || eliminatedOpsOrArgs || 708 mergedIdenticalBlocks); 709 } 710