1 //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.for, scf.if and loop.terminator
10 // ops into standard CFG ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "mlir/Transforms/Passes.h"
26
27 using namespace mlir;
28 using namespace mlir::scf;
29
30 namespace {
31
32 struct SCFToControlFlowPass
33 : public SCFToControlFlowBase<SCFToControlFlowPass> {
34 void runOnOperation() override;
35 };
36
37 // Create a CFG subgraph for the loop around its body blocks (if the body
38 // contained other loops, they have been already lowered to a flow of blocks).
39 // Maintain the invariants that a CFG subgraph created for any loop has a single
40 // entry and a single exit, and that the entry/exit blocks are respectively
41 // first/last blocks in the parent region. The original loop operation is
42 // replaced by the initialization operations that set up the initial value of
43 // the loop induction variable (%iv) and computes the loop bounds that are loop-
44 // invariant for affine loops. The operations following the original scf.for
45 // are split out into a separate continuation (exit) block. A condition block is
46 // created before the continuation block. It checks the exit condition of the
47 // loop and branches either to the continuation block, or to the first block of
48 // the body. The condition block takes as arguments the values of the induction
49 // variable followed by loop-carried values. Since it dominates both the body
50 // blocks and the continuation block, loop-carried values are visible in all of
51 // those blocks. Induction variable modification is appended to the last block
52 // of the body (which is the exit block from the body subgraph thanks to the
53 // invariant we maintain) along with a branch that loops back to the condition
54 // block. Loop-carried values are the loop terminator operands, which are
55 // forwarded to the branch.
56 //
57 // +---------------------------------+
58 // | <code before the ForOp> |
59 // | <definitions of %init...> |
60 // | <compute initial %iv value> |
61 // | cf.br cond(%iv, %init...) |
62 // +---------------------------------+
63 // |
64 // -------| |
65 // | v v
66 // | +--------------------------------+
67 // | | cond(%iv, %init...): |
68 // | | <compare %iv to upper bound> |
69 // | | cf.cond_br %r, body, end |
70 // | +--------------------------------+
71 // | | |
72 // | | -------------|
73 // | v |
74 // | +--------------------------------+ |
75 // | | body-first: | |
76 // | | <%init visible by dominance> | |
77 // | | <body contents> | |
78 // | +--------------------------------+ |
79 // | | |
80 // | ... |
81 // | | |
82 // | +--------------------------------+ |
83 // | | body-last: | |
84 // | | <body contents> | |
85 // | | <operands of yield = %yields>| |
86 // | | %new_iv =<add step to %iv> | |
87 // | | cf.br cond(%new_iv, %yields) | |
88 // | +--------------------------------+ |
89 // | | |
90 // |----------- |--------------------
91 // v
92 // +--------------------------------+
93 // | end: |
94 // | <code after the ForOp> |
95 // | <%init visible by dominance> |
96 // +--------------------------------+
97 //
98 struct ForLowering : public OpRewritePattern<ForOp> {
99 using OpRewritePattern<ForOp>::OpRewritePattern;
100
101 LogicalResult matchAndRewrite(ForOp forOp,
102 PatternRewriter &rewriter) const override;
103 };
104
105 // Create a CFG subgraph for the scf.if operation (including its "then" and
106 // optional "else" operation blocks). We maintain the invariants that the
107 // subgraph has a single entry and a single exit point, and that the entry/exit
108 // blocks are respectively the first/last block of the enclosing region. The
109 // operations following the scf.if are split into a continuation (subgraph
110 // exit) block. The condition is lowered to a chain of blocks that implement the
111 // short-circuit scheme. The "scf.if" operation is replaced with a conditional
112 // branch to either the first block of the "then" region, or to the first block
113 // of the "else" region. In these blocks, "scf.yield" is unconditional branches
114 // to the post-dominating block. When the "scf.if" does not return values, the
115 // post-dominating block is the same as the continuation block. When it returns
116 // values, the post-dominating block is a new block with arguments that
117 // correspond to the values returned by the "scf.if" that unconditionally
118 // branches to the continuation block. This allows block arguments to dominate
119 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
120 // new block allows us to avoid modifying the argument list of an existing
121 // block, which is illegal in a conversion pattern). When the "else" region is
122 // empty, which is only allowed for "scf.if"s that don't return values, the
123 // condition branches directly to the continuation block.
124 //
125 // CFG for a scf.if with else and without results.
126 //
127 // +--------------------------------+
128 // | <code before the IfOp> |
129 // | cf.cond_br %cond, %then, %else |
130 // +--------------------------------+
131 // | |
132 // | --------------|
133 // v |
134 // +--------------------------------+ |
135 // | then: | |
136 // | <then contents> | |
137 // | cf.br continue | |
138 // +--------------------------------+ |
139 // | |
140 // |---------- |-------------
141 // | V
142 // | +--------------------------------+
143 // | | else: |
144 // | | <else contents> |
145 // | | cf.br continue |
146 // | +--------------------------------+
147 // | |
148 // ------| |
149 // v v
150 // +--------------------------------+
151 // | continue: |
152 // | <code after the IfOp> |
153 // +--------------------------------+
154 //
155 // CFG for a scf.if with results.
156 //
157 // +--------------------------------+
158 // | <code before the IfOp> |
159 // | cf.cond_br %cond, %then, %else |
160 // +--------------------------------+
161 // | |
162 // | --------------|
163 // v |
164 // +--------------------------------+ |
165 // | then: | |
166 // | <then contents> | |
167 // | cf.br dom(%args...) | |
168 // +--------------------------------+ |
169 // | |
170 // |---------- |-------------
171 // | V
172 // | +--------------------------------+
173 // | | else: |
174 // | | <else contents> |
175 // | | cf.br dom(%args...) |
176 // | +--------------------------------+
177 // | |
178 // ------| |
179 // v v
180 // +--------------------------------+
181 // | dom(%args...): |
182 // | cf.br continue |
183 // +--------------------------------+
184 // |
185 // v
186 // +--------------------------------+
187 // | continue: |
188 // | <code after the IfOp> |
189 // +--------------------------------+
190 //
191 struct IfLowering : public OpRewritePattern<IfOp> {
192 using OpRewritePattern<IfOp>::OpRewritePattern;
193
194 LogicalResult matchAndRewrite(IfOp ifOp,
195 PatternRewriter &rewriter) const override;
196 };
197
198 struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
199 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
200
201 LogicalResult matchAndRewrite(ExecuteRegionOp op,
202 PatternRewriter &rewriter) const override;
203 };
204
205 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
206 using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
207
208 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
209 PatternRewriter &rewriter) const override;
210 };
211
212 /// Create a CFG subgraph for this loop construct. The regions of the loop need
213 /// not be a single block anymore (for example, if other SCF constructs that
214 /// they contain have been already converted to CFG), but need to be single-exit
215 /// from the last block of each region. The operations following the original
216 /// WhileOp are split into a new continuation block. Both regions of the WhileOp
217 /// are inlined, and their terminators are rewritten to organize the control
218 /// flow implementing the loop as follows.
219 ///
220 /// +---------------------------------+
221 /// | <code before the WhileOp> |
222 /// | cf.br ^before(%operands...) |
223 /// +---------------------------------+
224 /// |
225 /// -------| |
226 /// | v v
227 /// | +--------------------------------+
228 /// | | ^before(%bargs...): |
229 /// | | %vals... = <some payload> |
230 /// | +--------------------------------+
231 /// | |
232 /// | ...
233 /// | |
234 /// | +--------------------------------+
235 /// | | ^before-last:
236 /// | | %cond = <compute condition> |
237 /// | | cf.cond_br %cond, |
238 /// | | ^after(%vals...), ^cont |
239 /// | +--------------------------------+
240 /// | | |
241 /// | | -------------|
242 /// | v |
243 /// | +--------------------------------+ |
244 /// | | ^after(%aargs...): | |
245 /// | | <body contents> | |
246 /// | +--------------------------------+ |
247 /// | | |
248 /// | ... |
249 /// | | |
250 /// | +--------------------------------+ |
251 /// | | ^after-last: | |
252 /// | | %yields... = <some payload> | |
253 /// | | cf.br ^before(%yields...) | |
254 /// | +--------------------------------+ |
255 /// | | |
256 /// |----------- |--------------------
257 /// v
258 /// +--------------------------------+
259 /// | ^cont: |
260 /// | <code after the WhileOp> |
261 /// | <%vals from 'before' region |
262 /// | visible by dominance> |
263 /// +--------------------------------+
264 ///
265 /// Values are communicated between ex-regions (the groups of blocks that used
266 /// to form a region before inlining) through block arguments of their
267 /// entry blocks, which are visible in all other dominated blocks. Similarly,
268 /// the results of the WhileOp are defined in the 'before' region, which is
269 /// required to have a single existing block, and are therefore accessible in
270 /// the continuation block due to dominance.
271 struct WhileLowering : public OpRewritePattern<WhileOp> {
272 using OpRewritePattern<WhileOp>::OpRewritePattern;
273
274 LogicalResult matchAndRewrite(WhileOp whileOp,
275 PatternRewriter &rewriter) const override;
276 };
277
278 /// Optimized version of the above for the case of the "after" region merely
279 /// forwarding its arguments back to the "before" region (i.e., a "do-while"
280 /// loop). This avoid inlining the "after" region completely and branches back
281 /// to the "before" entry instead.
282 struct DoWhileLowering : public OpRewritePattern<WhileOp> {
283 using OpRewritePattern<WhileOp>::OpRewritePattern;
284
285 LogicalResult matchAndRewrite(WhileOp whileOp,
286 PatternRewriter &rewriter) const override;
287 };
288 } // namespace
289
matchAndRewrite(ForOp forOp,PatternRewriter & rewriter) const290 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
291 PatternRewriter &rewriter) const {
292 Location loc = forOp.getLoc();
293
294 // Start by splitting the block containing the 'scf.for' into two parts.
295 // The part before will get the init code, the part after will be the end
296 // point.
297 auto *initBlock = rewriter.getInsertionBlock();
298 auto initPosition = rewriter.getInsertionPoint();
299 auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
300
301 // Use the first block of the loop body as the condition block since it is the
302 // block that has the induction variable and loop-carried values as arguments.
303 // Split out all operations from the first block into a new block. Move all
304 // body blocks from the loop body region to the region containing the loop.
305 auto *conditionBlock = &forOp.getRegion().front();
306 auto *firstBodyBlock =
307 rewriter.splitBlock(conditionBlock, conditionBlock->begin());
308 auto *lastBodyBlock = &forOp.getRegion().back();
309 rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
310 auto iv = conditionBlock->getArgument(0);
311
312 // Append the induction variable stepping logic to the last body block and
313 // branch back to the condition block. Loop-carried values are taken from
314 // operands of the loop terminator.
315 Operation *terminator = lastBodyBlock->getTerminator();
316 rewriter.setInsertionPointToEnd(lastBodyBlock);
317 auto step = forOp.getStep();
318 auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult();
319 if (!stepped)
320 return failure();
321
322 SmallVector<Value, 8> loopCarried;
323 loopCarried.push_back(stepped);
324 loopCarried.append(terminator->operand_begin(), terminator->operand_end());
325 rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
326 rewriter.eraseOp(terminator);
327
328 // Compute loop bounds before branching to the condition.
329 rewriter.setInsertionPointToEnd(initBlock);
330 Value lowerBound = forOp.getLowerBound();
331 Value upperBound = forOp.getUpperBound();
332 if (!lowerBound || !upperBound)
333 return failure();
334
335 // The initial values of loop-carried values is obtained from the operands
336 // of the loop operation.
337 SmallVector<Value, 8> destOperands;
338 destOperands.push_back(lowerBound);
339 auto iterOperands = forOp.getIterOperands();
340 destOperands.append(iterOperands.begin(), iterOperands.end());
341 rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
342
343 // With the body block done, we can fill in the condition block.
344 rewriter.setInsertionPointToEnd(conditionBlock);
345 auto comparison = rewriter.create<arith::CmpIOp>(
346 loc, arith::CmpIPredicate::slt, iv, upperBound);
347
348 rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
349 ArrayRef<Value>(), endBlock,
350 ArrayRef<Value>());
351 // The result of the loop operation is the values of the condition block
352 // arguments except the induction variable on the last iteration.
353 rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
354 return success();
355 }
356
matchAndRewrite(IfOp ifOp,PatternRewriter & rewriter) const357 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
358 PatternRewriter &rewriter) const {
359 auto loc = ifOp.getLoc();
360
361 // Start by splitting the block containing the 'scf.if' into two parts.
362 // The part before will contain the condition, the part after will be the
363 // continuation point.
364 auto *condBlock = rewriter.getInsertionBlock();
365 auto opPosition = rewriter.getInsertionPoint();
366 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
367 Block *continueBlock;
368 if (ifOp.getNumResults() == 0) {
369 continueBlock = remainingOpsBlock;
370 } else {
371 continueBlock =
372 rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
373 SmallVector<Location>(ifOp.getNumResults(), loc));
374 rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
375 }
376
377 // Move blocks from the "then" region to the region containing 'scf.if',
378 // place it before the continuation block, and branch to it.
379 auto &thenRegion = ifOp.getThenRegion();
380 auto *thenBlock = &thenRegion.front();
381 Operation *thenTerminator = thenRegion.back().getTerminator();
382 ValueRange thenTerminatorOperands = thenTerminator->getOperands();
383 rewriter.setInsertionPointToEnd(&thenRegion.back());
384 rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
385 rewriter.eraseOp(thenTerminator);
386 rewriter.inlineRegionBefore(thenRegion, continueBlock);
387
388 // Move blocks from the "else" region (if present) to the region containing
389 // 'scf.if', place it before the continuation block and branch to it. It
390 // will be placed after the "then" regions.
391 auto *elseBlock = continueBlock;
392 auto &elseRegion = ifOp.getElseRegion();
393 if (!elseRegion.empty()) {
394 elseBlock = &elseRegion.front();
395 Operation *elseTerminator = elseRegion.back().getTerminator();
396 ValueRange elseTerminatorOperands = elseTerminator->getOperands();
397 rewriter.setInsertionPointToEnd(&elseRegion.back());
398 rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
399 rewriter.eraseOp(elseTerminator);
400 rewriter.inlineRegionBefore(elseRegion, continueBlock);
401 }
402
403 rewriter.setInsertionPointToEnd(condBlock);
404 rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
405 /*trueArgs=*/ArrayRef<Value>(), elseBlock,
406 /*falseArgs=*/ArrayRef<Value>());
407
408 // Ok, we're done!
409 rewriter.replaceOp(ifOp, continueBlock->getArguments());
410 return success();
411 }
412
413 LogicalResult
matchAndRewrite(ExecuteRegionOp op,PatternRewriter & rewriter) const414 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
415 PatternRewriter &rewriter) const {
416 auto loc = op.getLoc();
417
418 auto *condBlock = rewriter.getInsertionBlock();
419 auto opPosition = rewriter.getInsertionPoint();
420 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
421
422 auto ®ion = op.getRegion();
423 rewriter.setInsertionPointToEnd(condBlock);
424 rewriter.create<cf::BranchOp>(loc, ®ion.front());
425
426 for (Block &block : region) {
427 if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
428 ValueRange terminatorOperands = terminator->getOperands();
429 rewriter.setInsertionPointToEnd(&block);
430 rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
431 rewriter.eraseOp(terminator);
432 }
433 }
434
435 rewriter.inlineRegionBefore(region, remainingOpsBlock);
436
437 SmallVector<Value> vals;
438 SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
439 for (auto arg :
440 remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
441 vals.push_back(arg);
442 rewriter.replaceOp(op, vals);
443 return success();
444 }
445
446 LogicalResult
matchAndRewrite(ParallelOp parallelOp,PatternRewriter & rewriter) const447 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
448 PatternRewriter &rewriter) const {
449 Location loc = parallelOp.getLoc();
450
451 // For a parallel loop, we essentially need to create an n-dimensional loop
452 // nest. We do this by translating to scf.for ops and have those lowered in
453 // a further rewrite. If a parallel loop contains reductions (and thus returns
454 // values), forward the initial values for the reductions down the loop
455 // hierarchy and bubble up the results by modifying the "yield" terminator.
456 SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
457 SmallVector<Value, 4> ivs;
458 ivs.reserve(parallelOp.getNumLoops());
459 bool first = true;
460 SmallVector<Value, 4> loopResults(iterArgs);
461 for (auto loopOperands :
462 llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
463 parallelOp.getUpperBound(), parallelOp.getStep())) {
464 Value iv, lower, upper, step;
465 std::tie(iv, lower, upper, step) = loopOperands;
466 ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
467 ivs.push_back(forOp.getInductionVar());
468 auto iterRange = forOp.getRegionIterArgs();
469 iterArgs.assign(iterRange.begin(), iterRange.end());
470
471 if (first) {
472 // Store the results of the outermost loop that will be used to replace
473 // the results of the parallel loop when it is fully rewritten.
474 loopResults.assign(forOp.result_begin(), forOp.result_end());
475 first = false;
476 } else if (!forOp.getResults().empty()) {
477 // A loop is constructed with an empty "yield" terminator if there are
478 // no results.
479 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
480 rewriter.create<scf::YieldOp>(loc, forOp.getResults());
481 }
482
483 rewriter.setInsertionPointToStart(forOp.getBody());
484 }
485
486 // First, merge reduction blocks into the main region.
487 SmallVector<Value, 4> yieldOperands;
488 yieldOperands.reserve(parallelOp.getNumResults());
489 for (auto &op : *parallelOp.getBody()) {
490 auto reduce = dyn_cast<ReduceOp>(op);
491 if (!reduce)
492 continue;
493
494 Block &reduceBlock = reduce.getReductionOperator().front();
495 Value arg = iterArgs[yieldOperands.size()];
496 yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
497 rewriter.eraseOp(reduceBlock.getTerminator());
498 rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
499 rewriter.eraseOp(reduce);
500 }
501
502 // Then merge the loop body without the terminator.
503 rewriter.eraseOp(parallelOp.getBody()->getTerminator());
504 Block *newBody = rewriter.getInsertionBlock();
505 if (newBody->empty())
506 rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
507 else
508 rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
509 ivs);
510
511 // Finally, create the terminator if required (for loops with no results, it
512 // has been already created in loop construction).
513 if (!yieldOperands.empty()) {
514 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
515 rewriter.create<scf::YieldOp>(loc, yieldOperands);
516 }
517
518 rewriter.replaceOp(parallelOp, loopResults);
519
520 return success();
521 }
522
matchAndRewrite(WhileOp whileOp,PatternRewriter & rewriter) const523 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
524 PatternRewriter &rewriter) const {
525 OpBuilder::InsertionGuard guard(rewriter);
526 Location loc = whileOp.getLoc();
527
528 // Split the current block before the WhileOp to create the inlining point.
529 Block *currentBlock = rewriter.getInsertionBlock();
530 Block *continuation =
531 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
532
533 // Inline both regions.
534 Block *after = &whileOp.getAfter().front();
535 Block *afterLast = &whileOp.getAfter().back();
536 Block *before = &whileOp.getBefore().front();
537 Block *beforeLast = &whileOp.getBefore().back();
538 rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
539 rewriter.inlineRegionBefore(whileOp.getBefore(), after);
540
541 // Branch to the "before" region.
542 rewriter.setInsertionPointToEnd(currentBlock);
543 rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
544
545 // Replace terminators with branches. Assuming bodies are SESE, which holds
546 // given only the patterns from this file, we only need to look at the last
547 // block. This should be reconsidered if we allow break/continue in SCF.
548 rewriter.setInsertionPointToEnd(beforeLast);
549 auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
550 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
551 after, condOp.getArgs(),
552 continuation, ValueRange());
553
554 rewriter.setInsertionPointToEnd(afterLast);
555 auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
556 rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
557 yieldOp.getResults());
558
559 // Replace the op with values "yielded" from the "before" region, which are
560 // visible by dominance.
561 rewriter.replaceOp(whileOp, condOp.getArgs());
562
563 return success();
564 }
565
566 LogicalResult
matchAndRewrite(WhileOp whileOp,PatternRewriter & rewriter) const567 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
568 PatternRewriter &rewriter) const {
569 if (!llvm::hasSingleElement(whileOp.getAfter()))
570 return rewriter.notifyMatchFailure(whileOp,
571 "do-while simplification applicable to "
572 "single-block 'after' region only");
573
574 Block &afterBlock = whileOp.getAfter().front();
575 if (!llvm::hasSingleElement(afterBlock))
576 return rewriter.notifyMatchFailure(whileOp,
577 "do-while simplification applicable "
578 "only if 'after' region has no payload");
579
580 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
581 if (!yield || yield.getResults() != afterBlock.getArguments())
582 return rewriter.notifyMatchFailure(whileOp,
583 "do-while simplification applicable "
584 "only to forwarding 'after' regions");
585
586 // Split the current block before the WhileOp to create the inlining point.
587 OpBuilder::InsertionGuard guard(rewriter);
588 Block *currentBlock = rewriter.getInsertionBlock();
589 Block *continuation =
590 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
591
592 // Only the "before" region should be inlined.
593 Block *before = &whileOp.getBefore().front();
594 Block *beforeLast = &whileOp.getBefore().back();
595 rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
596
597 // Branch to the "before" region.
598 rewriter.setInsertionPointToEnd(currentBlock);
599 rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
600
601 // Loop around the "before" region based on condition.
602 rewriter.setInsertionPointToEnd(beforeLast);
603 auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
604 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
605 before, condOp.getArgs(),
606 continuation, ValueRange());
607
608 // Replace the op with values "yielded" from the "before" region, which are
609 // visible by dominance.
610 rewriter.replaceOp(whileOp, condOp.getArgs());
611
612 return success();
613 }
614
populateSCFToControlFlowConversionPatterns(RewritePatternSet & patterns)615 void mlir::populateSCFToControlFlowConversionPatterns(
616 RewritePatternSet &patterns) {
617 patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
618 ExecuteRegionLowering>(patterns.getContext());
619 patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
620 }
621
runOnOperation()622 void SCFToControlFlowPass::runOnOperation() {
623 RewritePatternSet patterns(&getContext());
624 populateSCFToControlFlowConversionPatterns(patterns);
625
626 // Configure conversion to lower out SCF operations.
627 ConversionTarget target(getContext());
628 target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
629 scf::ExecuteRegionOp>();
630 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
631 if (failed(
632 applyPartialConversion(getOperation(), target, std::move(patterns))))
633 signalPassFailure();
634 }
635
createConvertSCFToCFPass()636 std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
637 return std::make_unique<SCFToControlFlowPass>();
638 }
639