1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/RegionUtils.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/RegionGraphTraits.h"
13 #include "mlir/IR/Value.h"
14 #include "mlir/Interfaces/ControlFlowInterfaces.h"
15 #include "mlir/Interfaces/SideEffectInterfaces.h"
16 
17 #include "llvm/ADT/DepthFirstIterator.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/ADT/SmallSet.h"
20 
21 using namespace mlir;
22 
23 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
24                                       Region &region) {
25   for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
26     if (region.isAncestor(use.getOwner()->getParentRegion()))
27       use.set(replacement);
28   }
29 }
30 
31 void mlir::visitUsedValuesDefinedAbove(
32     Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
33   assert(limit.isAncestor(&region) &&
34          "expected isolation limit to be an ancestor of the given region");
35 
36   // Collect proper ancestors of `limit` upfront to avoid traversing the region
37   // tree for every value.
38   SmallPtrSet<Region *, 4> properAncestors;
39   for (auto *reg = limit.getParentRegion(); reg != nullptr;
40        reg = reg->getParentRegion()) {
41     properAncestors.insert(reg);
42   }
43 
44   region.walk([callback, &properAncestors](Operation *op) {
45     for (OpOperand &operand : op->getOpOperands())
46       // Callback on values defined in a proper ancestor of region.
47       if (properAncestors.count(operand.get().getParentRegion()))
48         callback(&operand);
49   });
50 }
51 
52 void mlir::visitUsedValuesDefinedAbove(
53     MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
54   for (Region &region : regions)
55     visitUsedValuesDefinedAbove(region, region, callback);
56 }
57 
58 void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
59                                      llvm::SetVector<Value> &values) {
60   visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
61     values.insert(operand->get());
62   });
63 }
64 
65 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
66                                      llvm::SetVector<Value> &values) {
67   for (Region &region : regions)
68     getUsedValuesDefinedAbove(region, region, values);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Unreachable Block Elimination
73 //===----------------------------------------------------------------------===//
74 
75 /// Erase the unreachable blocks within the provided regions. Returns success
76 /// if any blocks were erased, failure otherwise.
77 // TODO: We could likely merge this with the DCE algorithm below.
78 static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) {
79   // Set of blocks found to be reachable within a given region.
80   llvm::df_iterator_default_set<Block *, 16> reachable;
81   // If any blocks were found to be dead.
82   bool erasedDeadBlocks = false;
83 
84   SmallVector<Region *, 1> worklist;
85   worklist.reserve(regions.size());
86   for (Region &region : regions)
87     worklist.push_back(&region);
88   while (!worklist.empty()) {
89     Region *region = worklist.pop_back_val();
90     if (region->empty())
91       continue;
92 
93     // If this is a single block region, just collect the nested regions.
94     if (std::next(region->begin()) == region->end()) {
95       for (Operation &op : region->front())
96         for (Region &region : op.getRegions())
97           worklist.push_back(&region);
98       continue;
99     }
100 
101     // Mark all reachable blocks.
102     reachable.clear();
103     for (Block *block : depth_first_ext(&region->front(), reachable))
104       (void)block /* Mark all reachable blocks */;
105 
106     // Collect all of the dead blocks and push the live regions onto the
107     // worklist.
108     for (Block &block : llvm::make_early_inc_range(*region)) {
109       if (!reachable.count(&block)) {
110         block.dropAllDefinedValueUses();
111         block.erase();
112         erasedDeadBlocks = true;
113         continue;
114       }
115 
116       // Walk any regions within this block.
117       for (Operation &op : block)
118         for (Region &region : op.getRegions())
119           worklist.push_back(&region);
120     }
121   }
122 
123   return success(erasedDeadBlocks);
124 }
125 
126 //===----------------------------------------------------------------------===//
127 // Dead Code Elimination
128 //===----------------------------------------------------------------------===//
129 
130 namespace {
131 /// Data structure used to track which values have already been proved live.
132 ///
133 /// Because Operation's can have multiple results, this data structure tracks
134 /// liveness for both Value's and Operation's to avoid having to look through
135 /// all Operation results when analyzing a use.
136 ///
137 /// This data structure essentially tracks the dataflow lattice.
138 /// The set of values/ops proved live increases monotonically to a fixed-point.
139 class LiveMap {
140 public:
141   /// Value methods.
142   bool wasProvenLive(Value value) {
143     // TODO: For results that are removable, e.g. for region based control flow,
144     // we could allow for these values to be tracked independently.
145     if (OpResult result = value.dyn_cast<OpResult>())
146       return wasProvenLive(result.getOwner());
147     return wasProvenLive(value.cast<BlockArgument>());
148   }
149   bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
150   void setProvedLive(Value value) {
151     // TODO: For results that are removable, e.g. for region based control flow,
152     // we could allow for these values to be tracked independently.
153     if (OpResult result = value.dyn_cast<OpResult>())
154       return setProvedLive(result.getOwner());
155     setProvedLive(value.cast<BlockArgument>());
156   }
157   void setProvedLive(BlockArgument arg) {
158     changed |= liveValues.insert(arg).second;
159   }
160 
161   /// Operation methods.
162   bool wasProvenLive(Operation *op) { return liveOps.count(op); }
163   void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
164 
165   /// Methods for tracking if we have reached a fixed-point.
166   void resetChanged() { changed = false; }
167   bool hasChanged() { return changed; }
168 
169 private:
170   bool changed = false;
171   DenseSet<Value> liveValues;
172   DenseSet<Operation *> liveOps;
173 };
174 } // namespace
175 
176 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
177   Operation *owner = use.getOwner();
178   unsigned operandIndex = use.getOperandNumber();
179   // This pass generally treats all uses of an op as live if the op itself is
180   // considered live. However, for successor operands to terminators we need a
181   // finer-grained notion where we deduce liveness for operands individually.
182   // The reason for this is easiest to think about in terms of a classical phi
183   // node based SSA IR, where each successor operand is really an operand to a
184   // *separate* phi node, rather than all operands to the branch itself as with
185   // the block argument representation that MLIR uses.
186   //
187   // And similarly, because each successor operand is really an operand to a phi
188   // node, rather than to the terminator op itself, a terminator op can't e.g.
189   // "print" the value of a successor operand.
190   if (owner->hasTrait<OpTrait::IsTerminator>()) {
191     if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
192       if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
193         return !liveMap.wasProvenLive(*arg);
194     return false;
195   }
196   return false;
197 }
198 
199 static void processValue(Value value, LiveMap &liveMap) {
200   bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
201     if (isUseSpeciallyKnownDead(use, liveMap))
202       return false;
203     return liveMap.wasProvenLive(use.getOwner());
204   });
205   if (provedLive)
206     liveMap.setProvedLive(value);
207 }
208 
209 static void propagateLiveness(Region &region, LiveMap &liveMap);
210 
211 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
212   // Terminators are always live.
213   liveMap.setProvedLive(op);
214 
215   // Check to see if we can reason about the successor operands and mutate them.
216   BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
217   if (!branchInterface) {
218     for (Block *successor : op->getSuccessors())
219       for (BlockArgument arg : successor->getArguments())
220         liveMap.setProvedLive(arg);
221     return;
222   }
223 
224   // If we can't reason about the operands to a successor, conservatively mark
225   // all arguments as live.
226   for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
227     if (!branchInterface.getMutableSuccessorOperands(i))
228       for (BlockArgument arg : op->getSuccessor(i)->getArguments())
229         liveMap.setProvedLive(arg);
230   }
231 }
232 
233 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
234   // Recurse on any regions the op has.
235   for (Region &region : op->getRegions())
236     propagateLiveness(region, liveMap);
237 
238   // Process terminator operations.
239   if (op->hasTrait<OpTrait::IsTerminator>())
240     return propagateTerminatorLiveness(op, liveMap);
241 
242   // Don't reprocess live operations.
243   if (liveMap.wasProvenLive(op))
244     return;
245 
246   // Process the op itself.
247   if (!wouldOpBeTriviallyDead(op))
248     return liveMap.setProvedLive(op);
249 
250   // If the op isn't intrinsically alive, check it's results.
251   for (Value value : op->getResults())
252     processValue(value, liveMap);
253 }
254 
255 static void propagateLiveness(Region &region, LiveMap &liveMap) {
256   if (region.empty())
257     return;
258 
259   for (Block *block : llvm::post_order(&region.front())) {
260     // We process block arguments after the ops in the block, to promote
261     // faster convergence to a fixed point (we try to visit uses before defs).
262     for (Operation &op : llvm::reverse(block->getOperations()))
263       propagateLiveness(&op, liveMap);
264 
265     // We currently do not remove entry block arguments, so there is no need to
266     // track their liveness.
267     // TODO: We could track these and enable removing dead operands/arguments
268     // from region control flow operations.
269     if (block->isEntryBlock())
270       continue;
271 
272     for (Value value : block->getArguments()) {
273       if (!liveMap.wasProvenLive(value))
274         processValue(value, liveMap);
275     }
276   }
277 }
278 
279 static void eraseTerminatorSuccessorOperands(Operation *terminator,
280                                              LiveMap &liveMap) {
281   BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
282   if (!branchOp)
283     return;
284 
285   for (unsigned succI = 0, succE = terminator->getNumSuccessors();
286        succI < succE; succI++) {
287     // Iterating successors in reverse is not strictly needed, since we
288     // aren't erasing any successors. But it is slightly more efficient
289     // since it will promote later operands of the terminator being erased
290     // first, reducing the quadratic-ness.
291     unsigned succ = succE - succI - 1;
292     Optional<MutableOperandRange> succOperands =
293         branchOp.getMutableSuccessorOperands(succ);
294     if (!succOperands)
295       continue;
296     Block *successor = terminator->getSuccessor(succ);
297 
298     for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) {
299       // Iterating args in reverse is needed for correctness, to avoid
300       // shifting later args when earlier args are erased.
301       unsigned arg = argE - argI - 1;
302       if (!liveMap.wasProvenLive(successor->getArgument(arg)))
303         succOperands->erase(arg);
304     }
305   }
306 }
307 
308 static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
309                                     LiveMap &liveMap) {
310   bool erasedAnything = false;
311   for (Region &region : regions) {
312     if (region.empty())
313       continue;
314 
315     // Delete every operation that is not live. Graph regions may have cycles
316     // in the use-def graph, so we must explicitly dropAllUses() from each
317     // operation as we erase it. Visiting the operations in post-order
318     // guarantees that in SSA CFG regions value uses are removed before defs,
319     // which makes dropAllUses() a no-op.
320     for (Block *block : llvm::post_order(&region.front())) {
321       eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
322       for (Operation &childOp :
323            llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
324         if (!liveMap.wasProvenLive(&childOp)) {
325           erasedAnything = true;
326           childOp.dropAllUses();
327           childOp.erase();
328         } else {
329           erasedAnything |=
330               succeeded(deleteDeadness(childOp.getRegions(), liveMap));
331         }
332       }
333     }
334     // Delete block arguments.
335     // The entry block has an unknown contract with their enclosing block, so
336     // skip it.
337     for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
338       block.eraseArguments(
339           [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
340     }
341   }
342   return success(erasedAnything);
343 }
344 
345 // This function performs a simple dead code elimination algorithm over the
346 // given regions.
347 //
348 // The overall goal is to prove that Values are dead, which allows deleting ops
349 // and block arguments.
350 //
351 // This uses an optimistic algorithm that assumes everything is dead until
352 // proved otherwise, allowing it to delete recursively dead cycles.
353 //
354 // This is a simple fixed-point dataflow analysis algorithm on a lattice
355 // {Dead,Alive}. Because liveness flows backward, we generally try to
356 // iterate everything backward to speed up convergence to the fixed-point. This
357 // allows for being able to delete recursively dead cycles of the use-def graph,
358 // including block arguments.
359 //
360 // This function returns success if any operations or arguments were deleted,
361 // failure otherwise.
362 static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
363   LiveMap liveMap;
364   do {
365     liveMap.resetChanged();
366 
367     for (Region &region : regions)
368       propagateLiveness(region, liveMap);
369   } while (liveMap.hasChanged());
370 
371   return deleteDeadness(regions, liveMap);
372 }
373 
374 //===----------------------------------------------------------------------===//
375 // Block Merging
376 //===----------------------------------------------------------------------===//
377 
378 //===----------------------------------------------------------------------===//
379 // BlockEquivalenceData
380 
381 namespace {
382 /// This class contains the information for comparing the equivalencies of two
383 /// blocks. Blocks are considered equivalent if they contain the same operations
384 /// in the same order. The only allowed divergence is for operands that come
385 /// from sources outside of the parent block, i.e. the uses of values produced
386 /// within the block must be equivalent.
387 ///   e.g.,
388 /// Equivalent:
389 ///  ^bb1(%arg0: i32)
390 ///    return %arg0, %foo : i32, i32
391 ///  ^bb2(%arg1: i32)
392 ///    return %arg1, %bar : i32, i32
393 /// Not Equivalent:
394 ///  ^bb1(%arg0: i32)
395 ///    return %foo, %arg0 : i32, i32
396 ///  ^bb2(%arg1: i32)
397 ///    return %arg1, %bar : i32, i32
398 struct BlockEquivalenceData {
399   BlockEquivalenceData(Block *block);
400 
401   /// Return the order index for the given value that is within the block of
402   /// this data.
403   unsigned getOrderOf(Value value) const;
404 
405   /// The block this data refers to.
406   Block *block;
407   /// A hash value for this block.
408   llvm::hash_code hash;
409   /// A map of result producing operations to their relative orders within this
410   /// block. The order of an operation is the number of defined values that are
411   /// produced within the block before this operation.
412   DenseMap<Operation *, unsigned> opOrderIndex;
413 };
414 } // end anonymous namespace
415 
416 BlockEquivalenceData::BlockEquivalenceData(Block *block)
417     : block(block), hash(0) {
418   unsigned orderIt = block->getNumArguments();
419   for (Operation &op : *block) {
420     if (unsigned numResults = op.getNumResults()) {
421       opOrderIndex.try_emplace(&op, orderIt);
422       orderIt += numResults;
423     }
424     auto opHash = OperationEquivalence::computeHash(
425         &op, OperationEquivalence::Flags::IgnoreOperands);
426     hash = llvm::hash_combine(hash, opHash);
427   }
428 }
429 
430 unsigned BlockEquivalenceData::getOrderOf(Value value) const {
431   assert(value.getParentBlock() == block && "expected value of this block");
432 
433   // Arguments use the argument number as the order index.
434   if (BlockArgument arg = value.dyn_cast<BlockArgument>())
435     return arg.getArgNumber();
436 
437   // Otherwise, the result order is offset from the parent op's order.
438   OpResult result = value.cast<OpResult>();
439   auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
440   assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
441   return opOrderIt->second + result.getResultNumber();
442 }
443 
444 //===----------------------------------------------------------------------===//
445 // BlockMergeCluster
446 
447 namespace {
448 /// This class represents a cluster of blocks to be merged together.
449 class BlockMergeCluster {
450 public:
451   BlockMergeCluster(BlockEquivalenceData &&leaderData)
452       : leaderData(std::move(leaderData)) {}
453 
454   /// Attempt to add the given block to this cluster. Returns success if the
455   /// block was merged, failure otherwise.
456   LogicalResult addToCluster(BlockEquivalenceData &blockData);
457 
458   /// Try to merge all of the blocks within this cluster into the leader block.
459   LogicalResult merge();
460 
461 private:
462   /// The equivalence data for the leader of the cluster.
463   BlockEquivalenceData leaderData;
464 
465   /// The set of blocks that can be merged into the leader.
466   llvm::SmallSetVector<Block *, 1> blocksToMerge;
467 
468   /// A set of operand+index pairs that correspond to operands that need to be
469   /// replaced by arguments when the cluster gets merged.
470   std::set<std::pair<int, int>> operandsToMerge;
471 };
472 } // end anonymous namespace
473 
474 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
475   if (leaderData.hash != blockData.hash)
476     return failure();
477   Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
478   if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
479     return failure();
480 
481   // A set of operands that mismatch between the leader and the new block.
482   SmallVector<std::pair<int, int>, 8> mismatchedOperands;
483   auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
484   auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
485   for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
486     // Check that the operations are equivalent.
487     if (!OperationEquivalence::isEquivalentTo(
488             &*lhsIt, &*rhsIt, OperationEquivalence::Flags::IgnoreOperands))
489       return failure();
490 
491     // Compare the operands of the two operations. If the operand is within
492     // the block, it must refer to the same operation.
493     auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
494     for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
495       Value lhsOperand = lhsOperands[operand];
496       Value rhsOperand = rhsOperands[operand];
497       if (lhsOperand == rhsOperand)
498         continue;
499       // Check that the types of the operands match.
500       if (lhsOperand.getType() != rhsOperand.getType())
501         return failure();
502 
503       // Check that these uses are both external, or both internal.
504       bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
505       bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
506       if (lhsIsInBlock != rhsIsInBlock)
507         return failure();
508       // Let the operands differ if they are defined in a different block. These
509       // will become new arguments if the blocks get merged.
510       if (!lhsIsInBlock) {
511         mismatchedOperands.emplace_back(opI, operand);
512         continue;
513       }
514 
515       // Otherwise, these operands must have the same logical order within the
516       // parent block.
517       if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
518         return failure();
519     }
520 
521     // If the lhs or rhs has external uses, the blocks cannot be merged as the
522     // merged version of this operation will not be either the lhs or rhs
523     // alone (thus semantically incorrect), but some mix dependending on which
524     // block preceeded this.
525     // TODO allow merging of operations when one block does not dominate the
526     // other
527     if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
528         lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
529       return failure();
530     }
531   }
532   // Make sure that the block sizes are equivalent.
533   if (lhsIt != lhsE || rhsIt != rhsE)
534     return failure();
535 
536   // If we get here, the blocks are equivalent and can be merged.
537   operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
538   blocksToMerge.insert(blockData.block);
539   return success();
540 }
541 
542 /// Returns true if the predecessor terminators of the given block can not have
543 /// their operands updated.
544 static bool ableToUpdatePredOperands(Block *block) {
545   for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
546     auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
547     if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex()))
548       return false;
549   }
550   return true;
551 }
552 
553 LogicalResult BlockMergeCluster::merge() {
554   // Don't consider clusters that don't have blocks to merge.
555   if (blocksToMerge.empty())
556     return failure();
557 
558   Block *leaderBlock = leaderData.block;
559   if (!operandsToMerge.empty()) {
560     // If the cluster has operands to merge, verify that the predecessor
561     // terminators of each of the blocks can have their successor operands
562     // updated.
563     // TODO: We could try and sub-partition this cluster if only some blocks
564     // cause the mismatch.
565     if (!ableToUpdatePredOperands(leaderBlock) ||
566         !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
567       return failure();
568 
569     // Collect the iterators for each of the blocks to merge. We will walk all
570     // of the iterators at once to avoid operand index invalidation.
571     SmallVector<Block::iterator, 2> blockIterators;
572     blockIterators.reserve(blocksToMerge.size() + 1);
573     blockIterators.push_back(leaderBlock->begin());
574     for (Block *mergeBlock : blocksToMerge)
575       blockIterators.push_back(mergeBlock->begin());
576 
577     // Update each of the predecessor terminators with the new arguments.
578     SmallVector<SmallVector<Value, 8>, 2> newArguments(
579         1 + blocksToMerge.size(),
580         SmallVector<Value, 8>(operandsToMerge.size()));
581     unsigned curOpIndex = 0;
582     for (auto it : llvm::enumerate(operandsToMerge)) {
583       unsigned nextOpOffset = it.value().first - curOpIndex;
584       curOpIndex = it.value().first;
585 
586       // Process the operand for each of the block iterators.
587       for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
588         Block::iterator &blockIter = blockIterators[i];
589         std::advance(blockIter, nextOpOffset);
590         auto &operand = blockIter->getOpOperand(it.value().second);
591         newArguments[i][it.index()] = operand.get();
592 
593         // Update the operand and insert an argument if this is the leader.
594         if (i == 0)
595           operand.set(leaderBlock->addArgument(operand.get().getType()));
596       }
597     }
598     // Update the predecessors for each of the blocks.
599     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
600       for (auto predIt = block->pred_begin(), predE = block->pred_end();
601            predIt != predE; ++predIt) {
602         auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
603         unsigned succIndex = predIt.getSuccessorIndex();
604         branch.getMutableSuccessorOperands(succIndex)->append(
605             newArguments[clusterIndex]);
606       }
607     };
608     updatePredecessors(leaderBlock, /*clusterIndex=*/0);
609     for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
610       updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
611   }
612 
613   // Replace all uses of the merged blocks with the leader and erase them.
614   for (Block *block : blocksToMerge) {
615     block->replaceAllUsesWith(leaderBlock);
616     block->erase();
617   }
618   return success();
619 }
620 
621 /// Identify identical blocks within the given region and merge them, inserting
622 /// new block arguments as necessary. Returns success if any blocks were merged,
623 /// failure otherwise.
624 static LogicalResult mergeIdenticalBlocks(Region &region) {
625   if (region.empty() || llvm::hasSingleElement(region))
626     return failure();
627 
628   // Identify sets of blocks, other than the entry block, that branch to the
629   // same successors. We will use these groups to create clusters of equivalent
630   // blocks.
631   DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors;
632   for (Block &block : llvm::drop_begin(region, 1))
633     matchingSuccessors[block.getSuccessors()].push_back(&block);
634 
635   bool mergedAnyBlocks = false;
636   for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
637     if (blocks.size() == 1)
638       continue;
639 
640     SmallVector<BlockMergeCluster, 1> clusters;
641     for (Block *block : blocks) {
642       BlockEquivalenceData data(block);
643 
644       // Don't allow merging if this block has any regions.
645       // TODO: Add support for regions if necessary.
646       bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
647         return llvm::any_of(op.getRegions(),
648                             [](Region &region) { return !region.empty(); });
649       });
650       if (hasNonEmptyRegion)
651         continue;
652 
653       // Try to add this block to an existing cluster.
654       bool addedToCluster = false;
655       for (auto &cluster : clusters)
656         if ((addedToCluster = succeeded(cluster.addToCluster(data))))
657           break;
658       if (!addedToCluster)
659         clusters.emplace_back(std::move(data));
660     }
661     for (auto &cluster : clusters)
662       mergedAnyBlocks |= succeeded(cluster.merge());
663   }
664 
665   return success(mergedAnyBlocks);
666 }
667 
668 /// Identify identical blocks within the given regions and merge them, inserting
669 /// new block arguments as necessary.
670 static LogicalResult mergeIdenticalBlocks(MutableArrayRef<Region> regions) {
671   llvm::SmallSetVector<Region *, 1> worklist;
672   for (auto &region : regions)
673     worklist.insert(&region);
674   bool anyChanged = false;
675   while (!worklist.empty()) {
676     Region *region = worklist.pop_back_val();
677     if (succeeded(mergeIdenticalBlocks(*region))) {
678       worklist.insert(region);
679       anyChanged = true;
680     }
681 
682     // Add any nested regions to the worklist.
683     for (Block &block : *region)
684       for (auto &op : block)
685         for (auto &nestedRegion : op.getRegions())
686           worklist.insert(&nestedRegion);
687   }
688 
689   return success(anyChanged);
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // Region Simplification
694 //===----------------------------------------------------------------------===//
695 
696 /// Run a set of structural simplifications over the given regions. This
697 /// includes transformations like unreachable block elimination, dead argument
698 /// elimination, as well as some other DCE. This function returns success if any
699 /// of the regions were simplified, failure otherwise.
700 LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) {
701   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(regions));
702   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(regions));
703   bool mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(regions));
704   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
705                  mergedIdenticalBlocks);
706 }
707