1ace01605SRiver Riddle //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2ace01605SRiver Riddle //
3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ace01605SRiver Riddle //
7ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8ace01605SRiver Riddle 
9ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10ace01605SRiver Riddle 
11ace01605SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12ace01605SRiver Riddle #include "mlir/Dialect/CommonFolders.h"
13ace01605SRiver Riddle #include "mlir/IR/AffineExpr.h"
14ace01605SRiver Riddle #include "mlir/IR/AffineMap.h"
15ace01605SRiver Riddle #include "mlir/IR/BlockAndValueMapping.h"
16ace01605SRiver Riddle #include "mlir/IR/Builders.h"
17ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h"
18ace01605SRiver Riddle #include "mlir/IR/BuiltinTypes.h"
19ace01605SRiver Riddle #include "mlir/IR/Matchers.h"
20ace01605SRiver Riddle #include "mlir/IR/OpImplementation.h"
21ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h"
22ace01605SRiver Riddle #include "mlir/IR/TypeUtilities.h"
23ace01605SRiver Riddle #include "mlir/IR/Value.h"
24ace01605SRiver Riddle #include "mlir/Support/MathExtras.h"
25ace01605SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
26ace01605SRiver Riddle #include "llvm/ADT/APFloat.h"
27ace01605SRiver Riddle #include "llvm/ADT/STLExtras.h"
28ace01605SRiver Riddle #include "llvm/ADT/StringSwitch.h"
29ace01605SRiver Riddle #include "llvm/Support/FormatVariadic.h"
30ace01605SRiver Riddle #include "llvm/Support/raw_ostream.h"
31ace01605SRiver Riddle #include <numeric>
32ace01605SRiver Riddle 
33ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
34ace01605SRiver Riddle 
35ace01605SRiver Riddle using namespace mlir;
36ace01605SRiver Riddle using namespace mlir::cf;
37ace01605SRiver Riddle 
38ace01605SRiver Riddle //===----------------------------------------------------------------------===//
39ace01605SRiver Riddle // ControlFlowDialect Interfaces
40ace01605SRiver Riddle //===----------------------------------------------------------------------===//
41ace01605SRiver Riddle namespace {
42ace01605SRiver Riddle /// This class defines the interface for handling inlining with control flow
43ace01605SRiver Riddle /// operations.
44ace01605SRiver Riddle struct ControlFlowInlinerInterface : public DialectInlinerInterface {
45ace01605SRiver Riddle   using DialectInlinerInterface::DialectInlinerInterface;
46ace01605SRiver Riddle   ~ControlFlowInlinerInterface() override = default;
47ace01605SRiver Riddle 
48ace01605SRiver Riddle   /// All control flow operations can be inlined.
isLegalToInline__anonbb8b539b0111::ControlFlowInlinerInterface49ace01605SRiver Riddle   bool isLegalToInline(Operation *call, Operation *callable,
50ace01605SRiver Riddle                        bool wouldBeCloned) const final {
51ace01605SRiver Riddle     return true;
52ace01605SRiver Riddle   }
isLegalToInline__anonbb8b539b0111::ControlFlowInlinerInterface53ace01605SRiver Riddle   bool isLegalToInline(Operation *, Region *, bool,
54ace01605SRiver Riddle                        BlockAndValueMapping &) const final {
55ace01605SRiver Riddle     return true;
56ace01605SRiver Riddle   }
57ace01605SRiver Riddle 
58ace01605SRiver Riddle   /// ControlFlow terminator operations don't really need any special handing.
handleTerminator__anonbb8b539b0111::ControlFlowInlinerInterface59ace01605SRiver Riddle   void handleTerminator(Operation *op, Block *newDest) const final {}
60ace01605SRiver Riddle };
61ace01605SRiver Riddle } // namespace
62ace01605SRiver Riddle 
63ace01605SRiver Riddle //===----------------------------------------------------------------------===//
64ace01605SRiver Riddle // ControlFlowDialect
65ace01605SRiver Riddle //===----------------------------------------------------------------------===//
66ace01605SRiver Riddle 
initialize()67ace01605SRiver Riddle void ControlFlowDialect::initialize() {
68ace01605SRiver Riddle   addOperations<
69ace01605SRiver Riddle #define GET_OP_LIST
70ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
71ace01605SRiver Riddle       >();
72ace01605SRiver Riddle   addInterfaces<ControlFlowInlinerInterface>();
73ace01605SRiver Riddle }
74ace01605SRiver Riddle 
75ace01605SRiver Riddle //===----------------------------------------------------------------------===//
76ace01605SRiver Riddle // AssertOp
77ace01605SRiver Riddle //===----------------------------------------------------------------------===//
78ace01605SRiver Riddle 
canonicalize(AssertOp op,PatternRewriter & rewriter)79ace01605SRiver Riddle LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
80ace01605SRiver Riddle   // Erase assertion if argument is constant true.
81ace01605SRiver Riddle   if (matchPattern(op.getArg(), m_One())) {
82ace01605SRiver Riddle     rewriter.eraseOp(op);
83ace01605SRiver Riddle     return success();
84ace01605SRiver Riddle   }
85ace01605SRiver Riddle   return failure();
86ace01605SRiver Riddle }
87ace01605SRiver Riddle 
88ace01605SRiver Riddle //===----------------------------------------------------------------------===//
89ace01605SRiver Riddle // BranchOp
90ace01605SRiver Riddle //===----------------------------------------------------------------------===//
91ace01605SRiver Riddle 
92ace01605SRiver Riddle /// Given a successor, try to collapse it to a new destination if it only
93ace01605SRiver Riddle /// contains a passthrough unconditional branch. If the successor is
94ace01605SRiver Riddle /// collapsable, `successor` and `successorOperands` are updated to reference
95ace01605SRiver Riddle /// the new destination and values. `argStorage` is used as storage if operands
96ace01605SRiver Riddle /// to the collapsed successor need to be remapped. It must outlive uses of
97ace01605SRiver Riddle /// successorOperands.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)98ace01605SRiver Riddle static LogicalResult collapseBranch(Block *&successor,
99ace01605SRiver Riddle                                     ValueRange &successorOperands,
100ace01605SRiver Riddle                                     SmallVectorImpl<Value> &argStorage) {
101ace01605SRiver Riddle   // Check that the successor only contains a unconditional branch.
102ace01605SRiver Riddle   if (std::next(successor->begin()) != successor->end())
103ace01605SRiver Riddle     return failure();
104ace01605SRiver Riddle   // Check that the terminator is an unconditional branch.
105ace01605SRiver Riddle   BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
106ace01605SRiver Riddle   if (!successorBranch)
107ace01605SRiver Riddle     return failure();
108ace01605SRiver Riddle   // Check that the arguments are only used within the terminator.
109ace01605SRiver Riddle   for (BlockArgument arg : successor->getArguments()) {
110ace01605SRiver Riddle     for (Operation *user : arg.getUsers())
111ace01605SRiver Riddle       if (user != successorBranch)
112ace01605SRiver Riddle         return failure();
113ace01605SRiver Riddle   }
114ace01605SRiver Riddle   // Don't try to collapse branches to infinite loops.
115ace01605SRiver Riddle   Block *successorDest = successorBranch.getDest();
116ace01605SRiver Riddle   if (successorDest == successor)
117ace01605SRiver Riddle     return failure();
118ace01605SRiver Riddle 
119ace01605SRiver Riddle   // Update the operands to the successor. If the branch parent has no
120ace01605SRiver Riddle   // arguments, we can use the branch operands directly.
121ace01605SRiver Riddle   OperandRange operands = successorBranch.getOperands();
122ace01605SRiver Riddle   if (successor->args_empty()) {
123ace01605SRiver Riddle     successor = successorDest;
124ace01605SRiver Riddle     successorOperands = operands;
125ace01605SRiver Riddle     return success();
126ace01605SRiver Riddle   }
127ace01605SRiver Riddle 
128ace01605SRiver Riddle   // Otherwise, we need to remap any argument operands.
129ace01605SRiver Riddle   for (Value operand : operands) {
130ace01605SRiver Riddle     BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
131ace01605SRiver Riddle     if (argOperand && argOperand.getOwner() == successor)
132ace01605SRiver Riddle       argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
133ace01605SRiver Riddle     else
134ace01605SRiver Riddle       argStorage.push_back(operand);
135ace01605SRiver Riddle   }
136ace01605SRiver Riddle   successor = successorDest;
137ace01605SRiver Riddle   successorOperands = argStorage;
138ace01605SRiver Riddle   return success();
139ace01605SRiver Riddle }
140ace01605SRiver Riddle 
141ace01605SRiver Riddle /// Simplify a branch to a block that has a single predecessor. This effectively
142ace01605SRiver Riddle /// merges the two blocks.
143ace01605SRiver Riddle static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op,PatternRewriter & rewriter)144ace01605SRiver Riddle simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
145ace01605SRiver Riddle   // Check that the successor block has a single predecessor.
146ace01605SRiver Riddle   Block *succ = op.getDest();
147ace01605SRiver Riddle   Block *opParent = op->getBlock();
148ace01605SRiver Riddle   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
149ace01605SRiver Riddle     return failure();
150ace01605SRiver Riddle 
151ace01605SRiver Riddle   // Merge the successor into the current block and erase the branch.
152ace01605SRiver Riddle   rewriter.mergeBlocks(succ, opParent, op.getOperands());
153ace01605SRiver Riddle   rewriter.eraseOp(op);
154ace01605SRiver Riddle   return success();
155ace01605SRiver Riddle }
156ace01605SRiver Riddle 
157ace01605SRiver Riddle ///   br ^bb1
158ace01605SRiver Riddle /// ^bb1
159ace01605SRiver Riddle ///   br ^bbN(...)
160ace01605SRiver Riddle ///
161ace01605SRiver Riddle ///  -> br ^bbN(...)
162ace01605SRiver Riddle ///
simplifyPassThroughBr(BranchOp op,PatternRewriter & rewriter)163ace01605SRiver Riddle static LogicalResult simplifyPassThroughBr(BranchOp op,
164ace01605SRiver Riddle                                            PatternRewriter &rewriter) {
165ace01605SRiver Riddle   Block *dest = op.getDest();
166ace01605SRiver Riddle   ValueRange destOperands = op.getOperands();
167ace01605SRiver Riddle   SmallVector<Value, 4> destOperandStorage;
168ace01605SRiver Riddle 
169ace01605SRiver Riddle   // Try to collapse the successor if it points somewhere other than this
170ace01605SRiver Riddle   // block.
171ace01605SRiver Riddle   if (dest == op->getBlock() ||
172ace01605SRiver Riddle       failed(collapseBranch(dest, destOperands, destOperandStorage)))
173ace01605SRiver Riddle     return failure();
174ace01605SRiver Riddle 
175ace01605SRiver Riddle   // Create a new branch with the collapsed successor.
176ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
177ace01605SRiver Riddle   return success();
178ace01605SRiver Riddle }
179ace01605SRiver Riddle 
canonicalize(BranchOp op,PatternRewriter & rewriter)180ace01605SRiver Riddle LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
181ace01605SRiver Riddle   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
182ace01605SRiver Riddle                  succeeded(simplifyPassThroughBr(op, rewriter)));
183ace01605SRiver Riddle }
184ace01605SRiver Riddle 
setDest(Block * block)185ace01605SRiver Riddle void BranchOp::setDest(Block *block) { return setSuccessor(block); }
186ace01605SRiver Riddle 
eraseOperand(unsigned index)187ace01605SRiver Riddle void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
188ace01605SRiver Riddle 
getSuccessorOperands(unsigned index)1890c789db5SMarkus Böck SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
190ace01605SRiver Riddle   assert(index == 0 && "invalid successor index");
1910c789db5SMarkus Böck   return SuccessorOperands(getDestOperandsMutable());
192ace01605SRiver Riddle }
193ace01605SRiver Riddle 
getSuccessorForOperands(ArrayRef<Attribute>)194ace01605SRiver Riddle Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
195ace01605SRiver Riddle   return getDest();
196ace01605SRiver Riddle }
197ace01605SRiver Riddle 
198ace01605SRiver Riddle //===----------------------------------------------------------------------===//
199ace01605SRiver Riddle // CondBranchOp
200ace01605SRiver Riddle //===----------------------------------------------------------------------===//
201ace01605SRiver Riddle 
202ace01605SRiver Riddle namespace {
203ace01605SRiver Riddle /// cf.cond_br true, ^bb1, ^bb2
204ace01605SRiver Riddle ///  -> br ^bb1
205ace01605SRiver Riddle /// cf.cond_br false, ^bb1, ^bb2
206ace01605SRiver Riddle ///  -> br ^bb2
207ace01605SRiver Riddle ///
208ace01605SRiver Riddle struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
209ace01605SRiver Riddle   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
210ace01605SRiver Riddle 
matchAndRewrite__anonbb8b539b0211::SimplifyConstCondBranchPred211ace01605SRiver Riddle   LogicalResult matchAndRewrite(CondBranchOp condbr,
212ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override {
213ace01605SRiver Riddle     if (matchPattern(condbr.getCondition(), m_NonZero())) {
214ace01605SRiver Riddle       // True branch taken.
215ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
216ace01605SRiver Riddle                                             condbr.getTrueOperands());
217ace01605SRiver Riddle       return success();
218ace01605SRiver Riddle     }
219ace01605SRiver Riddle     if (matchPattern(condbr.getCondition(), m_Zero())) {
220ace01605SRiver Riddle       // False branch taken.
221ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
222ace01605SRiver Riddle                                             condbr.getFalseOperands());
223ace01605SRiver Riddle       return success();
224ace01605SRiver Riddle     }
225ace01605SRiver Riddle     return failure();
226ace01605SRiver Riddle   }
227ace01605SRiver Riddle };
228ace01605SRiver Riddle 
229ace01605SRiver Riddle ///   cf.cond_br %cond, ^bb1, ^bb2
230ace01605SRiver Riddle /// ^bb1
231ace01605SRiver Riddle ///   br ^bbN(...)
232ace01605SRiver Riddle /// ^bb2
233ace01605SRiver Riddle ///   br ^bbK(...)
234ace01605SRiver Riddle ///
235ace01605SRiver Riddle ///  -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
236ace01605SRiver Riddle ///
237ace01605SRiver Riddle struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
238ace01605SRiver Riddle   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
239ace01605SRiver Riddle 
matchAndRewrite__anonbb8b539b0211::SimplifyPassThroughCondBranch240ace01605SRiver Riddle   LogicalResult matchAndRewrite(CondBranchOp condbr,
241ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override {
242ace01605SRiver Riddle     Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
243ace01605SRiver Riddle     ValueRange trueDestOperands = condbr.getTrueOperands();
244ace01605SRiver Riddle     ValueRange falseDestOperands = condbr.getFalseOperands();
245ace01605SRiver Riddle     SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
246ace01605SRiver Riddle 
247ace01605SRiver Riddle     // Try to collapse one of the current successors.
248ace01605SRiver Riddle     LogicalResult collapsedTrue =
249ace01605SRiver Riddle         collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
250ace01605SRiver Riddle     LogicalResult collapsedFalse =
251ace01605SRiver Riddle         collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
252ace01605SRiver Riddle     if (failed(collapsedTrue) && failed(collapsedFalse))
253ace01605SRiver Riddle       return failure();
254ace01605SRiver Riddle 
255ace01605SRiver Riddle     // Create a new branch with the collapsed successors.
256ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
257ace01605SRiver Riddle                                               trueDest, trueDestOperands,
258ace01605SRiver Riddle                                               falseDest, falseDestOperands);
259ace01605SRiver Riddle     return success();
260ace01605SRiver Riddle   }
261ace01605SRiver Riddle };
262ace01605SRiver Riddle 
263ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
264ace01605SRiver Riddle ///  -> br ^bb1(A, ..., N)
265ace01605SRiver Riddle ///
266ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
267ace01605SRiver Riddle ///  -> %select = arith.select %cond, A, B
268ace01605SRiver Riddle ///     br ^bb1(%select)
269ace01605SRiver Riddle ///
270ace01605SRiver Riddle struct SimplifyCondBranchIdenticalSuccessors
271ace01605SRiver Riddle     : public OpRewritePattern<CondBranchOp> {
272ace01605SRiver Riddle   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
273ace01605SRiver Riddle 
matchAndRewrite__anonbb8b539b0211::SimplifyCondBranchIdenticalSuccessors274ace01605SRiver Riddle   LogicalResult matchAndRewrite(CondBranchOp condbr,
275ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override {
276ace01605SRiver Riddle     // Check that the true and false destinations are the same and have the same
277ace01605SRiver Riddle     // operands.
278ace01605SRiver Riddle     Block *trueDest = condbr.getTrueDest();
279ace01605SRiver Riddle     if (trueDest != condbr.getFalseDest())
280ace01605SRiver Riddle       return failure();
281ace01605SRiver Riddle 
282ace01605SRiver Riddle     // If all of the operands match, no selects need to be generated.
283ace01605SRiver Riddle     OperandRange trueOperands = condbr.getTrueOperands();
284ace01605SRiver Riddle     OperandRange falseOperands = condbr.getFalseOperands();
285ace01605SRiver Riddle     if (trueOperands == falseOperands) {
286ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
287ace01605SRiver Riddle       return success();
288ace01605SRiver Riddle     }
289ace01605SRiver Riddle 
290ace01605SRiver Riddle     // Otherwise, if the current block is the only predecessor insert selects
291ace01605SRiver Riddle     // for any mismatched branch operands.
292ace01605SRiver Riddle     if (trueDest->getUniquePredecessor() != condbr->getBlock())
293ace01605SRiver Riddle       return failure();
294ace01605SRiver Riddle 
295ace01605SRiver Riddle     // Generate a select for any operands that differ between the two.
296ace01605SRiver Riddle     SmallVector<Value, 8> mergedOperands;
297ace01605SRiver Riddle     mergedOperands.reserve(trueOperands.size());
298ace01605SRiver Riddle     Value condition = condbr.getCondition();
299ace01605SRiver Riddle     for (auto it : llvm::zip(trueOperands, falseOperands)) {
300ace01605SRiver Riddle       if (std::get<0>(it) == std::get<1>(it))
301ace01605SRiver Riddle         mergedOperands.push_back(std::get<0>(it));
302ace01605SRiver Riddle       else
303ace01605SRiver Riddle         mergedOperands.push_back(rewriter.create<arith::SelectOp>(
304ace01605SRiver Riddle             condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
305ace01605SRiver Riddle     }
306ace01605SRiver Riddle 
307ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
308ace01605SRiver Riddle     return success();
309ace01605SRiver Riddle   }
310ace01605SRiver Riddle };
311ace01605SRiver Riddle 
312ace01605SRiver Riddle ///   ...
313ace01605SRiver Riddle ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
314ace01605SRiver Riddle /// ...
315ace01605SRiver Riddle /// ^bb1: // has single predecessor
316ace01605SRiver Riddle ///   ...
317ace01605SRiver Riddle ///   cf.cond_br %cond, ^bb3(...), ^bb4(...)
318ace01605SRiver Riddle ///
319ace01605SRiver Riddle /// ->
320ace01605SRiver Riddle ///
321ace01605SRiver Riddle ///   ...
322ace01605SRiver Riddle ///   cf.cond_br %cond, ^bb1(...), ^bb2(...)
323ace01605SRiver Riddle /// ...
324ace01605SRiver Riddle /// ^bb1: // has single predecessor
325ace01605SRiver Riddle ///   ...
326ace01605SRiver Riddle ///   br ^bb3(...)
327ace01605SRiver Riddle ///
328ace01605SRiver Riddle struct SimplifyCondBranchFromCondBranchOnSameCondition
329ace01605SRiver Riddle     : public OpRewritePattern<CondBranchOp> {
330ace01605SRiver Riddle   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
331ace01605SRiver Riddle 
matchAndRewrite__anonbb8b539b0211::SimplifyCondBranchFromCondBranchOnSameCondition332ace01605SRiver Riddle   LogicalResult matchAndRewrite(CondBranchOp condbr,
333ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override {
334ace01605SRiver Riddle     // Check that we have a single distinct predecessor.
335ace01605SRiver Riddle     Block *currentBlock = condbr->getBlock();
336ace01605SRiver Riddle     Block *predecessor = currentBlock->getSinglePredecessor();
337ace01605SRiver Riddle     if (!predecessor)
338ace01605SRiver Riddle       return failure();
339ace01605SRiver Riddle 
340ace01605SRiver Riddle     // Check that the predecessor terminates with a conditional branch to this
341ace01605SRiver Riddle     // block and that it branches on the same condition.
342ace01605SRiver Riddle     auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
343ace01605SRiver Riddle     if (!predBranch || condbr.getCondition() != predBranch.getCondition())
344ace01605SRiver Riddle       return failure();
345ace01605SRiver Riddle 
346ace01605SRiver Riddle     // Fold this branch to an unconditional branch.
347ace01605SRiver Riddle     if (currentBlock == predBranch.getTrueDest())
348ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
349ace01605SRiver Riddle                                             condbr.getTrueDestOperands());
350ace01605SRiver Riddle     else
351ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
352ace01605SRiver Riddle                                             condbr.getFalseDestOperands());
353ace01605SRiver Riddle     return success();
354ace01605SRiver Riddle   }
355ace01605SRiver Riddle };
356ace01605SRiver Riddle 
357ace01605SRiver Riddle ///   cf.cond_br %arg0, ^trueB, ^falseB
358ace01605SRiver Riddle ///
359ace01605SRiver Riddle /// ^trueB:
360ace01605SRiver Riddle ///   "test.consumer1"(%arg0) : (i1) -> ()
361ace01605SRiver Riddle ///    ...
362ace01605SRiver Riddle ///
363ace01605SRiver Riddle /// ^falseB:
364ace01605SRiver Riddle ///   "test.consumer2"(%arg0) : (i1) -> ()
365ace01605SRiver Riddle ///   ...
366ace01605SRiver Riddle ///
367ace01605SRiver Riddle /// ->
368ace01605SRiver Riddle ///
369ace01605SRiver Riddle ///   cf.cond_br %arg0, ^trueB, ^falseB
370ace01605SRiver Riddle /// ^trueB:
371ace01605SRiver Riddle ///   "test.consumer1"(%true) : (i1) -> ()
372ace01605SRiver Riddle ///   ...
373ace01605SRiver Riddle ///
374ace01605SRiver Riddle /// ^falseB:
375ace01605SRiver Riddle ///   "test.consumer2"(%false) : (i1) -> ()
376ace01605SRiver Riddle ///   ...
377ace01605SRiver Riddle struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
378ace01605SRiver Riddle   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
379ace01605SRiver Riddle 
matchAndRewrite__anonbb8b539b0211::CondBranchTruthPropagation380ace01605SRiver Riddle   LogicalResult matchAndRewrite(CondBranchOp condbr,
381ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override {
382ace01605SRiver Riddle     // Check that we have a single distinct predecessor.
383ace01605SRiver Riddle     bool replaced = false;
384ace01605SRiver Riddle     Type ty = rewriter.getI1Type();
385ace01605SRiver Riddle 
386ace01605SRiver Riddle     // These variables serve to prevent creating duplicate constants
387ace01605SRiver Riddle     // and hold constant true or false values.
388ace01605SRiver Riddle     Value constantTrue = nullptr;
389ace01605SRiver Riddle     Value constantFalse = nullptr;
390ace01605SRiver Riddle 
391ace01605SRiver Riddle     // TODO These checks can be expanded to encompas any use with only
392ace01605SRiver Riddle     // either the true of false edge as a predecessor. For now, we fall
393ace01605SRiver Riddle     // back to checking the single predecessor is given by the true/fasle
394ace01605SRiver Riddle     // destination, thereby ensuring that only that edge can reach the
395ace01605SRiver Riddle     // op.
396ace01605SRiver Riddle     if (condbr.getTrueDest()->getSinglePredecessor()) {
397ace01605SRiver Riddle       for (OpOperand &use :
398ace01605SRiver Riddle            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
399ace01605SRiver Riddle         if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
400ace01605SRiver Riddle           replaced = true;
401ace01605SRiver Riddle 
402ace01605SRiver Riddle           if (!constantTrue)
403ace01605SRiver Riddle             constantTrue = rewriter.create<arith::ConstantOp>(
404ace01605SRiver Riddle                 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
405ace01605SRiver Riddle 
406ace01605SRiver Riddle           rewriter.updateRootInPlace(use.getOwner(),
407ace01605SRiver Riddle                                      [&] { use.set(constantTrue); });
408ace01605SRiver Riddle         }
409ace01605SRiver Riddle       }
410ace01605SRiver Riddle     }
411ace01605SRiver Riddle     if (condbr.getFalseDest()->getSinglePredecessor()) {
412ace01605SRiver Riddle       for (OpOperand &use :
413ace01605SRiver Riddle            llvm::make_early_inc_range(condbr.getCondition().getUses())) {
414ace01605SRiver Riddle         if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
415ace01605SRiver Riddle           replaced = true;
416ace01605SRiver Riddle 
417ace01605SRiver Riddle           if (!constantFalse)
418ace01605SRiver Riddle             constantFalse = rewriter.create<arith::ConstantOp>(
419ace01605SRiver Riddle                 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
420ace01605SRiver Riddle 
421ace01605SRiver Riddle           rewriter.updateRootInPlace(use.getOwner(),
422ace01605SRiver Riddle                                      [&] { use.set(constantFalse); });
423ace01605SRiver Riddle         }
424ace01605SRiver Riddle       }
425ace01605SRiver Riddle     }
426ace01605SRiver Riddle     return success(replaced);
427ace01605SRiver Riddle   }
428ace01605SRiver Riddle };
429ace01605SRiver Riddle } // namespace
430ace01605SRiver Riddle 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)431ace01605SRiver Riddle void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
432ace01605SRiver Riddle                                                MLIRContext *context) {
433ace01605SRiver Riddle   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
434ace01605SRiver Riddle               SimplifyCondBranchIdenticalSuccessors,
435ace01605SRiver Riddle               SimplifyCondBranchFromCondBranchOnSameCondition,
436ace01605SRiver Riddle               CondBranchTruthPropagation>(context);
437ace01605SRiver Riddle }
438ace01605SRiver Riddle 
getSuccessorOperands(unsigned index)4390c789db5SMarkus Böck SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
440ace01605SRiver Riddle   assert(index < getNumSuccessors() && "invalid successor index");
4410c789db5SMarkus Böck   return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
4420c789db5SMarkus Böck                                               : getFalseDestOperandsMutable());
443ace01605SRiver Riddle }
444ace01605SRiver Riddle 
getSuccessorForOperands(ArrayRef<Attribute> operands)445ace01605SRiver Riddle Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
446ace01605SRiver Riddle   if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
447ace01605SRiver Riddle     return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
448ace01605SRiver Riddle   return nullptr;
449ace01605SRiver Riddle }
450ace01605SRiver Riddle 
451ace01605SRiver Riddle //===----------------------------------------------------------------------===//
452ace01605SRiver Riddle // SwitchOp
453ace01605SRiver Riddle //===----------------------------------------------------------------------===//
454ace01605SRiver Riddle 
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,DenseIntElementsAttr caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)455ace01605SRiver Riddle void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
456ace01605SRiver Riddle                      Block *defaultDestination, ValueRange defaultOperands,
457ace01605SRiver Riddle                      DenseIntElementsAttr caseValues,
458ace01605SRiver Riddle                      BlockRange caseDestinations,
459ace01605SRiver Riddle                      ArrayRef<ValueRange> caseOperands) {
460ace01605SRiver Riddle   build(builder, result, value, defaultOperands, caseOperands, caseValues,
461ace01605SRiver Riddle         defaultDestination, caseDestinations);
462ace01605SRiver Riddle }
463ace01605SRiver Riddle 
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<APInt> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)464ace01605SRiver Riddle void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
465ace01605SRiver Riddle                      Block *defaultDestination, ValueRange defaultOperands,
466ace01605SRiver Riddle                      ArrayRef<APInt> caseValues, BlockRange caseDestinations,
467ace01605SRiver Riddle                      ArrayRef<ValueRange> caseOperands) {
468ace01605SRiver Riddle   DenseIntElementsAttr caseValuesAttr;
469ace01605SRiver Riddle   if (!caseValues.empty()) {
470ace01605SRiver Riddle     ShapedType caseValueType = VectorType::get(
471ace01605SRiver Riddle         static_cast<int64_t>(caseValues.size()), value.getType());
472ace01605SRiver Riddle     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
473ace01605SRiver Riddle   }
474ace01605SRiver Riddle   build(builder, result, value, defaultDestination, defaultOperands,
475ace01605SRiver Riddle         caseValuesAttr, caseDestinations, caseOperands);
476ace01605SRiver Riddle }
477ace01605SRiver Riddle 
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<int32_t> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)478b34fb277SAlexander Batashev void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
479b34fb277SAlexander Batashev                      Block *defaultDestination, ValueRange defaultOperands,
480b34fb277SAlexander Batashev                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
481b34fb277SAlexander Batashev                      ArrayRef<ValueRange> caseOperands) {
482b34fb277SAlexander Batashev   DenseIntElementsAttr caseValuesAttr;
483b34fb277SAlexander Batashev   if (!caseValues.empty()) {
484b34fb277SAlexander Batashev     ShapedType caseValueType = VectorType::get(
485b34fb277SAlexander Batashev         static_cast<int64_t>(caseValues.size()), value.getType());
486b34fb277SAlexander Batashev     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
487b34fb277SAlexander Batashev   }
488b34fb277SAlexander Batashev   build(builder, result, value, defaultDestination, defaultOperands,
489b34fb277SAlexander Batashev         caseValuesAttr, caseDestinations, caseOperands);
490b34fb277SAlexander Batashev }
491b34fb277SAlexander Batashev 
492ace01605SRiver Riddle /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
493ace01605SRiver Riddle ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
parseSwitchOpCases(OpAsmParser & parser,Type & flagType,Block * & defaultDestination,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & defaultOperands,SmallVectorImpl<Type> & defaultOperandTypes,DenseIntElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> & caseOperands,SmallVectorImpl<SmallVector<Type>> & caseOperandTypes)494ace01605SRiver Riddle static ParseResult parseSwitchOpCases(
495ace01605SRiver Riddle     OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
496e13d23bcSMarkus Böck     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
497ace01605SRiver Riddle     SmallVectorImpl<Type> &defaultOperandTypes,
498ace01605SRiver Riddle     DenseIntElementsAttr &caseValues,
499ace01605SRiver Riddle     SmallVectorImpl<Block *> &caseDestinations,
500e13d23bcSMarkus Böck     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
501ace01605SRiver Riddle     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
502ace01605SRiver Riddle   if (parser.parseKeyword("default") || parser.parseColon() ||
503ace01605SRiver Riddle       parser.parseSuccessor(defaultDestination))
504ace01605SRiver Riddle     return failure();
505ace01605SRiver Riddle   if (succeeded(parser.parseOptionalLParen())) {
5065dedf911SChris Lattner     if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
5075dedf911SChris Lattner                                 /*allowResultNumber=*/false) ||
508ace01605SRiver Riddle         parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
509ace01605SRiver Riddle       return failure();
510ace01605SRiver Riddle   }
511ace01605SRiver Riddle 
512ace01605SRiver Riddle   SmallVector<APInt> values;
513ace01605SRiver Riddle   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
514ace01605SRiver Riddle   while (succeeded(parser.parseOptionalComma())) {
515ace01605SRiver Riddle     int64_t value = 0;
516ace01605SRiver Riddle     if (failed(parser.parseInteger(value)))
517ace01605SRiver Riddle       return failure();
518ace01605SRiver Riddle     values.push_back(APInt(bitWidth, value));
519ace01605SRiver Riddle 
520ace01605SRiver Riddle     Block *destination;
521e13d23bcSMarkus Böck     SmallVector<OpAsmParser::UnresolvedOperand> operands;
522ace01605SRiver Riddle     SmallVector<Type> operandTypes;
523ace01605SRiver Riddle     if (failed(parser.parseColon()) ||
524ace01605SRiver Riddle         failed(parser.parseSuccessor(destination)))
525ace01605SRiver Riddle       return failure();
526ace01605SRiver Riddle     if (succeeded(parser.parseOptionalLParen())) {
5275dedf911SChris Lattner       if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
5285dedf911SChris Lattner                                          /*allowResultNumber=*/false)) ||
529ace01605SRiver Riddle           failed(parser.parseColonTypeList(operandTypes)) ||
530ace01605SRiver Riddle           failed(parser.parseRParen()))
531ace01605SRiver Riddle         return failure();
532ace01605SRiver Riddle     }
533ace01605SRiver Riddle     caseDestinations.push_back(destination);
534ace01605SRiver Riddle     caseOperands.emplace_back(operands);
535ace01605SRiver Riddle     caseOperandTypes.emplace_back(operandTypes);
536ace01605SRiver Riddle   }
537ace01605SRiver Riddle 
538ace01605SRiver Riddle   if (!values.empty()) {
539ace01605SRiver Riddle     ShapedType caseValueType =
540ace01605SRiver Riddle         VectorType::get(static_cast<int64_t>(values.size()), flagType);
541ace01605SRiver Riddle     caseValues = DenseIntElementsAttr::get(caseValueType, values);
542ace01605SRiver Riddle   }
543ace01605SRiver Riddle   return success();
544ace01605SRiver Riddle }
545ace01605SRiver Riddle 
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,Type flagType,Block * defaultDestination,OperandRange defaultOperands,TypeRange defaultOperandTypes,DenseIntElementsAttr caseValues,SuccessorRange caseDestinations,OperandRangeRange caseOperands,const TypeRangeRange & caseOperandTypes)546ace01605SRiver Riddle static void printSwitchOpCases(
547ace01605SRiver Riddle     OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
548ace01605SRiver Riddle     OperandRange defaultOperands, TypeRange defaultOperandTypes,
549ace01605SRiver Riddle     DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
550ace01605SRiver Riddle     OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
551ace01605SRiver Riddle   p << "  default: ";
552ace01605SRiver Riddle   p.printSuccessorAndUseList(defaultDestination, defaultOperands);
553ace01605SRiver Riddle 
554ace01605SRiver Riddle   if (!caseValues)
555ace01605SRiver Riddle     return;
556ace01605SRiver Riddle 
557ace01605SRiver Riddle   for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
558ace01605SRiver Riddle     p << ',';
559ace01605SRiver Riddle     p.printNewline();
560ace01605SRiver Riddle     p << "  ";
561ace01605SRiver Riddle     p << it.value().getLimitedValue();
562ace01605SRiver Riddle     p << ": ";
563ace01605SRiver Riddle     p.printSuccessorAndUseList(caseDestinations[it.index()],
564ace01605SRiver Riddle                                caseOperands[it.index()]);
565ace01605SRiver Riddle   }
566ace01605SRiver Riddle   p.printNewline();
567ace01605SRiver Riddle }
568ace01605SRiver Riddle 
verify()569ace01605SRiver Riddle LogicalResult SwitchOp::verify() {
570ace01605SRiver Riddle   auto caseValues = getCaseValues();
571ace01605SRiver Riddle   auto caseDestinations = getCaseDestinations();
572ace01605SRiver Riddle 
573ace01605SRiver Riddle   if (!caseValues && caseDestinations.empty())
574ace01605SRiver Riddle     return success();
575ace01605SRiver Riddle 
576ace01605SRiver Riddle   Type flagType = getFlag().getType();
577ace01605SRiver Riddle   Type caseValueType = caseValues->getType().getElementType();
578ace01605SRiver Riddle   if (caseValueType != flagType)
579ace01605SRiver Riddle     return emitOpError() << "'flag' type (" << flagType
580ace01605SRiver Riddle                          << ") should match case value type (" << caseValueType
581ace01605SRiver Riddle                          << ")";
582ace01605SRiver Riddle 
583ace01605SRiver Riddle   if (caseValues &&
584ace01605SRiver Riddle       caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
585ace01605SRiver Riddle     return emitOpError() << "number of case values (" << caseValues->size()
586ace01605SRiver Riddle                          << ") should match number of "
587ace01605SRiver Riddle                             "case destinations ("
588ace01605SRiver Riddle                          << caseDestinations.size() << ")";
589ace01605SRiver Riddle   return success();
590ace01605SRiver Riddle }
591ace01605SRiver Riddle 
getSuccessorOperands(unsigned index)5920c789db5SMarkus Böck SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
593ace01605SRiver Riddle   assert(index < getNumSuccessors() && "invalid successor index");
5940c789db5SMarkus Böck   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
5950c789db5SMarkus Böck                                       : getCaseOperandsMutable(index - 1));
596ace01605SRiver Riddle }
597ace01605SRiver Riddle 
getSuccessorForOperands(ArrayRef<Attribute> operands)598ace01605SRiver Riddle Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
599ace01605SRiver Riddle   Optional<DenseIntElementsAttr> caseValues = getCaseValues();
600ace01605SRiver Riddle 
601ace01605SRiver Riddle   if (!caseValues)
602ace01605SRiver Riddle     return getDefaultDestination();
603ace01605SRiver Riddle 
604ace01605SRiver Riddle   SuccessorRange caseDests = getCaseDestinations();
605ace01605SRiver Riddle   if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
606ace01605SRiver Riddle     for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
607ace01605SRiver Riddle       if (it.value() == value.getValue())
608ace01605SRiver Riddle         return caseDests[it.index()];
609ace01605SRiver Riddle     return getDefaultDestination();
610ace01605SRiver Riddle   }
611ace01605SRiver Riddle   return nullptr;
612ace01605SRiver Riddle }
613ace01605SRiver Riddle 
614ace01605SRiver Riddle /// switch %flag : i32, [
615ace01605SRiver Riddle ///   default:  ^bb1
616ace01605SRiver Riddle /// ]
617ace01605SRiver Riddle ///  -> br ^bb1
simplifySwitchWithOnlyDefault(SwitchOp op,PatternRewriter & rewriter)618ace01605SRiver Riddle static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
619ace01605SRiver Riddle                                                    PatternRewriter &rewriter) {
620ace01605SRiver Riddle   if (!op.getCaseDestinations().empty())
621ace01605SRiver Riddle     return failure();
622ace01605SRiver Riddle 
623ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
624ace01605SRiver Riddle                                         op.getDefaultOperands());
625ace01605SRiver Riddle   return success();
626ace01605SRiver Riddle }
627ace01605SRiver Riddle 
628ace01605SRiver Riddle /// switch %flag : i32, [
629ace01605SRiver Riddle ///   default: ^bb1,
630ace01605SRiver Riddle ///   42: ^bb1,
631ace01605SRiver Riddle ///   43: ^bb2
632ace01605SRiver Riddle /// ]
633ace01605SRiver Riddle /// ->
634ace01605SRiver Riddle /// switch %flag : i32, [
635ace01605SRiver Riddle ///   default: ^bb1,
636ace01605SRiver Riddle ///   43: ^bb2
637ace01605SRiver Riddle /// ]
638ace01605SRiver Riddle static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op,PatternRewriter & rewriter)639ace01605SRiver Riddle dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
640ace01605SRiver Riddle   SmallVector<Block *> newCaseDestinations;
641ace01605SRiver Riddle   SmallVector<ValueRange> newCaseOperands;
642ace01605SRiver Riddle   SmallVector<APInt> newCaseValues;
643ace01605SRiver Riddle   bool requiresChange = false;
644ace01605SRiver Riddle   auto caseValues = op.getCaseValues();
645ace01605SRiver Riddle   auto caseDests = op.getCaseDestinations();
646ace01605SRiver Riddle 
647ace01605SRiver Riddle   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
648ace01605SRiver Riddle     if (caseDests[it.index()] == op.getDefaultDestination() &&
649ace01605SRiver Riddle         op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
650ace01605SRiver Riddle       requiresChange = true;
651ace01605SRiver Riddle       continue;
652ace01605SRiver Riddle     }
653ace01605SRiver Riddle     newCaseDestinations.push_back(caseDests[it.index()]);
654ace01605SRiver Riddle     newCaseOperands.push_back(op.getCaseOperands(it.index()));
655ace01605SRiver Riddle     newCaseValues.push_back(it.value());
656ace01605SRiver Riddle   }
657ace01605SRiver Riddle 
658ace01605SRiver Riddle   if (!requiresChange)
659ace01605SRiver Riddle     return failure();
660ace01605SRiver Riddle 
661ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<SwitchOp>(
662ace01605SRiver Riddle       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
663ace01605SRiver Riddle       newCaseValues, newCaseDestinations, newCaseOperands);
664ace01605SRiver Riddle   return success();
665ace01605SRiver Riddle }
666ace01605SRiver Riddle 
667ace01605SRiver Riddle /// Helper for folding a switch with a constant value.
668ace01605SRiver Riddle /// switch %c_42 : i32, [
669ace01605SRiver Riddle ///   default: ^bb1 ,
670ace01605SRiver Riddle ///   42: ^bb2,
671ace01605SRiver Riddle ///   43: ^bb3
672ace01605SRiver Riddle /// ]
673ace01605SRiver Riddle /// -> br ^bb2
foldSwitch(SwitchOp op,PatternRewriter & rewriter,const APInt & caseValue)674ace01605SRiver Riddle static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
675ace01605SRiver Riddle                        const APInt &caseValue) {
676ace01605SRiver Riddle   auto caseValues = op.getCaseValues();
677ace01605SRiver Riddle   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
678ace01605SRiver Riddle     if (it.value() == caseValue) {
679ace01605SRiver Riddle       rewriter.replaceOpWithNewOp<BranchOp>(
680ace01605SRiver Riddle           op, op.getCaseDestinations()[it.index()],
681ace01605SRiver Riddle           op.getCaseOperands(it.index()));
682ace01605SRiver Riddle       return;
683ace01605SRiver Riddle     }
684ace01605SRiver Riddle   }
685ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
686ace01605SRiver Riddle                                         op.getDefaultOperands());
687ace01605SRiver Riddle }
688ace01605SRiver Riddle 
689ace01605SRiver Riddle /// switch %c_42 : i32, [
690ace01605SRiver Riddle ///   default: ^bb1,
691ace01605SRiver Riddle ///   42: ^bb2,
692ace01605SRiver Riddle ///   43: ^bb3
693ace01605SRiver Riddle /// ]
694ace01605SRiver Riddle /// -> br ^bb2
simplifyConstSwitchValue(SwitchOp op,PatternRewriter & rewriter)695ace01605SRiver Riddle static LogicalResult simplifyConstSwitchValue(SwitchOp op,
696ace01605SRiver Riddle                                               PatternRewriter &rewriter) {
697ace01605SRiver Riddle   APInt caseValue;
698ace01605SRiver Riddle   if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
699ace01605SRiver Riddle     return failure();
700ace01605SRiver Riddle 
701ace01605SRiver Riddle   foldSwitch(op, rewriter, caseValue);
702ace01605SRiver Riddle   return success();
703ace01605SRiver Riddle }
704ace01605SRiver Riddle 
705ace01605SRiver Riddle /// switch %c_42 : i32, [
706ace01605SRiver Riddle ///   default: ^bb1,
707ace01605SRiver Riddle ///   42: ^bb2,
708ace01605SRiver Riddle /// ]
709ace01605SRiver Riddle /// ^bb2:
710ace01605SRiver Riddle ///   br ^bb3
711ace01605SRiver Riddle /// ->
712ace01605SRiver Riddle /// switch %c_42 : i32, [
713ace01605SRiver Riddle ///   default: ^bb1,
714ace01605SRiver Riddle ///   42: ^bb3,
715ace01605SRiver Riddle /// ]
simplifyPassThroughSwitch(SwitchOp op,PatternRewriter & rewriter)716ace01605SRiver Riddle static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
717ace01605SRiver Riddle                                                PatternRewriter &rewriter) {
718ace01605SRiver Riddle   SmallVector<Block *> newCaseDests;
719ace01605SRiver Riddle   SmallVector<ValueRange> newCaseOperands;
720ace01605SRiver Riddle   SmallVector<SmallVector<Value>> argStorage;
721ace01605SRiver Riddle   auto caseValues = op.getCaseValues();
722f735b3a2SVitaly Buka   argStorage.reserve(caseValues->size() + 1);
723ace01605SRiver Riddle   auto caseDests = op.getCaseDestinations();
724ace01605SRiver Riddle   bool requiresChange = false;
725ace01605SRiver Riddle   for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
726ace01605SRiver Riddle     Block *caseDest = caseDests[i];
727ace01605SRiver Riddle     ValueRange caseOperands = op.getCaseOperands(i);
728ace01605SRiver Riddle     argStorage.emplace_back();
729ace01605SRiver Riddle     if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
730ace01605SRiver Riddle       requiresChange = true;
731ace01605SRiver Riddle 
732ace01605SRiver Riddle     newCaseDests.push_back(caseDest);
733ace01605SRiver Riddle     newCaseOperands.push_back(caseOperands);
734ace01605SRiver Riddle   }
735ace01605SRiver Riddle 
736ace01605SRiver Riddle   Block *defaultDest = op.getDefaultDestination();
737ace01605SRiver Riddle   ValueRange defaultOperands = op.getDefaultOperands();
738ace01605SRiver Riddle   argStorage.emplace_back();
739ace01605SRiver Riddle 
740ace01605SRiver Riddle   if (succeeded(
741ace01605SRiver Riddle           collapseBranch(defaultDest, defaultOperands, argStorage.back())))
742ace01605SRiver Riddle     requiresChange = true;
743ace01605SRiver Riddle 
744ace01605SRiver Riddle   if (!requiresChange)
745ace01605SRiver Riddle     return failure();
746ace01605SRiver Riddle 
747ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
748*6d5fc1e3SKazu Hirata                                         defaultOperands, *caseValues,
749ace01605SRiver Riddle                                         newCaseDests, newCaseOperands);
750ace01605SRiver Riddle   return success();
751ace01605SRiver Riddle }
752ace01605SRiver Riddle 
753ace01605SRiver Riddle /// switch %flag : i32, [
754ace01605SRiver Riddle ///   default: ^bb1,
755ace01605SRiver Riddle ///   42: ^bb2,
756ace01605SRiver Riddle /// ]
757ace01605SRiver Riddle /// ^bb2:
758ace01605SRiver Riddle ///   switch %flag : i32, [
759ace01605SRiver Riddle ///     default: ^bb3,
760ace01605SRiver Riddle ///     42: ^bb4
761ace01605SRiver Riddle ///   ]
762ace01605SRiver Riddle /// ->
763ace01605SRiver Riddle /// switch %flag : i32, [
764ace01605SRiver Riddle ///   default: ^bb1,
765ace01605SRiver Riddle ///   42: ^bb2,
766ace01605SRiver Riddle /// ]
767ace01605SRiver Riddle /// ^bb2:
768ace01605SRiver Riddle ///   br ^bb4
769ace01605SRiver Riddle ///
770ace01605SRiver Riddle ///  and
771ace01605SRiver Riddle ///
772ace01605SRiver Riddle /// switch %flag : i32, [
773ace01605SRiver Riddle ///   default: ^bb1,
774ace01605SRiver Riddle ///   42: ^bb2,
775ace01605SRiver Riddle /// ]
776ace01605SRiver Riddle /// ^bb2:
777ace01605SRiver Riddle ///   switch %flag : i32, [
778ace01605SRiver Riddle ///     default: ^bb3,
779ace01605SRiver Riddle ///     43: ^bb4
780ace01605SRiver Riddle ///   ]
781ace01605SRiver Riddle /// ->
782ace01605SRiver Riddle /// switch %flag : i32, [
783ace01605SRiver Riddle ///   default: ^bb1,
784ace01605SRiver Riddle ///   42: ^bb2,
785ace01605SRiver Riddle /// ]
786ace01605SRiver Riddle /// ^bb2:
787ace01605SRiver Riddle ///   br ^bb3
788ace01605SRiver Riddle static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)789ace01605SRiver Riddle simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
790ace01605SRiver Riddle                                         PatternRewriter &rewriter) {
791ace01605SRiver Riddle   // Check that we have a single distinct predecessor.
792ace01605SRiver Riddle   Block *currentBlock = op->getBlock();
793ace01605SRiver Riddle   Block *predecessor = currentBlock->getSinglePredecessor();
794ace01605SRiver Riddle   if (!predecessor)
795ace01605SRiver Riddle     return failure();
796ace01605SRiver Riddle 
797ace01605SRiver Riddle   // Check that the predecessor terminates with a switch branch to this block
798ace01605SRiver Riddle   // and that it branches on the same condition and that this branch isn't the
799ace01605SRiver Riddle   // default destination.
800ace01605SRiver Riddle   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
801ace01605SRiver Riddle   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
802ace01605SRiver Riddle       predSwitch.getDefaultDestination() == currentBlock)
803ace01605SRiver Riddle     return failure();
804ace01605SRiver Riddle 
805ace01605SRiver Riddle   // Fold this switch to an unconditional branch.
806ace01605SRiver Riddle   SuccessorRange predDests = predSwitch.getCaseDestinations();
807ace01605SRiver Riddle   auto it = llvm::find(predDests, currentBlock);
808ace01605SRiver Riddle   if (it != predDests.end()) {
809ace01605SRiver Riddle     Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
810ace01605SRiver Riddle     foldSwitch(op, rewriter,
811ace01605SRiver Riddle                predCaseValues->getValues<APInt>()[it - predDests.begin()]);
812ace01605SRiver Riddle   } else {
813ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
814ace01605SRiver Riddle                                           op.getDefaultOperands());
815ace01605SRiver Riddle   }
816ace01605SRiver Riddle   return success();
817ace01605SRiver Riddle }
818ace01605SRiver Riddle 
819ace01605SRiver Riddle /// switch %flag : i32, [
820ace01605SRiver Riddle ///   default: ^bb1,
821ace01605SRiver Riddle ///   42: ^bb2
822ace01605SRiver Riddle /// ]
823ace01605SRiver Riddle /// ^bb1:
824ace01605SRiver Riddle ///   switch %flag : i32, [
825ace01605SRiver Riddle ///     default: ^bb3,
826ace01605SRiver Riddle ///     42: ^bb4,
827ace01605SRiver Riddle ///     43: ^bb5
828ace01605SRiver Riddle ///   ]
829ace01605SRiver Riddle /// ->
830ace01605SRiver Riddle /// switch %flag : i32, [
831ace01605SRiver Riddle ///   default: ^bb1,
832ace01605SRiver Riddle ///   42: ^bb2,
833ace01605SRiver Riddle /// ]
834ace01605SRiver Riddle /// ^bb1:
835ace01605SRiver Riddle ///   switch %flag : i32, [
836ace01605SRiver Riddle ///     default: ^bb3,
837ace01605SRiver Riddle ///     43: ^bb5
838ace01605SRiver Riddle ///   ]
839ace01605SRiver Riddle static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)840ace01605SRiver Riddle simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
841ace01605SRiver Riddle                                                PatternRewriter &rewriter) {
842ace01605SRiver Riddle   // Check that we have a single distinct predecessor.
843ace01605SRiver Riddle   Block *currentBlock = op->getBlock();
844ace01605SRiver Riddle   Block *predecessor = currentBlock->getSinglePredecessor();
845ace01605SRiver Riddle   if (!predecessor)
846ace01605SRiver Riddle     return failure();
847ace01605SRiver Riddle 
848ace01605SRiver Riddle   // Check that the predecessor terminates with a switch branch to this block
849ace01605SRiver Riddle   // and that it branches on the same condition and that this branch is the
850ace01605SRiver Riddle   // default destination.
851ace01605SRiver Riddle   auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
852ace01605SRiver Riddle   if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
853ace01605SRiver Riddle       predSwitch.getDefaultDestination() != currentBlock)
854ace01605SRiver Riddle     return failure();
855ace01605SRiver Riddle 
856ace01605SRiver Riddle   // Delete case values that are not possible here.
857ace01605SRiver Riddle   DenseSet<APInt> caseValuesToRemove;
858ace01605SRiver Riddle   auto predDests = predSwitch.getCaseDestinations();
859ace01605SRiver Riddle   auto predCaseValues = predSwitch.getCaseValues();
860ace01605SRiver Riddle   for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
861ace01605SRiver Riddle     if (currentBlock != predDests[i])
862ace01605SRiver Riddle       caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
863ace01605SRiver Riddle 
864ace01605SRiver Riddle   SmallVector<Block *> newCaseDestinations;
865ace01605SRiver Riddle   SmallVector<ValueRange> newCaseOperands;
866ace01605SRiver Riddle   SmallVector<APInt> newCaseValues;
867ace01605SRiver Riddle   bool requiresChange = false;
868ace01605SRiver Riddle 
869ace01605SRiver Riddle   auto caseValues = op.getCaseValues();
870ace01605SRiver Riddle   auto caseDests = op.getCaseDestinations();
871ace01605SRiver Riddle   for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
872ace01605SRiver Riddle     if (caseValuesToRemove.contains(it.value())) {
873ace01605SRiver Riddle       requiresChange = true;
874ace01605SRiver Riddle       continue;
875ace01605SRiver Riddle     }
876ace01605SRiver Riddle     newCaseDestinations.push_back(caseDests[it.index()]);
877ace01605SRiver Riddle     newCaseOperands.push_back(op.getCaseOperands(it.index()));
878ace01605SRiver Riddle     newCaseValues.push_back(it.value());
879ace01605SRiver Riddle   }
880ace01605SRiver Riddle 
881ace01605SRiver Riddle   if (!requiresChange)
882ace01605SRiver Riddle     return failure();
883ace01605SRiver Riddle 
884ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<SwitchOp>(
885ace01605SRiver Riddle       op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
886ace01605SRiver Riddle       newCaseValues, newCaseDestinations, newCaseOperands);
887ace01605SRiver Riddle   return success();
888ace01605SRiver Riddle }
889ace01605SRiver Riddle 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)890ace01605SRiver Riddle void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
891ace01605SRiver Riddle                                            MLIRContext *context) {
892ace01605SRiver Riddle   results.add(&simplifySwitchWithOnlyDefault)
893ace01605SRiver Riddle       .add(&dropSwitchCasesThatMatchDefault)
894ace01605SRiver Riddle       .add(&simplifyConstSwitchValue)
895ace01605SRiver Riddle       .add(&simplifyPassThroughSwitch)
896ace01605SRiver Riddle       .add(&simplifySwitchFromSwitchOnSameCondition)
897ace01605SRiver Riddle       .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
898ace01605SRiver Riddle }
899ace01605SRiver Riddle 
900ace01605SRiver Riddle //===----------------------------------------------------------------------===//
901ace01605SRiver Riddle // TableGen'd op method definitions
902ace01605SRiver Riddle //===----------------------------------------------------------------------===//
903ace01605SRiver Riddle 
904ace01605SRiver Riddle #define GET_OP_CLASSES
905ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
906