1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Transforms/RegionUtils.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/IR/RegionGraphTraits.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Interfaces/ControlFlowInterfaces.h"
16 #include "mlir/Interfaces/SideEffectInterfaces.h"
17
18 #include "llvm/ADT/DepthFirstIterator.h"
19 #include "llvm/ADT/PostOrderIterator.h"
20 #include "llvm/ADT/SmallSet.h"
21
22 using namespace mlir;
23
replaceAllUsesInRegionWith(Value orig,Value replacement,Region & region)24 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
25 Region ®ion) {
26 for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
27 if (region.isAncestor(use.getOwner()->getParentRegion()))
28 use.set(replacement);
29 }
30 }
31
visitUsedValuesDefinedAbove(Region & region,Region & limit,function_ref<void (OpOperand *)> callback)32 void mlir::visitUsedValuesDefinedAbove(
33 Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) {
34 assert(limit.isAncestor(®ion) &&
35 "expected isolation limit to be an ancestor of the given region");
36
37 // Collect proper ancestors of `limit` upfront to avoid traversing the region
38 // tree for every value.
39 SmallPtrSet<Region *, 4> properAncestors;
40 for (auto *reg = limit.getParentRegion(); reg != nullptr;
41 reg = reg->getParentRegion()) {
42 properAncestors.insert(reg);
43 }
44
45 region.walk([callback, &properAncestors](Operation *op) {
46 for (OpOperand &operand : op->getOpOperands())
47 // Callback on values defined in a proper ancestor of region.
48 if (properAncestors.count(operand.get().getParentRegion()))
49 callback(&operand);
50 });
51 }
52
visitUsedValuesDefinedAbove(MutableArrayRef<Region> regions,function_ref<void (OpOperand *)> callback)53 void mlir::visitUsedValuesDefinedAbove(
54 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
55 for (Region ®ion : regions)
56 visitUsedValuesDefinedAbove(region, region, callback);
57 }
58
getUsedValuesDefinedAbove(Region & region,Region & limit,SetVector<Value> & values)59 void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit,
60 SetVector<Value> &values) {
61 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
62 values.insert(operand->get());
63 });
64 }
65
getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,SetVector<Value> & values)66 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
67 SetVector<Value> &values) {
68 for (Region ®ion : regions)
69 getUsedValuesDefinedAbove(region, region, values);
70 }
71
72 //===----------------------------------------------------------------------===//
73 // Unreachable Block Elimination
74 //===----------------------------------------------------------------------===//
75
76 /// Erase the unreachable blocks within the provided regions. Returns success
77 /// if any blocks were erased, failure otherwise.
78 // TODO: We could likely merge this with the DCE algorithm below.
eraseUnreachableBlocks(RewriterBase & rewriter,MutableArrayRef<Region> regions)79 LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
80 MutableArrayRef<Region> regions) {
81 // Set of blocks found to be reachable within a given region.
82 llvm::df_iterator_default_set<Block *, 16> reachable;
83 // If any blocks were found to be dead.
84 bool erasedDeadBlocks = false;
85
86 SmallVector<Region *, 1> worklist;
87 worklist.reserve(regions.size());
88 for (Region ®ion : regions)
89 worklist.push_back(®ion);
90 while (!worklist.empty()) {
91 Region *region = worklist.pop_back_val();
92 if (region->empty())
93 continue;
94
95 // If this is a single block region, just collect the nested regions.
96 if (std::next(region->begin()) == region->end()) {
97 for (Operation &op : region->front())
98 for (Region ®ion : op.getRegions())
99 worklist.push_back(®ion);
100 continue;
101 }
102
103 // Mark all reachable blocks.
104 reachable.clear();
105 for (Block *block : depth_first_ext(®ion->front(), reachable))
106 (void)block /* Mark all reachable blocks */;
107
108 // Collect all of the dead blocks and push the live regions onto the
109 // worklist.
110 for (Block &block : llvm::make_early_inc_range(*region)) {
111 if (!reachable.count(&block)) {
112 block.dropAllDefinedValueUses();
113 rewriter.eraseBlock(&block);
114 erasedDeadBlocks = true;
115 continue;
116 }
117
118 // Walk any regions within this block.
119 for (Operation &op : block)
120 for (Region ®ion : op.getRegions())
121 worklist.push_back(®ion);
122 }
123 }
124
125 return success(erasedDeadBlocks);
126 }
127
128 //===----------------------------------------------------------------------===//
129 // Dead Code Elimination
130 //===----------------------------------------------------------------------===//
131
132 namespace {
133 /// Data structure used to track which values have already been proved live.
134 ///
135 /// Because Operation's can have multiple results, this data structure tracks
136 /// liveness for both Value's and Operation's to avoid having to look through
137 /// all Operation results when analyzing a use.
138 ///
139 /// This data structure essentially tracks the dataflow lattice.
140 /// The set of values/ops proved live increases monotonically to a fixed-point.
141 class LiveMap {
142 public:
143 /// Value methods.
wasProvenLive(Value value)144 bool wasProvenLive(Value value) {
145 // TODO: For results that are removable, e.g. for region based control flow,
146 // we could allow for these values to be tracked independently.
147 if (OpResult result = value.dyn_cast<OpResult>())
148 return wasProvenLive(result.getOwner());
149 return wasProvenLive(value.cast<BlockArgument>());
150 }
wasProvenLive(BlockArgument arg)151 bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
setProvedLive(Value value)152 void setProvedLive(Value value) {
153 // TODO: For results that are removable, e.g. for region based control flow,
154 // we could allow for these values to be tracked independently.
155 if (OpResult result = value.dyn_cast<OpResult>())
156 return setProvedLive(result.getOwner());
157 setProvedLive(value.cast<BlockArgument>());
158 }
setProvedLive(BlockArgument arg)159 void setProvedLive(BlockArgument arg) {
160 changed |= liveValues.insert(arg).second;
161 }
162
163 /// Operation methods.
wasProvenLive(Operation * op)164 bool wasProvenLive(Operation *op) { return liveOps.count(op); }
setProvedLive(Operation * op)165 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
166
167 /// Methods for tracking if we have reached a fixed-point.
resetChanged()168 void resetChanged() { changed = false; }
hasChanged()169 bool hasChanged() { return changed; }
170
171 private:
172 bool changed = false;
173 DenseSet<Value> liveValues;
174 DenseSet<Operation *> liveOps;
175 };
176 } // namespace
177
isUseSpeciallyKnownDead(OpOperand & use,LiveMap & liveMap)178 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
179 Operation *owner = use.getOwner();
180 unsigned operandIndex = use.getOperandNumber();
181 // This pass generally treats all uses of an op as live if the op itself is
182 // considered live. However, for successor operands to terminators we need a
183 // finer-grained notion where we deduce liveness for operands individually.
184 // The reason for this is easiest to think about in terms of a classical phi
185 // node based SSA IR, where each successor operand is really an operand to a
186 // *separate* phi node, rather than all operands to the branch itself as with
187 // the block argument representation that MLIR uses.
188 //
189 // And similarly, because each successor operand is really an operand to a phi
190 // node, rather than to the terminator op itself, a terminator op can't e.g.
191 // "print" the value of a successor operand.
192 if (owner->hasTrait<OpTrait::IsTerminator>()) {
193 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
194 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
195 return !liveMap.wasProvenLive(*arg);
196 return false;
197 }
198 return false;
199 }
200
processValue(Value value,LiveMap & liveMap)201 static void processValue(Value value, LiveMap &liveMap) {
202 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
203 if (isUseSpeciallyKnownDead(use, liveMap))
204 return false;
205 return liveMap.wasProvenLive(use.getOwner());
206 });
207 if (provedLive)
208 liveMap.setProvedLive(value);
209 }
210
211 static void propagateLiveness(Region ®ion, LiveMap &liveMap);
212
propagateTerminatorLiveness(Operation * op,LiveMap & liveMap)213 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
214 // Terminators are always live.
215 liveMap.setProvedLive(op);
216
217 // Check to see if we can reason about the successor operands and mutate them.
218 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
219 if (!branchInterface) {
220 for (Block *successor : op->getSuccessors())
221 for (BlockArgument arg : successor->getArguments())
222 liveMap.setProvedLive(arg);
223 return;
224 }
225
226 // If we can't reason about the operand to a successor, conservatively mark
227 // it as live.
228 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
229 SuccessorOperands successorOperands =
230 branchInterface.getSuccessorOperands(i);
231 for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
232 opI != opE; ++opI)
233 liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
234 }
235 }
236
propagateLiveness(Operation * op,LiveMap & liveMap)237 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
238 // Recurse on any regions the op has.
239 for (Region ®ion : op->getRegions())
240 propagateLiveness(region, liveMap);
241
242 // Process terminator operations.
243 if (op->hasTrait<OpTrait::IsTerminator>())
244 return propagateTerminatorLiveness(op, liveMap);
245
246 // Don't reprocess live operations.
247 if (liveMap.wasProvenLive(op))
248 return;
249
250 // Process the op itself.
251 if (!wouldOpBeTriviallyDead(op))
252 return liveMap.setProvedLive(op);
253
254 // If the op isn't intrinsically alive, check it's results.
255 for (Value value : op->getResults())
256 processValue(value, liveMap);
257 }
258
propagateLiveness(Region & region,LiveMap & liveMap)259 static void propagateLiveness(Region ®ion, LiveMap &liveMap) {
260 if (region.empty())
261 return;
262
263 for (Block *block : llvm::post_order(®ion.front())) {
264 // We process block arguments after the ops in the block, to promote
265 // faster convergence to a fixed point (we try to visit uses before defs).
266 for (Operation &op : llvm::reverse(block->getOperations()))
267 propagateLiveness(&op, liveMap);
268
269 // We currently do not remove entry block arguments, so there is no need to
270 // track their liveness.
271 // TODO: We could track these and enable removing dead operands/arguments
272 // from region control flow operations.
273 if (block->isEntryBlock())
274 continue;
275
276 for (Value value : block->getArguments()) {
277 if (!liveMap.wasProvenLive(value))
278 processValue(value, liveMap);
279 }
280 }
281 }
282
eraseTerminatorSuccessorOperands(Operation * terminator,LiveMap & liveMap)283 static void eraseTerminatorSuccessorOperands(Operation *terminator,
284 LiveMap &liveMap) {
285 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
286 if (!branchOp)
287 return;
288
289 for (unsigned succI = 0, succE = terminator->getNumSuccessors();
290 succI < succE; succI++) {
291 // Iterating successors in reverse is not strictly needed, since we
292 // aren't erasing any successors. But it is slightly more efficient
293 // since it will promote later operands of the terminator being erased
294 // first, reducing the quadratic-ness.
295 unsigned succ = succE - succI - 1;
296 SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
297 Block *successor = terminator->getSuccessor(succ);
298
299 for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
300 // Iterating args in reverse is needed for correctness, to avoid
301 // shifting later args when earlier args are erased.
302 unsigned arg = argE - argI - 1;
303 if (!liveMap.wasProvenLive(successor->getArgument(arg)))
304 succOperands.erase(arg);
305 }
306 }
307 }
308
deleteDeadness(RewriterBase & rewriter,MutableArrayRef<Region> regions,LiveMap & liveMap)309 static LogicalResult deleteDeadness(RewriterBase &rewriter,
310 MutableArrayRef<Region> regions,
311 LiveMap &liveMap) {
312 bool erasedAnything = false;
313 for (Region ®ion : regions) {
314 if (region.empty())
315 continue;
316 bool hasSingleBlock = llvm::hasSingleElement(region);
317
318 // Delete every operation that is not live. Graph regions may have cycles
319 // in the use-def graph, so we must explicitly dropAllUses() from each
320 // operation as we erase it. Visiting the operations in post-order
321 // guarantees that in SSA CFG regions value uses are removed before defs,
322 // which makes dropAllUses() a no-op.
323 for (Block *block : llvm::post_order(®ion.front())) {
324 if (!hasSingleBlock)
325 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
326 for (Operation &childOp :
327 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
328 if (!liveMap.wasProvenLive(&childOp)) {
329 erasedAnything = true;
330 childOp.dropAllUses();
331 rewriter.eraseOp(&childOp);
332 } else {
333 erasedAnything |= succeeded(
334 deleteDeadness(rewriter, childOp.getRegions(), liveMap));
335 }
336 }
337 }
338 // Delete block arguments.
339 // The entry block has an unknown contract with their enclosing block, so
340 // skip it.
341 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
342 block.eraseArguments(
343 [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
344 }
345 }
346 return success(erasedAnything);
347 }
348
349 // This function performs a simple dead code elimination algorithm over the
350 // given regions.
351 //
352 // The overall goal is to prove that Values are dead, which allows deleting ops
353 // and block arguments.
354 //
355 // This uses an optimistic algorithm that assumes everything is dead until
356 // proved otherwise, allowing it to delete recursively dead cycles.
357 //
358 // This is a simple fixed-point dataflow analysis algorithm on a lattice
359 // {Dead,Alive}. Because liveness flows backward, we generally try to
360 // iterate everything backward to speed up convergence to the fixed-point. This
361 // allows for being able to delete recursively dead cycles of the use-def graph,
362 // including block arguments.
363 //
364 // This function returns success if any operations or arguments were deleted,
365 // failure otherwise.
runRegionDCE(RewriterBase & rewriter,MutableArrayRef<Region> regions)366 LogicalResult mlir::runRegionDCE(RewriterBase &rewriter,
367 MutableArrayRef<Region> regions) {
368 LiveMap liveMap;
369 do {
370 liveMap.resetChanged();
371
372 for (Region ®ion : regions)
373 propagateLiveness(region, liveMap);
374 } while (liveMap.hasChanged());
375
376 return deleteDeadness(rewriter, regions, liveMap);
377 }
378
379 //===----------------------------------------------------------------------===//
380 // Block Merging
381 //===----------------------------------------------------------------------===//
382
383 //===----------------------------------------------------------------------===//
384 // BlockEquivalenceData
385
386 namespace {
387 /// This class contains the information for comparing the equivalencies of two
388 /// blocks. Blocks are considered equivalent if they contain the same operations
389 /// in the same order. The only allowed divergence is for operands that come
390 /// from sources outside of the parent block, i.e. the uses of values produced
391 /// within the block must be equivalent.
392 /// e.g.,
393 /// Equivalent:
394 /// ^bb1(%arg0: i32)
395 /// return %arg0, %foo : i32, i32
396 /// ^bb2(%arg1: i32)
397 /// return %arg1, %bar : i32, i32
398 /// Not Equivalent:
399 /// ^bb1(%arg0: i32)
400 /// return %foo, %arg0 : i32, i32
401 /// ^bb2(%arg1: i32)
402 /// return %arg1, %bar : i32, i32
403 struct BlockEquivalenceData {
404 BlockEquivalenceData(Block *block);
405
406 /// Return the order index for the given value that is within the block of
407 /// this data.
408 unsigned getOrderOf(Value value) const;
409
410 /// The block this data refers to.
411 Block *block;
412 /// A hash value for this block.
413 llvm::hash_code hash;
414 /// A map of result producing operations to their relative orders within this
415 /// block. The order of an operation is the number of defined values that are
416 /// produced within the block before this operation.
417 DenseMap<Operation *, unsigned> opOrderIndex;
418 };
419 } // namespace
420
BlockEquivalenceData(Block * block)421 BlockEquivalenceData::BlockEquivalenceData(Block *block)
422 : block(block), hash(0) {
423 unsigned orderIt = block->getNumArguments();
424 for (Operation &op : *block) {
425 if (unsigned numResults = op.getNumResults()) {
426 opOrderIndex.try_emplace(&op, orderIt);
427 orderIt += numResults;
428 }
429 auto opHash = OperationEquivalence::computeHash(
430 &op, OperationEquivalence::ignoreHashValue,
431 OperationEquivalence::ignoreHashValue,
432 OperationEquivalence::IgnoreLocations);
433 hash = llvm::hash_combine(hash, opHash);
434 }
435 }
436
getOrderOf(Value value) const437 unsigned BlockEquivalenceData::getOrderOf(Value value) const {
438 assert(value.getParentBlock() == block && "expected value of this block");
439
440 // Arguments use the argument number as the order index.
441 if (BlockArgument arg = value.dyn_cast<BlockArgument>())
442 return arg.getArgNumber();
443
444 // Otherwise, the result order is offset from the parent op's order.
445 OpResult result = value.cast<OpResult>();
446 auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
447 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
448 return opOrderIt->second + result.getResultNumber();
449 }
450
451 //===----------------------------------------------------------------------===//
452 // BlockMergeCluster
453
454 namespace {
455 /// This class represents a cluster of blocks to be merged together.
456 class BlockMergeCluster {
457 public:
BlockMergeCluster(BlockEquivalenceData && leaderData)458 BlockMergeCluster(BlockEquivalenceData &&leaderData)
459 : leaderData(std::move(leaderData)) {}
460
461 /// Attempt to add the given block to this cluster. Returns success if the
462 /// block was merged, failure otherwise.
463 LogicalResult addToCluster(BlockEquivalenceData &blockData);
464
465 /// Try to merge all of the blocks within this cluster into the leader block.
466 LogicalResult merge(RewriterBase &rewriter);
467
468 private:
469 /// The equivalence data for the leader of the cluster.
470 BlockEquivalenceData leaderData;
471
472 /// The set of blocks that can be merged into the leader.
473 llvm::SmallSetVector<Block *, 1> blocksToMerge;
474
475 /// A set of operand+index pairs that correspond to operands that need to be
476 /// replaced by arguments when the cluster gets merged.
477 std::set<std::pair<int, int>> operandsToMerge;
478 };
479 } // namespace
480
addToCluster(BlockEquivalenceData & blockData)481 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
482 if (leaderData.hash != blockData.hash)
483 return failure();
484 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
485 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
486 return failure();
487
488 // A set of operands that mismatch between the leader and the new block.
489 SmallVector<std::pair<int, int>, 8> mismatchedOperands;
490 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
491 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
492 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
493 // Check that the operations are equivalent.
494 if (!OperationEquivalence::isEquivalentTo(
495 &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
496 OperationEquivalence::ignoreValueEquivalence,
497 OperationEquivalence::Flags::IgnoreLocations))
498 return failure();
499
500 // Compare the operands of the two operations. If the operand is within
501 // the block, it must refer to the same operation.
502 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
503 for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
504 Value lhsOperand = lhsOperands[operand];
505 Value rhsOperand = rhsOperands[operand];
506 if (lhsOperand == rhsOperand)
507 continue;
508 // Check that the types of the operands match.
509 if (lhsOperand.getType() != rhsOperand.getType())
510 return failure();
511
512 // Check that these uses are both external, or both internal.
513 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
514 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
515 if (lhsIsInBlock != rhsIsInBlock)
516 return failure();
517 // Let the operands differ if they are defined in a different block. These
518 // will become new arguments if the blocks get merged.
519 if (!lhsIsInBlock) {
520
521 // Check whether the operands aren't the result of an immediate
522 // predecessors terminator. In that case we are not able to use it as a
523 // successor operand when branching to the merged block as it does not
524 // dominate its producing operation.
525 auto isValidSuccessorArg = [](Block *block, Value operand) {
526 if (operand.getDefiningOp() !=
527 operand.getParentBlock()->getTerminator())
528 return true;
529 return !llvm::is_contained(block->getPredecessors(),
530 operand.getParentBlock());
531 };
532
533 if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
534 !isValidSuccessorArg(mergeBlock, rhsOperand))
535 return failure();
536
537 mismatchedOperands.emplace_back(opI, operand);
538 continue;
539 }
540
541 // Otherwise, these operands must have the same logical order within the
542 // parent block.
543 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
544 return failure();
545 }
546
547 // If the lhs or rhs has external uses, the blocks cannot be merged as the
548 // merged version of this operation will not be either the lhs or rhs
549 // alone (thus semantically incorrect), but some mix dependending on which
550 // block preceeded this.
551 // TODO allow merging of operations when one block does not dominate the
552 // other
553 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
554 lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
555 return failure();
556 }
557 }
558 // Make sure that the block sizes are equivalent.
559 if (lhsIt != lhsE || rhsIt != rhsE)
560 return failure();
561
562 // If we get here, the blocks are equivalent and can be merged.
563 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
564 blocksToMerge.insert(blockData.block);
565 return success();
566 }
567
568 /// Returns true if the predecessor terminators of the given block can not have
569 /// their operands updated.
ableToUpdatePredOperands(Block * block)570 static bool ableToUpdatePredOperands(Block *block) {
571 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
572 if (!isa<BranchOpInterface>((*it)->getTerminator()))
573 return false;
574 }
575 return true;
576 }
577
merge(RewriterBase & rewriter)578 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
579 // Don't consider clusters that don't have blocks to merge.
580 if (blocksToMerge.empty())
581 return failure();
582
583 Block *leaderBlock = leaderData.block;
584 if (!operandsToMerge.empty()) {
585 // If the cluster has operands to merge, verify that the predecessor
586 // terminators of each of the blocks can have their successor operands
587 // updated.
588 // TODO: We could try and sub-partition this cluster if only some blocks
589 // cause the mismatch.
590 if (!ableToUpdatePredOperands(leaderBlock) ||
591 !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
592 return failure();
593
594 // Collect the iterators for each of the blocks to merge. We will walk all
595 // of the iterators at once to avoid operand index invalidation.
596 SmallVector<Block::iterator, 2> blockIterators;
597 blockIterators.reserve(blocksToMerge.size() + 1);
598 blockIterators.push_back(leaderBlock->begin());
599 for (Block *mergeBlock : blocksToMerge)
600 blockIterators.push_back(mergeBlock->begin());
601
602 // Update each of the predecessor terminators with the new arguments.
603 SmallVector<SmallVector<Value, 8>, 2> newArguments(
604 1 + blocksToMerge.size(),
605 SmallVector<Value, 8>(operandsToMerge.size()));
606 unsigned curOpIndex = 0;
607 for (const auto &it : llvm::enumerate(operandsToMerge)) {
608 unsigned nextOpOffset = it.value().first - curOpIndex;
609 curOpIndex = it.value().first;
610
611 // Process the operand for each of the block iterators.
612 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
613 Block::iterator &blockIter = blockIterators[i];
614 std::advance(blockIter, nextOpOffset);
615 auto &operand = blockIter->getOpOperand(it.value().second);
616 newArguments[i][it.index()] = operand.get();
617
618 // Update the operand and insert an argument if this is the leader.
619 if (i == 0) {
620 Value operandVal = operand.get();
621 operand.set(leaderBlock->addArgument(operandVal.getType(),
622 operandVal.getLoc()));
623 }
624 }
625 }
626 // Update the predecessors for each of the blocks.
627 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
628 for (auto predIt = block->pred_begin(), predE = block->pred_end();
629 predIt != predE; ++predIt) {
630 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
631 unsigned succIndex = predIt.getSuccessorIndex();
632 branch.getSuccessorOperands(succIndex).append(
633 newArguments[clusterIndex]);
634 }
635 };
636 updatePredecessors(leaderBlock, /*clusterIndex=*/0);
637 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
638 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
639 }
640
641 // Replace all uses of the merged blocks with the leader and erase them.
642 for (Block *block : blocksToMerge) {
643 block->replaceAllUsesWith(leaderBlock);
644 rewriter.eraseBlock(block);
645 }
646 return success();
647 }
648
649 /// Identify identical blocks within the given region and merge them, inserting
650 /// new block arguments as necessary. Returns success if any blocks were merged,
651 /// failure otherwise.
mergeIdenticalBlocks(RewriterBase & rewriter,Region & region)652 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
653 Region ®ion) {
654 if (region.empty() || llvm::hasSingleElement(region))
655 return failure();
656
657 // Identify sets of blocks, other than the entry block, that branch to the
658 // same successors. We will use these groups to create clusters of equivalent
659 // blocks.
660 DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors;
661 for (Block &block : llvm::drop_begin(region, 1))
662 matchingSuccessors[block.getSuccessors()].push_back(&block);
663
664 bool mergedAnyBlocks = false;
665 for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
666 if (blocks.size() == 1)
667 continue;
668
669 SmallVector<BlockMergeCluster, 1> clusters;
670 for (Block *block : blocks) {
671 BlockEquivalenceData data(block);
672
673 // Don't allow merging if this block has any regions.
674 // TODO: Add support for regions if necessary.
675 bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
676 return llvm::any_of(op.getRegions(),
677 [](Region ®ion) { return !region.empty(); });
678 });
679 if (hasNonEmptyRegion)
680 continue;
681
682 // Try to add this block to an existing cluster.
683 bool addedToCluster = false;
684 for (auto &cluster : clusters)
685 if ((addedToCluster = succeeded(cluster.addToCluster(data))))
686 break;
687 if (!addedToCluster)
688 clusters.emplace_back(std::move(data));
689 }
690 for (auto &cluster : clusters)
691 mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
692 }
693
694 return success(mergedAnyBlocks);
695 }
696
697 /// Identify identical blocks within the given regions and merge them, inserting
698 /// new block arguments as necessary.
mergeIdenticalBlocks(RewriterBase & rewriter,MutableArrayRef<Region> regions)699 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
700 MutableArrayRef<Region> regions) {
701 llvm::SmallSetVector<Region *, 1> worklist;
702 for (auto ®ion : regions)
703 worklist.insert(®ion);
704 bool anyChanged = false;
705 while (!worklist.empty()) {
706 Region *region = worklist.pop_back_val();
707 if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
708 worklist.insert(region);
709 anyChanged = true;
710 }
711
712 // Add any nested regions to the worklist.
713 for (Block &block : *region)
714 for (auto &op : block)
715 for (auto &nestedRegion : op.getRegions())
716 worklist.insert(&nestedRegion);
717 }
718
719 return success(anyChanged);
720 }
721
722 //===----------------------------------------------------------------------===//
723 // Region Simplification
724 //===----------------------------------------------------------------------===//
725
726 /// Run a set of structural simplifications over the given regions. This
727 /// includes transformations like unreachable block elimination, dead argument
728 /// elimination, as well as some other DCE. This function returns success if any
729 /// of the regions were simplified, failure otherwise.
simplifyRegions(RewriterBase & rewriter,MutableArrayRef<Region> regions)730 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
731 MutableArrayRef<Region> regions) {
732 bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
733 bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
734 bool mergedIdenticalBlocks =
735 succeeded(mergeIdenticalBlocks(rewriter, regions));
736 return success(eliminatedBlocks || eliminatedOpsOrArgs ||
737 mergedIdenticalBlocks);
738 }
739