1 //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/AffineExpr.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Support/MathExtras.h"
25 #include "mlir/Transforms/InliningUtils.h"
26 #include "llvm/ADT/APFloat.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/StringSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <numeric>
32
33 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
34
35 using namespace mlir;
36 using namespace mlir::cf;
37
38 //===----------------------------------------------------------------------===//
39 // ControlFlowDialect Interfaces
40 //===----------------------------------------------------------------------===//
41 namespace {
42 /// This class defines the interface for handling inlining with control flow
43 /// operations.
44 struct ControlFlowInlinerInterface : public DialectInlinerInterface {
45 using DialectInlinerInterface::DialectInlinerInterface;
46 ~ControlFlowInlinerInterface() override = default;
47
48 /// All control flow operations can be inlined.
isLegalToInline__anonbb8b539b0111::ControlFlowInlinerInterface49 bool isLegalToInline(Operation *call, Operation *callable,
50 bool wouldBeCloned) const final {
51 return true;
52 }
isLegalToInline__anonbb8b539b0111::ControlFlowInlinerInterface53 bool isLegalToInline(Operation *, Region *, bool,
54 BlockAndValueMapping &) const final {
55 return true;
56 }
57
58 /// ControlFlow terminator operations don't really need any special handing.
handleTerminator__anonbb8b539b0111::ControlFlowInlinerInterface59 void handleTerminator(Operation *op, Block *newDest) const final {}
60 };
61 } // namespace
62
63 //===----------------------------------------------------------------------===//
64 // ControlFlowDialect
65 //===----------------------------------------------------------------------===//
66
initialize()67 void ControlFlowDialect::initialize() {
68 addOperations<
69 #define GET_OP_LIST
70 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
71 >();
72 addInterfaces<ControlFlowInlinerInterface>();
73 }
74
75 //===----------------------------------------------------------------------===//
76 // AssertOp
77 //===----------------------------------------------------------------------===//
78
canonicalize(AssertOp op,PatternRewriter & rewriter)79 LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
80 // Erase assertion if argument is constant true.
81 if (matchPattern(op.getArg(), m_One())) {
82 rewriter.eraseOp(op);
83 return success();
84 }
85 return failure();
86 }
87
88 //===----------------------------------------------------------------------===//
89 // BranchOp
90 //===----------------------------------------------------------------------===//
91
92 /// Given a successor, try to collapse it to a new destination if it only
93 /// contains a passthrough unconditional branch. If the successor is
94 /// collapsable, `successor` and `successorOperands` are updated to reference
95 /// the new destination and values. `argStorage` is used as storage if operands
96 /// to the collapsed successor need to be remapped. It must outlive uses of
97 /// successorOperands.
collapseBranch(Block * & successor,ValueRange & successorOperands,SmallVectorImpl<Value> & argStorage)98 static LogicalResult collapseBranch(Block *&successor,
99 ValueRange &successorOperands,
100 SmallVectorImpl<Value> &argStorage) {
101 // Check that the successor only contains a unconditional branch.
102 if (std::next(successor->begin()) != successor->end())
103 return failure();
104 // Check that the terminator is an unconditional branch.
105 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
106 if (!successorBranch)
107 return failure();
108 // Check that the arguments are only used within the terminator.
109 for (BlockArgument arg : successor->getArguments()) {
110 for (Operation *user : arg.getUsers())
111 if (user != successorBranch)
112 return failure();
113 }
114 // Don't try to collapse branches to infinite loops.
115 Block *successorDest = successorBranch.getDest();
116 if (successorDest == successor)
117 return failure();
118
119 // Update the operands to the successor. If the branch parent has no
120 // arguments, we can use the branch operands directly.
121 OperandRange operands = successorBranch.getOperands();
122 if (successor->args_empty()) {
123 successor = successorDest;
124 successorOperands = operands;
125 return success();
126 }
127
128 // Otherwise, we need to remap any argument operands.
129 for (Value operand : operands) {
130 BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
131 if (argOperand && argOperand.getOwner() == successor)
132 argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
133 else
134 argStorage.push_back(operand);
135 }
136 successor = successorDest;
137 successorOperands = argStorage;
138 return success();
139 }
140
141 /// Simplify a branch to a block that has a single predecessor. This effectively
142 /// merges the two blocks.
143 static LogicalResult
simplifyBrToBlockWithSinglePred(BranchOp op,PatternRewriter & rewriter)144 simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
145 // Check that the successor block has a single predecessor.
146 Block *succ = op.getDest();
147 Block *opParent = op->getBlock();
148 if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
149 return failure();
150
151 // Merge the successor into the current block and erase the branch.
152 rewriter.mergeBlocks(succ, opParent, op.getOperands());
153 rewriter.eraseOp(op);
154 return success();
155 }
156
157 /// br ^bb1
158 /// ^bb1
159 /// br ^bbN(...)
160 ///
161 /// -> br ^bbN(...)
162 ///
simplifyPassThroughBr(BranchOp op,PatternRewriter & rewriter)163 static LogicalResult simplifyPassThroughBr(BranchOp op,
164 PatternRewriter &rewriter) {
165 Block *dest = op.getDest();
166 ValueRange destOperands = op.getOperands();
167 SmallVector<Value, 4> destOperandStorage;
168
169 // Try to collapse the successor if it points somewhere other than this
170 // block.
171 if (dest == op->getBlock() ||
172 failed(collapseBranch(dest, destOperands, destOperandStorage)))
173 return failure();
174
175 // Create a new branch with the collapsed successor.
176 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
177 return success();
178 }
179
canonicalize(BranchOp op,PatternRewriter & rewriter)180 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
181 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
182 succeeded(simplifyPassThroughBr(op, rewriter)));
183 }
184
setDest(Block * block)185 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
186
eraseOperand(unsigned index)187 void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
188
getSuccessorOperands(unsigned index)189 SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
190 assert(index == 0 && "invalid successor index");
191 return SuccessorOperands(getDestOperandsMutable());
192 }
193
getSuccessorForOperands(ArrayRef<Attribute>)194 Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
195 return getDest();
196 }
197
198 //===----------------------------------------------------------------------===//
199 // CondBranchOp
200 //===----------------------------------------------------------------------===//
201
202 namespace {
203 /// cf.cond_br true, ^bb1, ^bb2
204 /// -> br ^bb1
205 /// cf.cond_br false, ^bb1, ^bb2
206 /// -> br ^bb2
207 ///
208 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
209 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
210
matchAndRewrite__anonbb8b539b0211::SimplifyConstCondBranchPred211 LogicalResult matchAndRewrite(CondBranchOp condbr,
212 PatternRewriter &rewriter) const override {
213 if (matchPattern(condbr.getCondition(), m_NonZero())) {
214 // True branch taken.
215 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
216 condbr.getTrueOperands());
217 return success();
218 }
219 if (matchPattern(condbr.getCondition(), m_Zero())) {
220 // False branch taken.
221 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
222 condbr.getFalseOperands());
223 return success();
224 }
225 return failure();
226 }
227 };
228
229 /// cf.cond_br %cond, ^bb1, ^bb2
230 /// ^bb1
231 /// br ^bbN(...)
232 /// ^bb2
233 /// br ^bbK(...)
234 ///
235 /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
236 ///
237 struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
238 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
239
matchAndRewrite__anonbb8b539b0211::SimplifyPassThroughCondBranch240 LogicalResult matchAndRewrite(CondBranchOp condbr,
241 PatternRewriter &rewriter) const override {
242 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
243 ValueRange trueDestOperands = condbr.getTrueOperands();
244 ValueRange falseDestOperands = condbr.getFalseOperands();
245 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
246
247 // Try to collapse one of the current successors.
248 LogicalResult collapsedTrue =
249 collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
250 LogicalResult collapsedFalse =
251 collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
252 if (failed(collapsedTrue) && failed(collapsedFalse))
253 return failure();
254
255 // Create a new branch with the collapsed successors.
256 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
257 trueDest, trueDestOperands,
258 falseDest, falseDestOperands);
259 return success();
260 }
261 };
262
263 /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
264 /// -> br ^bb1(A, ..., N)
265 ///
266 /// cf.cond_br %cond, ^bb1(A), ^bb1(B)
267 /// -> %select = arith.select %cond, A, B
268 /// br ^bb1(%select)
269 ///
270 struct SimplifyCondBranchIdenticalSuccessors
271 : public OpRewritePattern<CondBranchOp> {
272 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
273
matchAndRewrite__anonbb8b539b0211::SimplifyCondBranchIdenticalSuccessors274 LogicalResult matchAndRewrite(CondBranchOp condbr,
275 PatternRewriter &rewriter) const override {
276 // Check that the true and false destinations are the same and have the same
277 // operands.
278 Block *trueDest = condbr.getTrueDest();
279 if (trueDest != condbr.getFalseDest())
280 return failure();
281
282 // If all of the operands match, no selects need to be generated.
283 OperandRange trueOperands = condbr.getTrueOperands();
284 OperandRange falseOperands = condbr.getFalseOperands();
285 if (trueOperands == falseOperands) {
286 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
287 return success();
288 }
289
290 // Otherwise, if the current block is the only predecessor insert selects
291 // for any mismatched branch operands.
292 if (trueDest->getUniquePredecessor() != condbr->getBlock())
293 return failure();
294
295 // Generate a select for any operands that differ between the two.
296 SmallVector<Value, 8> mergedOperands;
297 mergedOperands.reserve(trueOperands.size());
298 Value condition = condbr.getCondition();
299 for (auto it : llvm::zip(trueOperands, falseOperands)) {
300 if (std::get<0>(it) == std::get<1>(it))
301 mergedOperands.push_back(std::get<0>(it));
302 else
303 mergedOperands.push_back(rewriter.create<arith::SelectOp>(
304 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
305 }
306
307 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
308 return success();
309 }
310 };
311
312 /// ...
313 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
314 /// ...
315 /// ^bb1: // has single predecessor
316 /// ...
317 /// cf.cond_br %cond, ^bb3(...), ^bb4(...)
318 ///
319 /// ->
320 ///
321 /// ...
322 /// cf.cond_br %cond, ^bb1(...), ^bb2(...)
323 /// ...
324 /// ^bb1: // has single predecessor
325 /// ...
326 /// br ^bb3(...)
327 ///
328 struct SimplifyCondBranchFromCondBranchOnSameCondition
329 : public OpRewritePattern<CondBranchOp> {
330 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
331
matchAndRewrite__anonbb8b539b0211::SimplifyCondBranchFromCondBranchOnSameCondition332 LogicalResult matchAndRewrite(CondBranchOp condbr,
333 PatternRewriter &rewriter) const override {
334 // Check that we have a single distinct predecessor.
335 Block *currentBlock = condbr->getBlock();
336 Block *predecessor = currentBlock->getSinglePredecessor();
337 if (!predecessor)
338 return failure();
339
340 // Check that the predecessor terminates with a conditional branch to this
341 // block and that it branches on the same condition.
342 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
343 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
344 return failure();
345
346 // Fold this branch to an unconditional branch.
347 if (currentBlock == predBranch.getTrueDest())
348 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
349 condbr.getTrueDestOperands());
350 else
351 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
352 condbr.getFalseDestOperands());
353 return success();
354 }
355 };
356
357 /// cf.cond_br %arg0, ^trueB, ^falseB
358 ///
359 /// ^trueB:
360 /// "test.consumer1"(%arg0) : (i1) -> ()
361 /// ...
362 ///
363 /// ^falseB:
364 /// "test.consumer2"(%arg0) : (i1) -> ()
365 /// ...
366 ///
367 /// ->
368 ///
369 /// cf.cond_br %arg0, ^trueB, ^falseB
370 /// ^trueB:
371 /// "test.consumer1"(%true) : (i1) -> ()
372 /// ...
373 ///
374 /// ^falseB:
375 /// "test.consumer2"(%false) : (i1) -> ()
376 /// ...
377 struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
378 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
379
matchAndRewrite__anonbb8b539b0211::CondBranchTruthPropagation380 LogicalResult matchAndRewrite(CondBranchOp condbr,
381 PatternRewriter &rewriter) const override {
382 // Check that we have a single distinct predecessor.
383 bool replaced = false;
384 Type ty = rewriter.getI1Type();
385
386 // These variables serve to prevent creating duplicate constants
387 // and hold constant true or false values.
388 Value constantTrue = nullptr;
389 Value constantFalse = nullptr;
390
391 // TODO These checks can be expanded to encompas any use with only
392 // either the true of false edge as a predecessor. For now, we fall
393 // back to checking the single predecessor is given by the true/fasle
394 // destination, thereby ensuring that only that edge can reach the
395 // op.
396 if (condbr.getTrueDest()->getSinglePredecessor()) {
397 for (OpOperand &use :
398 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
399 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
400 replaced = true;
401
402 if (!constantTrue)
403 constantTrue = rewriter.create<arith::ConstantOp>(
404 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
405
406 rewriter.updateRootInPlace(use.getOwner(),
407 [&] { use.set(constantTrue); });
408 }
409 }
410 }
411 if (condbr.getFalseDest()->getSinglePredecessor()) {
412 for (OpOperand &use :
413 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
414 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
415 replaced = true;
416
417 if (!constantFalse)
418 constantFalse = rewriter.create<arith::ConstantOp>(
419 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
420
421 rewriter.updateRootInPlace(use.getOwner(),
422 [&] { use.set(constantFalse); });
423 }
424 }
425 }
426 return success(replaced);
427 }
428 };
429 } // namespace
430
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)431 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
432 MLIRContext *context) {
433 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
434 SimplifyCondBranchIdenticalSuccessors,
435 SimplifyCondBranchFromCondBranchOnSameCondition,
436 CondBranchTruthPropagation>(context);
437 }
438
getSuccessorOperands(unsigned index)439 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
440 assert(index < getNumSuccessors() && "invalid successor index");
441 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
442 : getFalseDestOperandsMutable());
443 }
444
getSuccessorForOperands(ArrayRef<Attribute> operands)445 Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
446 if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
447 return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest();
448 return nullptr;
449 }
450
451 //===----------------------------------------------------------------------===//
452 // SwitchOp
453 //===----------------------------------------------------------------------===//
454
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,DenseIntElementsAttr caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)455 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
456 Block *defaultDestination, ValueRange defaultOperands,
457 DenseIntElementsAttr caseValues,
458 BlockRange caseDestinations,
459 ArrayRef<ValueRange> caseOperands) {
460 build(builder, result, value, defaultOperands, caseOperands, caseValues,
461 defaultDestination, caseDestinations);
462 }
463
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<APInt> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)464 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
465 Block *defaultDestination, ValueRange defaultOperands,
466 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
467 ArrayRef<ValueRange> caseOperands) {
468 DenseIntElementsAttr caseValuesAttr;
469 if (!caseValues.empty()) {
470 ShapedType caseValueType = VectorType::get(
471 static_cast<int64_t>(caseValues.size()), value.getType());
472 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
473 }
474 build(builder, result, value, defaultDestination, defaultOperands,
475 caseValuesAttr, caseDestinations, caseOperands);
476 }
477
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<int32_t> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands)478 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
479 Block *defaultDestination, ValueRange defaultOperands,
480 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
481 ArrayRef<ValueRange> caseOperands) {
482 DenseIntElementsAttr caseValuesAttr;
483 if (!caseValues.empty()) {
484 ShapedType caseValueType = VectorType::get(
485 static_cast<int64_t>(caseValues.size()), value.getType());
486 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
487 }
488 build(builder, result, value, defaultDestination, defaultOperands,
489 caseValuesAttr, caseDestinations, caseOperands);
490 }
491
492 /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
493 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
parseSwitchOpCases(OpAsmParser & parser,Type & flagType,Block * & defaultDestination,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & defaultOperands,SmallVectorImpl<Type> & defaultOperandTypes,DenseIntElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> & caseOperands,SmallVectorImpl<SmallVector<Type>> & caseOperandTypes)494 static ParseResult parseSwitchOpCases(
495 OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
496 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
497 SmallVectorImpl<Type> &defaultOperandTypes,
498 DenseIntElementsAttr &caseValues,
499 SmallVectorImpl<Block *> &caseDestinations,
500 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
501 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
502 if (parser.parseKeyword("default") || parser.parseColon() ||
503 parser.parseSuccessor(defaultDestination))
504 return failure();
505 if (succeeded(parser.parseOptionalLParen())) {
506 if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
507 /*allowResultNumber=*/false) ||
508 parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
509 return failure();
510 }
511
512 SmallVector<APInt> values;
513 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
514 while (succeeded(parser.parseOptionalComma())) {
515 int64_t value = 0;
516 if (failed(parser.parseInteger(value)))
517 return failure();
518 values.push_back(APInt(bitWidth, value));
519
520 Block *destination;
521 SmallVector<OpAsmParser::UnresolvedOperand> operands;
522 SmallVector<Type> operandTypes;
523 if (failed(parser.parseColon()) ||
524 failed(parser.parseSuccessor(destination)))
525 return failure();
526 if (succeeded(parser.parseOptionalLParen())) {
527 if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
528 /*allowResultNumber=*/false)) ||
529 failed(parser.parseColonTypeList(operandTypes)) ||
530 failed(parser.parseRParen()))
531 return failure();
532 }
533 caseDestinations.push_back(destination);
534 caseOperands.emplace_back(operands);
535 caseOperandTypes.emplace_back(operandTypes);
536 }
537
538 if (!values.empty()) {
539 ShapedType caseValueType =
540 VectorType::get(static_cast<int64_t>(values.size()), flagType);
541 caseValues = DenseIntElementsAttr::get(caseValueType, values);
542 }
543 return success();
544 }
545
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,Type flagType,Block * defaultDestination,OperandRange defaultOperands,TypeRange defaultOperandTypes,DenseIntElementsAttr caseValues,SuccessorRange caseDestinations,OperandRangeRange caseOperands,const TypeRangeRange & caseOperandTypes)546 static void printSwitchOpCases(
547 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
548 OperandRange defaultOperands, TypeRange defaultOperandTypes,
549 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
550 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
551 p << " default: ";
552 p.printSuccessorAndUseList(defaultDestination, defaultOperands);
553
554 if (!caseValues)
555 return;
556
557 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
558 p << ',';
559 p.printNewline();
560 p << " ";
561 p << it.value().getLimitedValue();
562 p << ": ";
563 p.printSuccessorAndUseList(caseDestinations[it.index()],
564 caseOperands[it.index()]);
565 }
566 p.printNewline();
567 }
568
verify()569 LogicalResult SwitchOp::verify() {
570 auto caseValues = getCaseValues();
571 auto caseDestinations = getCaseDestinations();
572
573 if (!caseValues && caseDestinations.empty())
574 return success();
575
576 Type flagType = getFlag().getType();
577 Type caseValueType = caseValues->getType().getElementType();
578 if (caseValueType != flagType)
579 return emitOpError() << "'flag' type (" << flagType
580 << ") should match case value type (" << caseValueType
581 << ")";
582
583 if (caseValues &&
584 caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
585 return emitOpError() << "number of case values (" << caseValues->size()
586 << ") should match number of "
587 "case destinations ("
588 << caseDestinations.size() << ")";
589 return success();
590 }
591
getSuccessorOperands(unsigned index)592 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
593 assert(index < getNumSuccessors() && "invalid successor index");
594 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
595 : getCaseOperandsMutable(index - 1));
596 }
597
getSuccessorForOperands(ArrayRef<Attribute> operands)598 Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
599 Optional<DenseIntElementsAttr> caseValues = getCaseValues();
600
601 if (!caseValues)
602 return getDefaultDestination();
603
604 SuccessorRange caseDests = getCaseDestinations();
605 if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
606 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
607 if (it.value() == value.getValue())
608 return caseDests[it.index()];
609 return getDefaultDestination();
610 }
611 return nullptr;
612 }
613
614 /// switch %flag : i32, [
615 /// default: ^bb1
616 /// ]
617 /// -> br ^bb1
simplifySwitchWithOnlyDefault(SwitchOp op,PatternRewriter & rewriter)618 static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
619 PatternRewriter &rewriter) {
620 if (!op.getCaseDestinations().empty())
621 return failure();
622
623 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
624 op.getDefaultOperands());
625 return success();
626 }
627
628 /// switch %flag : i32, [
629 /// default: ^bb1,
630 /// 42: ^bb1,
631 /// 43: ^bb2
632 /// ]
633 /// ->
634 /// switch %flag : i32, [
635 /// default: ^bb1,
636 /// 43: ^bb2
637 /// ]
638 static LogicalResult
dropSwitchCasesThatMatchDefault(SwitchOp op,PatternRewriter & rewriter)639 dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
640 SmallVector<Block *> newCaseDestinations;
641 SmallVector<ValueRange> newCaseOperands;
642 SmallVector<APInt> newCaseValues;
643 bool requiresChange = false;
644 auto caseValues = op.getCaseValues();
645 auto caseDests = op.getCaseDestinations();
646
647 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
648 if (caseDests[it.index()] == op.getDefaultDestination() &&
649 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
650 requiresChange = true;
651 continue;
652 }
653 newCaseDestinations.push_back(caseDests[it.index()]);
654 newCaseOperands.push_back(op.getCaseOperands(it.index()));
655 newCaseValues.push_back(it.value());
656 }
657
658 if (!requiresChange)
659 return failure();
660
661 rewriter.replaceOpWithNewOp<SwitchOp>(
662 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
663 newCaseValues, newCaseDestinations, newCaseOperands);
664 return success();
665 }
666
667 /// Helper for folding a switch with a constant value.
668 /// switch %c_42 : i32, [
669 /// default: ^bb1 ,
670 /// 42: ^bb2,
671 /// 43: ^bb3
672 /// ]
673 /// -> br ^bb2
foldSwitch(SwitchOp op,PatternRewriter & rewriter,const APInt & caseValue)674 static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
675 const APInt &caseValue) {
676 auto caseValues = op.getCaseValues();
677 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
678 if (it.value() == caseValue) {
679 rewriter.replaceOpWithNewOp<BranchOp>(
680 op, op.getCaseDestinations()[it.index()],
681 op.getCaseOperands(it.index()));
682 return;
683 }
684 }
685 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
686 op.getDefaultOperands());
687 }
688
689 /// switch %c_42 : i32, [
690 /// default: ^bb1,
691 /// 42: ^bb2,
692 /// 43: ^bb3
693 /// ]
694 /// -> br ^bb2
simplifyConstSwitchValue(SwitchOp op,PatternRewriter & rewriter)695 static LogicalResult simplifyConstSwitchValue(SwitchOp op,
696 PatternRewriter &rewriter) {
697 APInt caseValue;
698 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
699 return failure();
700
701 foldSwitch(op, rewriter, caseValue);
702 return success();
703 }
704
705 /// switch %c_42 : i32, [
706 /// default: ^bb1,
707 /// 42: ^bb2,
708 /// ]
709 /// ^bb2:
710 /// br ^bb3
711 /// ->
712 /// switch %c_42 : i32, [
713 /// default: ^bb1,
714 /// 42: ^bb3,
715 /// ]
simplifyPassThroughSwitch(SwitchOp op,PatternRewriter & rewriter)716 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
717 PatternRewriter &rewriter) {
718 SmallVector<Block *> newCaseDests;
719 SmallVector<ValueRange> newCaseOperands;
720 SmallVector<SmallVector<Value>> argStorage;
721 auto caseValues = op.getCaseValues();
722 argStorage.reserve(caseValues->size() + 1);
723 auto caseDests = op.getCaseDestinations();
724 bool requiresChange = false;
725 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
726 Block *caseDest = caseDests[i];
727 ValueRange caseOperands = op.getCaseOperands(i);
728 argStorage.emplace_back();
729 if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
730 requiresChange = true;
731
732 newCaseDests.push_back(caseDest);
733 newCaseOperands.push_back(caseOperands);
734 }
735
736 Block *defaultDest = op.getDefaultDestination();
737 ValueRange defaultOperands = op.getDefaultOperands();
738 argStorage.emplace_back();
739
740 if (succeeded(
741 collapseBranch(defaultDest, defaultOperands, argStorage.back())))
742 requiresChange = true;
743
744 if (!requiresChange)
745 return failure();
746
747 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
748 defaultOperands, *caseValues,
749 newCaseDests, newCaseOperands);
750 return success();
751 }
752
753 /// switch %flag : i32, [
754 /// default: ^bb1,
755 /// 42: ^bb2,
756 /// ]
757 /// ^bb2:
758 /// switch %flag : i32, [
759 /// default: ^bb3,
760 /// 42: ^bb4
761 /// ]
762 /// ->
763 /// switch %flag : i32, [
764 /// default: ^bb1,
765 /// 42: ^bb2,
766 /// ]
767 /// ^bb2:
768 /// br ^bb4
769 ///
770 /// and
771 ///
772 /// switch %flag : i32, [
773 /// default: ^bb1,
774 /// 42: ^bb2,
775 /// ]
776 /// ^bb2:
777 /// switch %flag : i32, [
778 /// default: ^bb3,
779 /// 43: ^bb4
780 /// ]
781 /// ->
782 /// switch %flag : i32, [
783 /// default: ^bb1,
784 /// 42: ^bb2,
785 /// ]
786 /// ^bb2:
787 /// br ^bb3
788 static LogicalResult
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)789 simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
790 PatternRewriter &rewriter) {
791 // Check that we have a single distinct predecessor.
792 Block *currentBlock = op->getBlock();
793 Block *predecessor = currentBlock->getSinglePredecessor();
794 if (!predecessor)
795 return failure();
796
797 // Check that the predecessor terminates with a switch branch to this block
798 // and that it branches on the same condition and that this branch isn't the
799 // default destination.
800 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
801 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
802 predSwitch.getDefaultDestination() == currentBlock)
803 return failure();
804
805 // Fold this switch to an unconditional branch.
806 SuccessorRange predDests = predSwitch.getCaseDestinations();
807 auto it = llvm::find(predDests, currentBlock);
808 if (it != predDests.end()) {
809 Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
810 foldSwitch(op, rewriter,
811 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
812 } else {
813 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
814 op.getDefaultOperands());
815 }
816 return success();
817 }
818
819 /// switch %flag : i32, [
820 /// default: ^bb1,
821 /// 42: ^bb2
822 /// ]
823 /// ^bb1:
824 /// switch %flag : i32, [
825 /// default: ^bb3,
826 /// 42: ^bb4,
827 /// 43: ^bb5
828 /// ]
829 /// ->
830 /// switch %flag : i32, [
831 /// default: ^bb1,
832 /// 42: ^bb2,
833 /// ]
834 /// ^bb1:
835 /// switch %flag : i32, [
836 /// default: ^bb3,
837 /// 43: ^bb5
838 /// ]
839 static LogicalResult
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,PatternRewriter & rewriter)840 simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
841 PatternRewriter &rewriter) {
842 // Check that we have a single distinct predecessor.
843 Block *currentBlock = op->getBlock();
844 Block *predecessor = currentBlock->getSinglePredecessor();
845 if (!predecessor)
846 return failure();
847
848 // Check that the predecessor terminates with a switch branch to this block
849 // and that it branches on the same condition and that this branch is the
850 // default destination.
851 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
852 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
853 predSwitch.getDefaultDestination() != currentBlock)
854 return failure();
855
856 // Delete case values that are not possible here.
857 DenseSet<APInt> caseValuesToRemove;
858 auto predDests = predSwitch.getCaseDestinations();
859 auto predCaseValues = predSwitch.getCaseValues();
860 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
861 if (currentBlock != predDests[i])
862 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
863
864 SmallVector<Block *> newCaseDestinations;
865 SmallVector<ValueRange> newCaseOperands;
866 SmallVector<APInt> newCaseValues;
867 bool requiresChange = false;
868
869 auto caseValues = op.getCaseValues();
870 auto caseDests = op.getCaseDestinations();
871 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
872 if (caseValuesToRemove.contains(it.value())) {
873 requiresChange = true;
874 continue;
875 }
876 newCaseDestinations.push_back(caseDests[it.index()]);
877 newCaseOperands.push_back(op.getCaseOperands(it.index()));
878 newCaseValues.push_back(it.value());
879 }
880
881 if (!requiresChange)
882 return failure();
883
884 rewriter.replaceOpWithNewOp<SwitchOp>(
885 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
886 newCaseValues, newCaseDestinations, newCaseOperands);
887 return success();
888 }
889
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)890 void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
891 MLIRContext *context) {
892 results.add(&simplifySwitchWithOnlyDefault)
893 .add(&dropSwitchCasesThatMatchDefault)
894 .add(&simplifyConstSwitchValue)
895 .add(&simplifyPassThroughSwitch)
896 .add(&simplifySwitchFromSwitchOnSameCondition)
897 .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
898 }
899
900 //===----------------------------------------------------------------------===//
901 // TableGen'd op method definitions
902 //===----------------------------------------------------------------------===//
903
904 #define GET_OP_CLASSES
905 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
906