1ace01605SRiver Riddle //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2ace01605SRiver Riddle //
3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ace01605SRiver Riddle //
7ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8ace01605SRiver Riddle
9ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10ace01605SRiver Riddle
11ace01605SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12ace01605SRiver Riddle #include "mlir/Dialect/CommonFolders.h"
13ace01605SRiver Riddle #include "mlir/IR/AffineExpr.h"
14ace01605SRiver Riddle #include "mlir/IR/AffineMap.h"
15ace01605SRiver Riddle #include "mlir/IR/BlockAndValueMapping.h"
16ace01605SRiver Riddle #include "mlir/IR/Builders.h"
17ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h"
18ace01605SRiver Riddle #include "mlir/IR/BuiltinTypes.h"
19ace01605SRiver Riddle #include "mlir/IR/Matchers.h"
20ace01605SRiver Riddle #include "mlir/IR/OpImplementation.h"
21ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h"
22ace01605SRiver Riddle #include "mlir/IR/TypeUtilities.h"
23ace01605SRiver Riddle #include "mlir/IR/Value.h"
24ace01605SRiver Riddle #include "mlir/Support/MathExtras.h"
25ace01605SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
26ace01605SRiver Riddle #include "llvm/ADT/APFloat.h"
27ace01605SRiver Riddle #include "llvm/ADT/STLExtras.h"
28ace01605SRiver Riddle #include "llvm/ADT/StringSwitch.h"
29ace01605SRiver Riddle #include "llvm/Support/FormatVariadic.h"
30ace01605SRiver Riddle #include "llvm/Support/raw_ostream.h"
31ace01605SRiver Riddle #include <numeric>
32ace01605SRiver Riddle
33ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
34ace01605SRiver Riddle
35ace01605SRiver Riddle using namespace mlir;
36ace01605SRiver Riddle using namespace mlir::cf;
37ace01605SRiver Riddle
38ace01605SRiver Riddle //===----------------------------------------------------------------------===//
39ace01605SRiver Riddle // ControlFlowDialect Interfaces
40ace01605SRiver Riddle //===----------------------------------------------------------------------===//
41ace01605SRiver Riddle namespace {
42ace01605SRiver Riddle /// This class defines the interface for handling inlining with control flow
43ace01605SRiver Riddle /// operations.
44ace01605SRiver Riddle struct ControlFlowInlinerInterface : public DialectInlinerInterface {
45ace01605SRiver Riddle using DialectInlinerInterface::DialectInlinerInterface;
46ace01605SRiver Riddle ~ControlFlowInlinerInterface() override = default;
47ace01605SRiver Riddle
48ace01605SRiver Riddle /// All control flow operations can be inlined.
isLegalToInline__anonbb8b539b0111::ControlFlowInlinerInterface49ace01605SRiver Riddle bool isLegalToInline(Operation *call, Operation *callable,
50ace01605SRiver Riddle bool wouldBeCloned) const final {
51ace01605SRiver Riddle return true;
52ace01605SRiver Riddle }
isLegalToInline__anonbb8b539b0111::ControlFlowInlinerInterface53ace01605SRiver Riddle bool isLegalToInline(Operation *, Region *, bool,
54ace01605SRiver Riddle BlockAndValueMapping &) const final {
55ace01605SRiver Riddle return true;
56ace01605SRiver Riddle }
57ace01605SRiver Riddle
58ace01605SRiver Riddle /// ControlFlow terminator operations don't really need any special handing.
handleTerminator__anonbb8b539b0111::ControlFlowInlinerInterface59ace01605SRiver Riddle void handleTerminator(Operation *op, Block *newDest) const final {}
60ace01605SRiver Riddle };
61ace01605SRiver Riddle } // namespace
62ace01605SRiver Riddle
63ace01605SRiver Riddle //===----------------------------------------------------------------------===//
64ace01605SRiver Riddle // ControlFlowDialect
65ace01605SRiver Riddle //===----------------------------------------------------------------------===//
66ace01605SRiver Riddle
initialize()67ace01605SRiver Riddle void ControlFlowDialect::initialize() {
68ace01605SRiver Riddle addOperations<
69ace01605SRiver Riddle #define GET_OP_LIST
70ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
71ace01605SRiver Riddle >();
72ace01605SRiver Riddle addInterfaces<ControlFlowInlinerInterface>();
73ace01605SRiver Riddle }
74ace01605SRiver Riddle
75ace01605SRiver Riddle //===----------------------------------------------------------------------===//
76ace01605SRiver Riddle // AssertOp
77ace01605SRiver Riddle //===----------------------------------------------------------------------===//
78ace01605SRiver Riddle
canonicalize(AssertOp op,PatternRewriter & rewriter)79ace01605SRiver Riddle LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
80ace01605SRiver Riddle // Erase assertion if argument is constant true.
81ace01605SRiver Riddle if (matchPattern(op.getArg(), m_One())) {
82ace01605SRiver Riddle rewriter.eraseOp(op);
83ace01605SRiver Riddle return success();
84ace01605SRiver Riddle }
85ace01605SRiver Riddle return failure();
86ace01605SRiver Riddle }
87ace01605SRiver Riddle
88ace01605SRiver Riddle //===----------------------------------------------------------------------===//
89ace01605SRiver Riddle // BranchOp
90ace01605SRiver Riddle //===----------------------------------------------------------------------===//
91ace01605SRiver Riddle
92ace01605SRiver Riddle /// Given a successor, try to collapse it to a new destination if it only
93ace01605SRiver Riddle /// contains a passthrough unconditional branch. If the successor is
94ace01605SRiver Riddle /// collapsable, `successor` and `successorOperands` are updated to reference
95ace01605SRiver Riddle /// the new destination and values. `argStorage` is used as storage if operands
96ace01605SRiver Riddle /// to the collapsed successor need to be remapped. It must outlive uses of
97ace01605SRiver Riddle /// successorOperands.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)98ace01605SRiver Riddle static LogicalResult collapseBranch(Block *&successor,
99ace01605SRiver Riddle ValueRange &successorOperands,
100ace01605SRiver Riddle SmallVectorImpl<Value> &argStorage) {
101ace01605SRiver Riddle // Check that the successor only contains a unconditional branch.
102ace01605SRiver Riddle if (std::next(successor->begin()) != successor->end())
103ace01605SRiver Riddle return failure();
104ace01605SRiver Riddle // Check that the terminator is an unconditional branch.
105ace01605SRiver Riddle BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
106ace01605SRiver Riddle if (!successorBranch)
107ace01605SRiver Riddle return failure();
108ace01605SRiver Riddle // Check that the arguments are only used within the terminator.
109ace01605SRiver Riddle for (BlockArgument arg : successor->getArguments()) {
110ace01605SRiver Riddle for (Operation *user : arg.getUsers())
111ace01605SRiver Riddle if (user != successorBranch)
112ace01605SRiver Riddle return failure();
113ace01605SRiver Riddle }
114ace01605SRiver Riddle // Don't try to collapse branches to infinite loops.
115ace01605SRiver Riddle Block *successorDest = successorBranch.getDest();
116ace01605SRiver Riddle if (successorDest == successor)
117ace01605SRiver Riddle return failure();
118ace01605SRiver Riddle
119ace01605SRiver Riddle // Update the operands to the successor. If the branch parent has no
120ace01605SRiver Riddle // arguments, we can use the branch operands directly.
121ace01605SRiver Riddle OperandRange operands = successorBranch.getOperands();
122ace01605SRiver Riddle if (successor->args_empty()) {
123ace01605SRiver Riddle successor = successorDest;
124ace01605SRiver Riddle successorOperands = operands;
125ace01605SRiver Riddle return success();
126ace01605SRiver Riddle }
127ace01605SRiver Riddle
128ace01605SRiver Riddle // Otherwise, we need to remap any argument operands.
129ace01605SRiver Riddle for (Value operand : operands) {
130ace01605SRiver Riddle BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
131ace01605SRiver Riddle if (argOperand && argOperand.getOwner() == successor)
132ace01605SRiver Riddle argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
133ace01605SRiver Riddle else
134ace01605SRiver Riddle argStorage.push_back(operand);
135ace01605SRiver Riddle }
136ace01605SRiver Riddle successor = successorDest;
137ace01605SRiver Riddle successorOperands = argStorage;
138ace01605SRiver Riddle return success();
139ace01605SRiver Riddle }
140ace01605SRiver Riddle
141ace01605SRiver Riddle /// Simplify a branch to a block that has a single predecessor. This effectively
142ace01605SRiver Riddle /// merges the two blocks.
143ace01605SRiver Riddle static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op,PatternRewriter & rewriter)144ace01605SRiver Riddle simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
145ace01605SRiver Riddle // Check that the successor block has a single predecessor.
146ace01605SRiver Riddle Block *succ = op.getDest();
147ace01605SRiver Riddle Block *opParent = op->getBlock();
148ace01605SRiver Riddle if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
149ace01605SRiver Riddle return failure();
150ace01605SRiver Riddle
151ace01605SRiver Riddle // Merge the successor into the current block and erase the branch.
152ace01605SRiver Riddle rewriter.mergeBlocks(succ, opParent, op.getOperands());
153ace01605SRiver Riddle rewriter.eraseOp(op);
154ace01605SRiver Riddle return success();
155ace01605SRiver Riddle }
156ace01605SRiver Riddle
157ace01605SRiver Riddle /// br ^bb1
158ace01605SRiver Riddle /// ^bb1
159ace01605SRiver Riddle /// br ^bbN(...)
160ace01605SRiver Riddle ///
161ace01605SRiver Riddle /// -> br ^bbN(...)
162ace01605SRiver Riddle ///
simplifyPassThroughBr(BranchOp op,PatternRewriter & rewriter)163ace01605SRiver Riddle static LogicalResult simplifyPassThroughBr(BranchOp op,
164ace01605SRiver Riddle PatternRewriter &rewriter) {
165ace01605SRiver Riddle Block *dest = op.getDest();
166ace01605SRiver Riddle ValueRange destOperands = op.getOperands();
167ace01605SRiver Riddle SmallVector<Value, 4> destOperandStorage;
168ace01605SRiver Riddle
169ace01605SRiver Riddle // Try to collapse the successor if it points somewhere other than this
170ace01605SRiver Riddle // block.
171ace01605SRiver Riddle if (dest == op->getBlock() ||
172ace01605SRiver Riddle failed(collapseBranch(dest, destOperands, destOperandStorage)))
173ace01605SRiver Riddle return failure();
174ace01605SRiver Riddle
175ace01605SRiver Riddle // Create a new branch with the collapsed successor.
176ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
177ace01605SRiver Riddle return success();
178ace01605SRiver Riddle }
179ace01605SRiver Riddle
canonicalize(BranchOp op,PatternRewriter & rewriter)180ace01605SRiver Riddle LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
181ace01605SRiver Riddle return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
182ace01605SRiver Riddle succeeded(simplifyPassThroughBr(op, rewriter)));
183ace01605SRiver Riddle }
184ace01605SRiver Riddle
setDest(Block * block)185ace01605SRiver Riddle void BranchOp::setDest(Block *block) { return setSuccessor(block); }
186ace01605SRiver Riddle
eraseOperand(unsigned index)187ace01605SRiver Riddle void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
188ace01605SRiver Riddle
getSuccessorOperands(unsigned index)1890c789db5SMarkus Böck SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
190ace01605SRiver Riddle assert(index == 0 && "invalid successor index");
1910c789db5SMarkus Böck return SuccessorOperands(getDestOperandsMutable());
192ace01605SRiver Riddle }
193ace01605SRiver Riddle
getSuccessorForOperands(ArrayRef<Attribute>)194ace01605SRiver Riddle Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
195ace01605SRiver Riddle return getDest();
196ace01605SRiver Riddle }
197ace01605SRiver Riddle
198ace01605SRiver Riddle //===----------------------------------------------------------------------===//
199ace01605SRiver Riddle // CondBranchOp
200ace01605SRiver Riddle //===----------------------------------------------------------------------===//
201ace01605SRiver Riddle
202ace01605SRiver Riddle namespace {
203ace01605SRiver Riddle /// cf.cond_br true, ^bb1, ^bb2
204ace01605SRiver Riddle /// -> br ^bb1
205ace01605SRiver Riddle /// cf.cond_br false, ^bb1, ^bb2
206ace01605SRiver Riddle /// -> br ^bb2
207ace01605SRiver Riddle ///
208ace01605SRiver Riddle struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
209ace01605SRiver Riddle using OpRewritePattern<CondBranchOp>::OpRewritePattern;
210ace01605SRiver Riddle
matchAndRewrite__anonbb8b539b0211::SimplifyConstCondBranchPred211ace01605SRiver Riddle LogicalResult matchAndRewrite(CondBranchOp condbr,
212ace01605SRiver Riddle PatternRewriter &rewriter) const override {
213ace01605SRiver Riddle if (matchPattern(condbr.getCondition(), m_NonZero())) {
214ace01605SRiver Riddle // True branch taken.
215ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
216ace01605SRiver Riddle condbr.getTrueOperands());
217ace01605SRiver Riddle return success();
218ace01605SRiver Riddle }
219ace01605SRiver Riddle if (matchPattern(condbr.getCondition(), m_Zero())) {
220ace01605SRiver Riddle // False branch taken.
221ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
222ace01605SRiver Riddle condbr.getFalseOperands());
223ace01605SRiver Riddle return success();
224ace01605SRiver Riddle }
225ace01605SRiver Riddle return failure();
226ace01605SRiver Riddle }
227ace01605SRiver Riddle };
228ace01605SRiver Riddle
229ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1, ^bb2
230ace01605SRiver Riddle /// ^bb1
231ace01605SRiver Riddle /// br ^bbN(...)
232ace01605SRiver Riddle /// ^bb2
233ace01605SRiver Riddle /// br ^bbK(...)
234ace01605SRiver Riddle ///
235ace01605SRiver Riddle /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
236ace01605SRiver Riddle ///
237ace01605SRiver Riddle struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
238ace01605SRiver Riddle using OpRewritePattern<CondBranchOp>::OpRewritePattern;
239ace01605SRiver Riddle
matchAndRewrite__anonbb8b539b0211::SimplifyPassThroughCondBranch240ace01605SRiver Riddle LogicalResult matchAndRewrite(CondBranchOp condbr,
241ace01605SRiver Riddle PatternRewriter &rewriter) const override {
242ace01605SRiver Riddle Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
243ace01605SRiver Riddle ValueRange trueDestOperands = condbr.getTrueOperands();
244ace01605SRiver Riddle ValueRange falseDestOperands = condbr.getFalseOperands();
245ace01605SRiver Riddle SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
246ace01605SRiver Riddle
247ace01605SRiver Riddle // Try to collapse one of the current successors.
248ace01605SRiver Riddle LogicalResult collapsedTrue =
249ace01605SRiver Riddle collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
250ace01605SRiver Riddle LogicalResult collapsedFalse =
251ace01605SRiver Riddle collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
252ace01605SRiver Riddle if (failed(collapsedTrue) && failed(collapsedFalse))
253ace01605SRiver Riddle return failure();
254ace01605SRiver Riddle
255ace01605SRiver Riddle // Create a new branch with the collapsed successors.
256ace01605SRiver Riddle rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
257ace01605SRiver Riddle trueDest, trueDestOperands,
258ace01605SRiver Riddle falseDest, falseDestOperands);
259ace01605SRiver Riddle return success();
260ace01605SRiver Riddle }
261ace01605SRiver Riddle };
262ace01605SRiver Riddle
263ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
264ace01605SRiver Riddle /// -> br ^bb1(A, ..., N)
265ace01605SRiver Riddle ///
266ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
267ace01605SRiver Riddle /// -> %select = arith.select %cond, A, B
268ace01605SRiver Riddle /// br ^bb1(%select)
269ace01605SRiver Riddle ///
270ace01605SRiver Riddle struct SimplifyCondBranchIdenticalSuccessors
271ace01605SRiver Riddle : public OpRewritePattern<CondBranchOp> {
272ace01605SRiver Riddle using OpRewritePattern<CondBranchOp>::OpRewritePattern;
273ace01605SRiver Riddle
matchAndRewrite__anonbb8b539b0211::SimplifyCondBranchIdenticalSuccessors274ace01605SRiver Riddle LogicalResult matchAndRewrite(CondBranchOp condbr,
275ace01605SRiver Riddle PatternRewriter &rewriter) const override {
276ace01605SRiver Riddle // Check that the true and false destinations are the same and have the same
277ace01605SRiver Riddle // operands.
278ace01605SRiver Riddle Block *trueDest = condbr.getTrueDest();
279ace01605SRiver Riddle if (trueDest != condbr.getFalseDest())
280ace01605SRiver Riddle return failure();
281ace01605SRiver Riddle
282ace01605SRiver Riddle // If all of the operands match, no selects need to be generated.
283ace01605SRiver Riddle OperandRange trueOperands = condbr.getTrueOperands();
284ace01605SRiver Riddle OperandRange falseOperands = condbr.getFalseOperands();
285ace01605SRiver Riddle if (trueOperands == falseOperands) {
286ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
287ace01605SRiver Riddle return success();
288ace01605SRiver Riddle }
289ace01605SRiver Riddle
290ace01605SRiver Riddle // Otherwise, if the current block is the only predecessor insert selects
291ace01605SRiver Riddle // for any mismatched branch operands.
292ace01605SRiver Riddle if (trueDest->getUniquePredecessor() != condbr->getBlock())
293ace01605SRiver Riddle return failure();
294ace01605SRiver Riddle
295ace01605SRiver Riddle // Generate a select for any operands that differ between the two.
296ace01605SRiver Riddle SmallVector<Value, 8> mergedOperands;
297ace01605SRiver Riddle mergedOperands.reserve(trueOperands.size());
298ace01605SRiver Riddle Value condition = condbr.getCondition();
299ace01605SRiver Riddle for (auto it : llvm::zip(trueOperands, falseOperands)) {
300ace01605SRiver Riddle if (std::get<0>(it) == std::get<1>(it))
301ace01605SRiver Riddle mergedOperands.push_back(std::get<0>(it));
302ace01605SRiver Riddle else
303ace01605SRiver Riddle mergedOperands.push_back(rewriter.create<arith::SelectOp>(
304ace01605SRiver Riddle condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
305ace01605SRiver Riddle }
306ace01605SRiver Riddle
307ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
308ace01605SRiver Riddle return success();
309ace01605SRiver Riddle }
310ace01605SRiver Riddle };
311ace01605SRiver Riddle
312ace01605SRiver Riddle /// ...
313ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
314ace01605SRiver Riddle /// ...
315ace01605SRiver Riddle /// ^bb1: // has single predecessor
316ace01605SRiver Riddle /// ...
317ace01605SRiver Riddle /// cf.cond_br %cond, ^bb3(...), ^bb4(...)
318ace01605SRiver Riddle ///
319ace01605SRiver Riddle /// ->
320ace01605SRiver Riddle ///
321ace01605SRiver Riddle /// ...
322ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
323ace01605SRiver Riddle /// ...
324ace01605SRiver Riddle /// ^bb1: // has single predecessor
325ace01605SRiver Riddle /// ...
326ace01605SRiver Riddle /// br ^bb3(...)
327ace01605SRiver Riddle ///
328ace01605SRiver Riddle struct SimplifyCondBranchFromCondBranchOnSameCondition
329ace01605SRiver Riddle : public OpRewritePattern<CondBranchOp> {
330ace01605SRiver Riddle using OpRewritePattern<CondBranchOp>::OpRewritePattern;
331ace01605SRiver Riddle
matchAndRewrite__anonbb8b539b0211::SimplifyCondBranchFromCondBranchOnSameCondition332ace01605SRiver Riddle LogicalResult matchAndRewrite(CondBranchOp condbr,
333ace01605SRiver Riddle PatternRewriter &rewriter) const override {
334ace01605SRiver Riddle // Check that we have a single distinct predecessor.
335ace01605SRiver Riddle Block *currentBlock = condbr->getBlock();
336ace01605SRiver Riddle Block *predecessor = currentBlock->getSinglePredecessor();
337ace01605SRiver Riddle if (!predecessor)
338ace01605SRiver Riddle return failure();
339ace01605SRiver Riddle
340ace01605SRiver Riddle // Check that the predecessor terminates with a conditional branch to this
341ace01605SRiver Riddle // block and that it branches on the same condition.
342ace01605SRiver Riddle auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
343ace01605SRiver Riddle if (!predBranch || condbr.getCondition() != predBranch.getCondition())
344ace01605SRiver Riddle return failure();
345ace01605SRiver Riddle
346ace01605SRiver Riddle // Fold this branch to an unconditional branch.
347ace01605SRiver Riddle if (currentBlock == predBranch.getTrueDest())
348ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
349ace01605SRiver Riddle condbr.getTrueDestOperands());
350ace01605SRiver Riddle else
351ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
352ace01605SRiver Riddle condbr.getFalseDestOperands());
353ace01605SRiver Riddle return success();
354ace01605SRiver Riddle }
355ace01605SRiver Riddle };
356ace01605SRiver Riddle
357ace01605SRiver Riddle /// cf.cond_br %arg0, ^trueB, ^falseB
358ace01605SRiver Riddle ///
359ace01605SRiver Riddle /// ^trueB:
360ace01605SRiver Riddle /// "test.consumer1"(%arg0) : (i1) -> ()
361ace01605SRiver Riddle /// ...
362ace01605SRiver Riddle ///
363ace01605SRiver Riddle /// ^falseB:
364ace01605SRiver Riddle /// "test.consumer2"(%arg0) : (i1) -> ()
365ace01605SRiver Riddle /// ...
366ace01605SRiver Riddle ///
367ace01605SRiver Riddle /// ->
368ace01605SRiver Riddle ///
369ace01605SRiver Riddle /// cf.cond_br %arg0, ^trueB, ^falseB
370ace01605SRiver Riddle /// ^trueB:
371ace01605SRiver Riddle /// "test.consumer1"(%true) : (i1) -> ()
372ace01605SRiver Riddle /// ...
373ace01605SRiver Riddle ///
374ace01605SRiver Riddle /// ^falseB:
375ace01605SRiver Riddle /// "test.consumer2"(%false) : (i1) -> ()
376ace01605SRiver Riddle /// ...
377ace01605SRiver Riddle struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
378ace01605SRiver Riddle using OpRewritePattern<CondBranchOp>::OpRewritePattern;
379ace01605SRiver Riddle
matchAndRewrite__anonbb8b539b0211::CondBranchTruthPropagation380ace01605SRiver Riddle LogicalResult matchAndRewrite(CondBranchOp condbr,
381ace01605SRiver Riddle PatternRewriter &rewriter) const override {
382ace01605SRiver Riddle // Check that we have a single distinct predecessor.
383ace01605SRiver Riddle bool replaced = false;
384ace01605SRiver Riddle Type ty = rewriter.getI1Type();
385ace01605SRiver Riddle
386ace01605SRiver Riddle // These variables serve to prevent creating duplicate constants
387ace01605SRiver Riddle // and hold constant true or false values.
388ace01605SRiver Riddle Value constantTrue = nullptr;
389ace01605SRiver Riddle Value constantFalse = nullptr;
390ace01605SRiver Riddle
391ace01605SRiver Riddle // TODO These checks can be expanded to encompas any use with only
392ace01605SRiver Riddle // either the true of false edge as a predecessor. For now, we fall
393ace01605SRiver Riddle // back to checking the single predecessor is given by the true/fasle
394ace01605SRiver Riddle // destination, thereby ensuring that only that edge can reach the
395ace01605SRiver Riddle // op.
396ace01605SRiver Riddle if (condbr.getTrueDest()->getSinglePredecessor()) {
397ace01605SRiver Riddle for (OpOperand &use :
398ace01605SRiver Riddle llvm::make_early_inc_range(condbr.getCondition().getUses())) {
399ace01605SRiver Riddle if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
400ace01605SRiver Riddle replaced = true;
401ace01605SRiver Riddle
402ace01605SRiver Riddle if (!constantTrue)
403ace01605SRiver Riddle constantTrue = rewriter.create<arith::ConstantOp>(
404ace01605SRiver Riddle condbr.getLoc(), ty, rewriter.getBoolAttr(true));
405ace01605SRiver Riddle
406ace01605SRiver Riddle rewriter.updateRootInPlace(use.getOwner(),
407ace01605SRiver Riddle [&] { use.set(constantTrue); });
408ace01605SRiver Riddle }
409ace01605SRiver Riddle }
410ace01605SRiver Riddle }
411ace01605SRiver Riddle if (condbr.getFalseDest()->getSinglePredecessor()) {
412ace01605SRiver Riddle for (OpOperand &use :
413ace01605SRiver Riddle llvm::make_early_inc_range(condbr.getCondition().getUses())) {
414ace01605SRiver Riddle if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
415ace01605SRiver Riddle replaced = true;
416ace01605SRiver Riddle
417ace01605SRiver Riddle if (!constantFalse)
418ace01605SRiver Riddle constantFalse = rewriter.create<arith::ConstantOp>(
419ace01605SRiver Riddle condbr.getLoc(), ty, rewriter.getBoolAttr(false));
420ace01605SRiver Riddle
421ace01605SRiver Riddle rewriter.updateRootInPlace(use.getOwner(),
422ace01605SRiver Riddle [&] { use.set(constantFalse); });
423ace01605SRiver Riddle }
424ace01605SRiver Riddle }
425ace01605SRiver Riddle }
426ace01605SRiver Riddle return success(replaced);
427ace01605SRiver Riddle }
428ace01605SRiver Riddle };
429ace01605SRiver Riddle } // namespace
430ace01605SRiver Riddle
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)431ace01605SRiver Riddle void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
432ace01605SRiver Riddle MLIRContext *context) {
433ace01605SRiver Riddle results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
434ace01605SRiver Riddle SimplifyCondBranchIdenticalSuccessors,
435ace01605SRiver Riddle SimplifyCondBranchFromCondBranchOnSameCondition,
436ace01605SRiver Riddle CondBranchTruthPropagation>(context);
437ace01605SRiver Riddle }
438ace01605SRiver Riddle
getSuccessorOperands(unsigned index)4390c789db5SMarkus Böck SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
440ace01605SRiver Riddle assert(index < getNumSuccessors() && "invalid successor index");
4410c789db5SMarkus Böck return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
4420c789db5SMarkus Böck : getFalseDestOperandsMutable());
443ace01605SRiver Riddle }
444ace01605SRiver Riddle
getSuccessorForOperands(ArrayRef<Attribute> operands)445ace01605SRiver Riddle Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
446ace01605SRiver Riddle if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
447ace01605SRiver Riddle return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
448ace01605SRiver Riddle return nullptr;
449ace01605SRiver Riddle }
450ace01605SRiver Riddle
451ace01605SRiver Riddle //===----------------------------------------------------------------------===//
452ace01605SRiver Riddle // SwitchOp
453ace01605SRiver Riddle //===----------------------------------------------------------------------===//
454ace01605SRiver Riddle
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,DenseIntElementsAttr caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)455ace01605SRiver Riddle void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
456ace01605SRiver Riddle Block *defaultDestination, ValueRange defaultOperands,
457ace01605SRiver Riddle DenseIntElementsAttr caseValues,
458ace01605SRiver Riddle BlockRange caseDestinations,
459ace01605SRiver Riddle ArrayRef<ValueRange> caseOperands) {
460ace01605SRiver Riddle build(builder, result, value, defaultOperands, caseOperands, caseValues,
461ace01605SRiver Riddle defaultDestination, caseDestinations);
462ace01605SRiver Riddle }
463ace01605SRiver Riddle
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<APInt> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)464ace01605SRiver Riddle void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
465ace01605SRiver Riddle Block *defaultDestination, ValueRange defaultOperands,
466ace01605SRiver Riddle ArrayRef<APInt> caseValues, BlockRange caseDestinations,
467ace01605SRiver Riddle ArrayRef<ValueRange> caseOperands) {
468ace01605SRiver Riddle DenseIntElementsAttr caseValuesAttr;
469ace01605SRiver Riddle if (!caseValues.empty()) {
470ace01605SRiver Riddle ShapedType caseValueType = VectorType::get(
471ace01605SRiver Riddle static_cast<int64_t>(caseValues.size()), value.getType());
472ace01605SRiver Riddle caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
473ace01605SRiver Riddle }
474ace01605SRiver Riddle build(builder, result, value, defaultDestination, defaultOperands,
475ace01605SRiver Riddle caseValuesAttr, caseDestinations, caseOperands);
476ace01605SRiver Riddle }
477ace01605SRiver Riddle
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<int32_t> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)478b34fb277SAlexander Batashev void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
479b34fb277SAlexander Batashev Block *defaultDestination, ValueRange defaultOperands,
480b34fb277SAlexander Batashev ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
481b34fb277SAlexander Batashev ArrayRef<ValueRange> caseOperands) {
482b34fb277SAlexander Batashev DenseIntElementsAttr caseValuesAttr;
483b34fb277SAlexander Batashev if (!caseValues.empty()) {
484b34fb277SAlexander Batashev ShapedType caseValueType = VectorType::get(
485b34fb277SAlexander Batashev static_cast<int64_t>(caseValues.size()), value.getType());
486b34fb277SAlexander Batashev caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
487b34fb277SAlexander Batashev }
488b34fb277SAlexander Batashev build(builder, result, value, defaultDestination, defaultOperands,
489b34fb277SAlexander Batashev caseValuesAttr, caseDestinations, caseOperands);
490b34fb277SAlexander Batashev }
491b34fb277SAlexander Batashev
492ace01605SRiver Riddle /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
493ace01605SRiver Riddle /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
parseSwitchOpCases(OpAsmParser & parser,Type & flagType,Block * & defaultDestination,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & defaultOperands,SmallVectorImpl<Type> & defaultOperandTypes,DenseIntElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> & caseOperands,SmallVectorImpl<SmallVector<Type>> & caseOperandTypes)494ace01605SRiver Riddle static ParseResult parseSwitchOpCases(
495ace01605SRiver Riddle OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
496e13d23bcSMarkus Böck SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
497ace01605SRiver Riddle SmallVectorImpl<Type> &defaultOperandTypes,
498ace01605SRiver Riddle DenseIntElementsAttr &caseValues,
499ace01605SRiver Riddle SmallVectorImpl<Block *> &caseDestinations,
500e13d23bcSMarkus Böck SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
501ace01605SRiver Riddle SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
502ace01605SRiver Riddle if (parser.parseKeyword("default") || parser.parseColon() ||
503ace01605SRiver Riddle parser.parseSuccessor(defaultDestination))
504ace01605SRiver Riddle return failure();
505ace01605SRiver Riddle if (succeeded(parser.parseOptionalLParen())) {
5065dedf911SChris Lattner if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
5075dedf911SChris Lattner /*allowResultNumber=*/false) ||
508ace01605SRiver Riddle parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
509ace01605SRiver Riddle return failure();
510ace01605SRiver Riddle }
511ace01605SRiver Riddle
512ace01605SRiver Riddle SmallVector<APInt> values;
513ace01605SRiver Riddle unsigned bitWidth = flagType.getIntOrFloatBitWidth();
514ace01605SRiver Riddle while (succeeded(parser.parseOptionalComma())) {
515ace01605SRiver Riddle int64_t value = 0;
516ace01605SRiver Riddle if (failed(parser.parseInteger(value)))
517ace01605SRiver Riddle return failure();
518ace01605SRiver Riddle values.push_back(APInt(bitWidth, value));
519ace01605SRiver Riddle
520ace01605SRiver Riddle Block *destination;
521e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand> operands;
522ace01605SRiver Riddle SmallVector<Type> operandTypes;
523ace01605SRiver Riddle if (failed(parser.parseColon()) ||
524ace01605SRiver Riddle failed(parser.parseSuccessor(destination)))
525ace01605SRiver Riddle return failure();
526ace01605SRiver Riddle if (succeeded(parser.parseOptionalLParen())) {
5275dedf911SChris Lattner if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
5285dedf911SChris Lattner /*allowResultNumber=*/false)) ||
529ace01605SRiver Riddle failed(parser.parseColonTypeList(operandTypes)) ||
530ace01605SRiver Riddle failed(parser.parseRParen()))
531ace01605SRiver Riddle return failure();
532ace01605SRiver Riddle }
533ace01605SRiver Riddle caseDestinations.push_back(destination);
534ace01605SRiver Riddle caseOperands.emplace_back(operands);
535ace01605SRiver Riddle caseOperandTypes.emplace_back(operandTypes);
536ace01605SRiver Riddle }
537ace01605SRiver Riddle
538ace01605SRiver Riddle if (!values.empty()) {
539ace01605SRiver Riddle ShapedType caseValueType =
540ace01605SRiver Riddle VectorType::get(static_cast<int64_t>(values.size()), flagType);
541ace01605SRiver Riddle caseValues = DenseIntElementsAttr::get(caseValueType, values);
542ace01605SRiver Riddle }
543ace01605SRiver Riddle return success();
544ace01605SRiver Riddle }
545ace01605SRiver Riddle
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,Type flagType,Block * defaultDestination,OperandRange defaultOperands,TypeRange defaultOperandTypes,DenseIntElementsAttr caseValues,SuccessorRange caseDestinations,OperandRangeRange caseOperands,const TypeRangeRange & caseOperandTypes)546ace01605SRiver Riddle static void printSwitchOpCases(
547ace01605SRiver Riddle OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
548ace01605SRiver Riddle OperandRange defaultOperands, TypeRange defaultOperandTypes,
549ace01605SRiver Riddle DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
550ace01605SRiver Riddle OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
551ace01605SRiver Riddle p << " default: ";
552ace01605SRiver Riddle p.printSuccessorAndUseList(defaultDestination, defaultOperands);
553ace01605SRiver Riddle
554ace01605SRiver Riddle if (!caseValues)
555ace01605SRiver Riddle return;
556ace01605SRiver Riddle
557ace01605SRiver Riddle for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
558ace01605SRiver Riddle p << ',';
559ace01605SRiver Riddle p.printNewline();
560ace01605SRiver Riddle p << " ";
561ace01605SRiver Riddle p << it.value().getLimitedValue();
562ace01605SRiver Riddle p << ": ";
563ace01605SRiver Riddle p.printSuccessorAndUseList(caseDestinations[it.index()],
564ace01605SRiver Riddle caseOperands[it.index()]);
565ace01605SRiver Riddle }
566ace01605SRiver Riddle p.printNewline();
567ace01605SRiver Riddle }
568ace01605SRiver Riddle
verify()569ace01605SRiver Riddle LogicalResult SwitchOp::verify() {
570ace01605SRiver Riddle auto caseValues = getCaseValues();
571ace01605SRiver Riddle auto caseDestinations = getCaseDestinations();
572ace01605SRiver Riddle
573ace01605SRiver Riddle if (!caseValues && caseDestinations.empty())
574ace01605SRiver Riddle return success();
575ace01605SRiver Riddle
576ace01605SRiver Riddle Type flagType = getFlag().getType();
577ace01605SRiver Riddle Type caseValueType = caseValues->getType().getElementType();
578ace01605SRiver Riddle if (caseValueType != flagType)
579ace01605SRiver Riddle return emitOpError() << "'flag' type (" << flagType
580ace01605SRiver Riddle << ") should match case value type (" << caseValueType
581ace01605SRiver Riddle << ")";
582ace01605SRiver Riddle
583ace01605SRiver Riddle if (caseValues &&
584ace01605SRiver Riddle caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
585ace01605SRiver Riddle return emitOpError() << "number of case values (" << caseValues->size()
586ace01605SRiver Riddle << ") should match number of "
587ace01605SRiver Riddle "case destinations ("
588ace01605SRiver Riddle << caseDestinations.size() << ")";
589ace01605SRiver Riddle return success();
590ace01605SRiver Riddle }
591ace01605SRiver Riddle
getSuccessorOperands(unsigned index)5920c789db5SMarkus Böck SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
593ace01605SRiver Riddle assert(index < getNumSuccessors() && "invalid successor index");
5940c789db5SMarkus Böck return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
5950c789db5SMarkus Böck : getCaseOperandsMutable(index - 1));
596ace01605SRiver Riddle }
597ace01605SRiver Riddle
getSuccessorForOperands(ArrayRef<Attribute> operands)598ace01605SRiver Riddle Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
599ace01605SRiver Riddle Optional<DenseIntElementsAttr> caseValues = getCaseValues();
600ace01605SRiver Riddle
601ace01605SRiver Riddle if (!caseValues)
602ace01605SRiver Riddle return getDefaultDestination();
603ace01605SRiver Riddle
604ace01605SRiver Riddle SuccessorRange caseDests = getCaseDestinations();
605ace01605SRiver Riddle if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
606ace01605SRiver Riddle for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
607ace01605SRiver Riddle if (it.value() == value.getValue())
608ace01605SRiver Riddle return caseDests[it.index()];
609ace01605SRiver Riddle return getDefaultDestination();
610ace01605SRiver Riddle }
611ace01605SRiver Riddle return nullptr;
612ace01605SRiver Riddle }
613ace01605SRiver Riddle
614ace01605SRiver Riddle /// switch %flag : i32, [
615ace01605SRiver Riddle /// default: ^bb1
616ace01605SRiver Riddle /// ]
617ace01605SRiver Riddle /// -> br ^bb1
simplifySwitchWithOnlyDefault(SwitchOp op,PatternRewriter & rewriter)618ace01605SRiver Riddle static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
619ace01605SRiver Riddle PatternRewriter &rewriter) {
620ace01605SRiver Riddle if (!op.getCaseDestinations().empty())
621ace01605SRiver Riddle return failure();
622ace01605SRiver Riddle
623ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
624ace01605SRiver Riddle op.getDefaultOperands());
625ace01605SRiver Riddle return success();
626ace01605SRiver Riddle }
627ace01605SRiver Riddle
628ace01605SRiver Riddle /// switch %flag : i32, [
629ace01605SRiver Riddle /// default: ^bb1,
630ace01605SRiver Riddle /// 42: ^bb1,
631ace01605SRiver Riddle /// 43: ^bb2
632ace01605SRiver Riddle /// ]
633ace01605SRiver Riddle /// ->
634ace01605SRiver Riddle /// switch %flag : i32, [
635ace01605SRiver Riddle /// default: ^bb1,
636ace01605SRiver Riddle /// 43: ^bb2
637ace01605SRiver Riddle /// ]
638ace01605SRiver Riddle static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op,PatternRewriter & rewriter)639ace01605SRiver Riddle dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
640ace01605SRiver Riddle SmallVector<Block *> newCaseDestinations;
641ace01605SRiver Riddle SmallVector<ValueRange> newCaseOperands;
642ace01605SRiver Riddle SmallVector<APInt> newCaseValues;
643ace01605SRiver Riddle bool requiresChange = false;
644ace01605SRiver Riddle auto caseValues = op.getCaseValues();
645ace01605SRiver Riddle auto caseDests = op.getCaseDestinations();
646ace01605SRiver Riddle
647ace01605SRiver Riddle for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
648ace01605SRiver Riddle if (caseDests[it.index()] == op.getDefaultDestination() &&
649ace01605SRiver Riddle op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
650ace01605SRiver Riddle requiresChange = true;
651ace01605SRiver Riddle continue;
652ace01605SRiver Riddle }
653ace01605SRiver Riddle newCaseDestinations.push_back(caseDests[it.index()]);
654ace01605SRiver Riddle newCaseOperands.push_back(op.getCaseOperands(it.index()));
655ace01605SRiver Riddle newCaseValues.push_back(it.value());
656ace01605SRiver Riddle }
657ace01605SRiver Riddle
658ace01605SRiver Riddle if (!requiresChange)
659ace01605SRiver Riddle return failure();
660ace01605SRiver Riddle
661ace01605SRiver Riddle rewriter.replaceOpWithNewOp<SwitchOp>(
662ace01605SRiver Riddle op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
663ace01605SRiver Riddle newCaseValues, newCaseDestinations, newCaseOperands);
664ace01605SRiver Riddle return success();
665ace01605SRiver Riddle }
666ace01605SRiver Riddle
667ace01605SRiver Riddle /// Helper for folding a switch with a constant value.
668ace01605SRiver Riddle /// switch %c_42 : i32, [
669ace01605SRiver Riddle /// default: ^bb1 ,
670ace01605SRiver Riddle /// 42: ^bb2,
671ace01605SRiver Riddle /// 43: ^bb3
672ace01605SRiver Riddle /// ]
673ace01605SRiver Riddle /// -> br ^bb2
foldSwitch(SwitchOp op,PatternRewriter & rewriter,const APInt & caseValue)674ace01605SRiver Riddle static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
675ace01605SRiver Riddle const APInt &caseValue) {
676ace01605SRiver Riddle auto caseValues = op.getCaseValues();
677ace01605SRiver Riddle for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
678ace01605SRiver Riddle if (it.value() == caseValue) {
679ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(
680ace01605SRiver Riddle op, op.getCaseDestinations()[it.index()],
681ace01605SRiver Riddle op.getCaseOperands(it.index()));
682ace01605SRiver Riddle return;
683ace01605SRiver Riddle }
684ace01605SRiver Riddle }
685ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
686ace01605SRiver Riddle op.getDefaultOperands());
687ace01605SRiver Riddle }
688ace01605SRiver Riddle
689ace01605SRiver Riddle /// switch %c_42 : i32, [
690ace01605SRiver Riddle /// default: ^bb1,
691ace01605SRiver Riddle /// 42: ^bb2,
692ace01605SRiver Riddle /// 43: ^bb3
693ace01605SRiver Riddle /// ]
694ace01605SRiver Riddle /// -> br ^bb2
simplifyConstSwitchValue(SwitchOp op,PatternRewriter & rewriter)695ace01605SRiver Riddle static LogicalResult simplifyConstSwitchValue(SwitchOp op,
696ace01605SRiver Riddle PatternRewriter &rewriter) {
697ace01605SRiver Riddle APInt caseValue;
698ace01605SRiver Riddle if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
699ace01605SRiver Riddle return failure();
700ace01605SRiver Riddle
701ace01605SRiver Riddle foldSwitch(op, rewriter, caseValue);
702ace01605SRiver Riddle return success();
703ace01605SRiver Riddle }
704ace01605SRiver Riddle
705ace01605SRiver Riddle /// switch %c_42 : i32, [
706ace01605SRiver Riddle /// default: ^bb1,
707ace01605SRiver Riddle /// 42: ^bb2,
708ace01605SRiver Riddle /// ]
709ace01605SRiver Riddle /// ^bb2:
710ace01605SRiver Riddle /// br ^bb3
711ace01605SRiver Riddle /// ->
712ace01605SRiver Riddle /// switch %c_42 : i32, [
713ace01605SRiver Riddle /// default: ^bb1,
714ace01605SRiver Riddle /// 42: ^bb3,
715ace01605SRiver Riddle /// ]
simplifyPassThroughSwitch(SwitchOp op,PatternRewriter & rewriter)716ace01605SRiver Riddle static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
717ace01605SRiver Riddle PatternRewriter &rewriter) {
718ace01605SRiver Riddle SmallVector<Block *> newCaseDests;
719ace01605SRiver Riddle SmallVector<ValueRange> newCaseOperands;
720ace01605SRiver Riddle SmallVector<SmallVector<Value>> argStorage;
721ace01605SRiver Riddle auto caseValues = op.getCaseValues();
722f735b3a2SVitaly Buka argStorage.reserve(caseValues->size() + 1);
723ace01605SRiver Riddle auto caseDests = op.getCaseDestinations();
724ace01605SRiver Riddle bool requiresChange = false;
725ace01605SRiver Riddle for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
726ace01605SRiver Riddle Block *caseDest = caseDests[i];
727ace01605SRiver Riddle ValueRange caseOperands = op.getCaseOperands(i);
728ace01605SRiver Riddle argStorage.emplace_back();
729ace01605SRiver Riddle if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
730ace01605SRiver Riddle requiresChange = true;
731ace01605SRiver Riddle
732ace01605SRiver Riddle newCaseDests.push_back(caseDest);
733ace01605SRiver Riddle newCaseOperands.push_back(caseOperands);
734ace01605SRiver Riddle }
735ace01605SRiver Riddle
736ace01605SRiver Riddle Block *defaultDest = op.getDefaultDestination();
737ace01605SRiver Riddle ValueRange defaultOperands = op.getDefaultOperands();
738ace01605SRiver Riddle argStorage.emplace_back();
739ace01605SRiver Riddle
740ace01605SRiver Riddle if (succeeded(
741ace01605SRiver Riddle collapseBranch(defaultDest, defaultOperands, argStorage.back())))
742ace01605SRiver Riddle requiresChange = true;
743ace01605SRiver Riddle
744ace01605SRiver Riddle if (!requiresChange)
745ace01605SRiver Riddle return failure();
746ace01605SRiver Riddle
747ace01605SRiver Riddle rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
748*6d5fc1e3SKazu Hirata defaultOperands, *caseValues,
749ace01605SRiver Riddle newCaseDests, newCaseOperands);
750ace01605SRiver Riddle return success();
751ace01605SRiver Riddle }
752ace01605SRiver Riddle
753ace01605SRiver Riddle /// switch %flag : i32, [
754ace01605SRiver Riddle /// default: ^bb1,
755ace01605SRiver Riddle /// 42: ^bb2,
756ace01605SRiver Riddle /// ]
757ace01605SRiver Riddle /// ^bb2:
758ace01605SRiver Riddle /// switch %flag : i32, [
759ace01605SRiver Riddle /// default: ^bb3,
760ace01605SRiver Riddle /// 42: ^bb4
761ace01605SRiver Riddle /// ]
762ace01605SRiver Riddle /// ->
763ace01605SRiver Riddle /// switch %flag : i32, [
764ace01605SRiver Riddle /// default: ^bb1,
765ace01605SRiver Riddle /// 42: ^bb2,
766ace01605SRiver Riddle /// ]
767ace01605SRiver Riddle /// ^bb2:
768ace01605SRiver Riddle /// br ^bb4
769ace01605SRiver Riddle ///
770ace01605SRiver Riddle /// and
771ace01605SRiver Riddle ///
772ace01605SRiver Riddle /// switch %flag : i32, [
773ace01605SRiver Riddle /// default: ^bb1,
774ace01605SRiver Riddle /// 42: ^bb2,
775ace01605SRiver Riddle /// ]
776ace01605SRiver Riddle /// ^bb2:
777ace01605SRiver Riddle /// switch %flag : i32, [
778ace01605SRiver Riddle /// default: ^bb3,
779ace01605SRiver Riddle /// 43: ^bb4
780ace01605SRiver Riddle /// ]
781ace01605SRiver Riddle /// ->
782ace01605SRiver Riddle /// switch %flag : i32, [
783ace01605SRiver Riddle /// default: ^bb1,
784ace01605SRiver Riddle /// 42: ^bb2,
785ace01605SRiver Riddle /// ]
786ace01605SRiver Riddle /// ^bb2:
787ace01605SRiver Riddle /// br ^bb3
788ace01605SRiver Riddle static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)789ace01605SRiver Riddle simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
790ace01605SRiver Riddle PatternRewriter &rewriter) {
791ace01605SRiver Riddle // Check that we have a single distinct predecessor.
792ace01605SRiver Riddle Block *currentBlock = op->getBlock();
793ace01605SRiver Riddle Block *predecessor = currentBlock->getSinglePredecessor();
794ace01605SRiver Riddle if (!predecessor)
795ace01605SRiver Riddle return failure();
796ace01605SRiver Riddle
797ace01605SRiver Riddle // Check that the predecessor terminates with a switch branch to this block
798ace01605SRiver Riddle // and that it branches on the same condition and that this branch isn't the
799ace01605SRiver Riddle // default destination.
800ace01605SRiver Riddle auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
801ace01605SRiver Riddle if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
802ace01605SRiver Riddle predSwitch.getDefaultDestination() == currentBlock)
803ace01605SRiver Riddle return failure();
804ace01605SRiver Riddle
805ace01605SRiver Riddle // Fold this switch to an unconditional branch.
806ace01605SRiver Riddle SuccessorRange predDests = predSwitch.getCaseDestinations();
807ace01605SRiver Riddle auto it = llvm::find(predDests, currentBlock);
808ace01605SRiver Riddle if (it != predDests.end()) {
809ace01605SRiver Riddle Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
810ace01605SRiver Riddle foldSwitch(op, rewriter,
811ace01605SRiver Riddle predCaseValues->getValues<APInt>()[it - predDests.begin()]);
812ace01605SRiver Riddle } else {
813ace01605SRiver Riddle rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
814ace01605SRiver Riddle op.getDefaultOperands());
815ace01605SRiver Riddle }
816ace01605SRiver Riddle return success();
817ace01605SRiver Riddle }
818ace01605SRiver Riddle
819ace01605SRiver Riddle /// switch %flag : i32, [
820ace01605SRiver Riddle /// default: ^bb1,
821ace01605SRiver Riddle /// 42: ^bb2
822ace01605SRiver Riddle /// ]
823ace01605SRiver Riddle /// ^bb1:
824ace01605SRiver Riddle /// switch %flag : i32, [
825ace01605SRiver Riddle /// default: ^bb3,
826ace01605SRiver Riddle /// 42: ^bb4,
827ace01605SRiver Riddle /// 43: ^bb5
828ace01605SRiver Riddle /// ]
829ace01605SRiver Riddle /// ->
830ace01605SRiver Riddle /// switch %flag : i32, [
831ace01605SRiver Riddle /// default: ^bb1,
832ace01605SRiver Riddle /// 42: ^bb2,
833ace01605SRiver Riddle /// ]
834ace01605SRiver Riddle /// ^bb1:
835ace01605SRiver Riddle /// switch %flag : i32, [
836ace01605SRiver Riddle /// default: ^bb3,
837ace01605SRiver Riddle /// 43: ^bb5
838ace01605SRiver Riddle /// ]
839ace01605SRiver Riddle static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)840ace01605SRiver Riddle simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
841ace01605SRiver Riddle PatternRewriter &rewriter) {
842ace01605SRiver Riddle // Check that we have a single distinct predecessor.
843ace01605SRiver Riddle Block *currentBlock = op->getBlock();
844ace01605SRiver Riddle Block *predecessor = currentBlock->getSinglePredecessor();
845ace01605SRiver Riddle if (!predecessor)
846ace01605SRiver Riddle return failure();
847ace01605SRiver Riddle
848ace01605SRiver Riddle // Check that the predecessor terminates with a switch branch to this block
849ace01605SRiver Riddle // and that it branches on the same condition and that this branch is the
850ace01605SRiver Riddle // default destination.
851ace01605SRiver Riddle auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
852ace01605SRiver Riddle if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
853ace01605SRiver Riddle predSwitch.getDefaultDestination() != currentBlock)
854ace01605SRiver Riddle return failure();
855ace01605SRiver Riddle
856ace01605SRiver Riddle // Delete case values that are not possible here.
857ace01605SRiver Riddle DenseSet<APInt> caseValuesToRemove;
858ace01605SRiver Riddle auto predDests = predSwitch.getCaseDestinations();
859ace01605SRiver Riddle auto predCaseValues = predSwitch.getCaseValues();
860ace01605SRiver Riddle for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
861ace01605SRiver Riddle if (currentBlock != predDests[i])
862ace01605SRiver Riddle caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
863ace01605SRiver Riddle
864ace01605SRiver Riddle SmallVector<Block *> newCaseDestinations;
865ace01605SRiver Riddle SmallVector<ValueRange> newCaseOperands;
866ace01605SRiver Riddle SmallVector<APInt> newCaseValues;
867ace01605SRiver Riddle bool requiresChange = false;
868ace01605SRiver Riddle
869ace01605SRiver Riddle auto caseValues = op.getCaseValues();
870ace01605SRiver Riddle auto caseDests = op.getCaseDestinations();
871ace01605SRiver Riddle for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
872ace01605SRiver Riddle if (caseValuesToRemove.contains(it.value())) {
873ace01605SRiver Riddle requiresChange = true;
874ace01605SRiver Riddle continue;
875ace01605SRiver Riddle }
876ace01605SRiver Riddle newCaseDestinations.push_back(caseDests[it.index()]);
877ace01605SRiver Riddle newCaseOperands.push_back(op.getCaseOperands(it.index()));
878ace01605SRiver Riddle newCaseValues.push_back(it.value());
879ace01605SRiver Riddle }
880ace01605SRiver Riddle
881ace01605SRiver Riddle if (!requiresChange)
882ace01605SRiver Riddle return failure();
883ace01605SRiver Riddle
884ace01605SRiver Riddle rewriter.replaceOpWithNewOp<SwitchOp>(
885ace01605SRiver Riddle op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
886ace01605SRiver Riddle newCaseValues, newCaseDestinations, newCaseOperands);
887ace01605SRiver Riddle return success();
888ace01605SRiver Riddle }
889ace01605SRiver Riddle
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)890ace01605SRiver Riddle void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
891ace01605SRiver Riddle MLIRContext *context) {
892ace01605SRiver Riddle results.add(&simplifySwitchWithOnlyDefault)
893ace01605SRiver Riddle .add(&dropSwitchCasesThatMatchDefault)
894ace01605SRiver Riddle .add(&simplifyConstSwitchValue)
895ace01605SRiver Riddle .add(&simplifyPassThroughSwitch)
896ace01605SRiver Riddle .add(&simplifySwitchFromSwitchOnSameCondition)
897ace01605SRiver Riddle .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
898ace01605SRiver Riddle }
899ace01605SRiver Riddle
900ace01605SRiver Riddle //===----------------------------------------------------------------------===//
901ace01605SRiver Riddle // TableGen'd op method definitions
902ace01605SRiver Riddle //===----------------------------------------------------------------------===//
903ace01605SRiver Riddle
904ace01605SRiver Riddle #define GET_OP_CLASSES
905ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
906