1*ace01605SRiver Riddle //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2*ace01605SRiver Riddle //
3*ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*ace01605SRiver Riddle //
7*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8*ace01605SRiver Riddle //
9*ace01605SRiver Riddle // This file implements a pass to convert scf.for, scf.if and loop.terminator
10*ace01605SRiver Riddle // ops into standard CFG ops.
11*ace01605SRiver Riddle //
12*ace01605SRiver Riddle //===----------------------------------------------------------------------===//
13*ace01605SRiver Riddle 
14*ace01605SRiver Riddle #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
15*ace01605SRiver Riddle #include "../PassDetail.h"
16*ace01605SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17*ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18*ace01605SRiver Riddle #include "mlir/Dialect/SCF/SCF.h"
19*ace01605SRiver Riddle #include "mlir/IR/BlockAndValueMapping.h"
20*ace01605SRiver Riddle #include "mlir/IR/Builders.h"
21*ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h"
22*ace01605SRiver Riddle #include "mlir/IR/MLIRContext.h"
23*ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h"
24*ace01605SRiver Riddle #include "mlir/Transforms/DialectConversion.h"
25*ace01605SRiver Riddle #include "mlir/Transforms/Passes.h"
26*ace01605SRiver Riddle 
27*ace01605SRiver Riddle using namespace mlir;
28*ace01605SRiver Riddle using namespace mlir::scf;
29*ace01605SRiver Riddle 
30*ace01605SRiver Riddle namespace {
31*ace01605SRiver Riddle 
32*ace01605SRiver Riddle struct SCFToControlFlowPass
33*ace01605SRiver Riddle     : public SCFToControlFlowBase<SCFToControlFlowPass> {
34*ace01605SRiver Riddle   void runOnOperation() override;
35*ace01605SRiver Riddle };
36*ace01605SRiver Riddle 
37*ace01605SRiver Riddle // Create a CFG subgraph for the loop around its body blocks (if the body
38*ace01605SRiver Riddle // contained other loops, they have been already lowered to a flow of blocks).
39*ace01605SRiver Riddle // Maintain the invariants that a CFG subgraph created for any loop has a single
40*ace01605SRiver Riddle // entry and a single exit, and that the entry/exit blocks are respectively
41*ace01605SRiver Riddle // first/last blocks in the parent region.  The original loop operation is
42*ace01605SRiver Riddle // replaced by the initialization operations that set up the initial value of
43*ace01605SRiver Riddle // the loop induction variable (%iv) and computes the loop bounds that are loop-
44*ace01605SRiver Riddle // invariant for affine loops.  The operations following the original scf.for
45*ace01605SRiver Riddle // are split out into a separate continuation (exit) block. A condition block is
46*ace01605SRiver Riddle // created before the continuation block. It checks the exit condition of the
47*ace01605SRiver Riddle // loop and branches either to the continuation block, or to the first block of
48*ace01605SRiver Riddle // the body. The condition block takes as arguments the values of the induction
49*ace01605SRiver Riddle // variable followed by loop-carried values. Since it dominates both the body
50*ace01605SRiver Riddle // blocks and the continuation block, loop-carried values are visible in all of
51*ace01605SRiver Riddle // those blocks. Induction variable modification is appended to the last block
52*ace01605SRiver Riddle // of the body (which is the exit block from the body subgraph thanks to the
53*ace01605SRiver Riddle // invariant we maintain) along with a branch that loops back to the condition
54*ace01605SRiver Riddle // block. Loop-carried values are the loop terminator operands, which are
55*ace01605SRiver Riddle // forwarded to the branch.
56*ace01605SRiver Riddle //
57*ace01605SRiver Riddle //      +---------------------------------+
58*ace01605SRiver Riddle //      |   <code before the ForOp>       |
59*ace01605SRiver Riddle //      |   <definitions of %init...>     |
60*ace01605SRiver Riddle //      |   <compute initial %iv value>   |
61*ace01605SRiver Riddle //      |   cf.br cond(%iv, %init...)        |
62*ace01605SRiver Riddle //      +---------------------------------+
63*ace01605SRiver Riddle //             |
64*ace01605SRiver Riddle //  -------|   |
65*ace01605SRiver Riddle //  |      v   v
66*ace01605SRiver Riddle //  |   +--------------------------------+
67*ace01605SRiver Riddle //  |   | cond(%iv, %init...):           |
68*ace01605SRiver Riddle //  |   |   <compare %iv to upper bound> |
69*ace01605SRiver Riddle //  |   |   cf.cond_br %r, body, end        |
70*ace01605SRiver Riddle //  |   +--------------------------------+
71*ace01605SRiver Riddle //  |          |               |
72*ace01605SRiver Riddle //  |          |               -------------|
73*ace01605SRiver Riddle //  |          v                            |
74*ace01605SRiver Riddle //  |   +--------------------------------+  |
75*ace01605SRiver Riddle //  |   | body-first:                    |  |
76*ace01605SRiver Riddle //  |   |   <%init visible by dominance> |  |
77*ace01605SRiver Riddle //  |   |   <body contents>              |  |
78*ace01605SRiver Riddle //  |   +--------------------------------+  |
79*ace01605SRiver Riddle //  |                   |                   |
80*ace01605SRiver Riddle //  |                  ...                  |
81*ace01605SRiver Riddle //  |                   |                   |
82*ace01605SRiver Riddle //  |   +--------------------------------+  |
83*ace01605SRiver Riddle //  |   | body-last:                     |  |
84*ace01605SRiver Riddle //  |   |   <body contents>              |  |
85*ace01605SRiver Riddle //  |   |   <operands of yield = %yields>|  |
86*ace01605SRiver Riddle //  |   |   %new_iv =<add step to %iv>   |  |
87*ace01605SRiver Riddle //  |   |   cf.br cond(%new_iv, %yields)    |  |
88*ace01605SRiver Riddle //  |   +--------------------------------+  |
89*ace01605SRiver Riddle //  |          |                            |
90*ace01605SRiver Riddle //  |-----------        |--------------------
91*ace01605SRiver Riddle //                      v
92*ace01605SRiver Riddle //      +--------------------------------+
93*ace01605SRiver Riddle //      | end:                           |
94*ace01605SRiver Riddle //      |   <code after the ForOp>       |
95*ace01605SRiver Riddle //      |   <%init visible by dominance> |
96*ace01605SRiver Riddle //      +--------------------------------+
97*ace01605SRiver Riddle //
98*ace01605SRiver Riddle struct ForLowering : public OpRewritePattern<ForOp> {
99*ace01605SRiver Riddle   using OpRewritePattern<ForOp>::OpRewritePattern;
100*ace01605SRiver Riddle 
101*ace01605SRiver Riddle   LogicalResult matchAndRewrite(ForOp forOp,
102*ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override;
103*ace01605SRiver Riddle };
104*ace01605SRiver Riddle 
105*ace01605SRiver Riddle // Create a CFG subgraph for the scf.if operation (including its "then" and
106*ace01605SRiver Riddle // optional "else" operation blocks).  We maintain the invariants that the
107*ace01605SRiver Riddle // subgraph has a single entry and a single exit point, and that the entry/exit
108*ace01605SRiver Riddle // blocks are respectively the first/last block of the enclosing region. The
109*ace01605SRiver Riddle // operations following the scf.if are split into a continuation (subgraph
110*ace01605SRiver Riddle // exit) block. The condition is lowered to a chain of blocks that implement the
111*ace01605SRiver Riddle // short-circuit scheme. The "scf.if" operation is replaced with a conditional
112*ace01605SRiver Riddle // branch to either the first block of the "then" region, or to the first block
113*ace01605SRiver Riddle // of the "else" region. In these blocks, "scf.yield" is unconditional branches
114*ace01605SRiver Riddle // to the post-dominating block. When the "scf.if" does not return values, the
115*ace01605SRiver Riddle // post-dominating block is the same as the continuation block. When it returns
116*ace01605SRiver Riddle // values, the post-dominating block is a new block with arguments that
117*ace01605SRiver Riddle // correspond to the values returned by the "scf.if" that unconditionally
118*ace01605SRiver Riddle // branches to the continuation block. This allows block arguments to dominate
119*ace01605SRiver Riddle // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
120*ace01605SRiver Riddle // new block allows us to avoid modifying the argument list of an existing
121*ace01605SRiver Riddle // block, which is illegal in a conversion pattern). When the "else" region is
122*ace01605SRiver Riddle // empty, which is only allowed for "scf.if"s that don't return values, the
123*ace01605SRiver Riddle // condition branches directly to the continuation block.
124*ace01605SRiver Riddle //
125*ace01605SRiver Riddle // CFG for a scf.if with else and without results.
126*ace01605SRiver Riddle //
127*ace01605SRiver Riddle //      +--------------------------------+
128*ace01605SRiver Riddle //      | <code before the IfOp>         |
129*ace01605SRiver Riddle //      | cf.cond_br %cond, %then, %else    |
130*ace01605SRiver Riddle //      +--------------------------------+
131*ace01605SRiver Riddle //             |              |
132*ace01605SRiver Riddle //             |              --------------|
133*ace01605SRiver Riddle //             v                            |
134*ace01605SRiver Riddle //      +--------------------------------+  |
135*ace01605SRiver Riddle //      | then:                          |  |
136*ace01605SRiver Riddle //      |   <then contents>              |  |
137*ace01605SRiver Riddle //      |   cf.br continue                  |  |
138*ace01605SRiver Riddle //      +--------------------------------+  |
139*ace01605SRiver Riddle //             |                            |
140*ace01605SRiver Riddle //   |----------               |-------------
141*ace01605SRiver Riddle //   |                         V
142*ace01605SRiver Riddle //   |  +--------------------------------+
143*ace01605SRiver Riddle //   |  | else:                          |
144*ace01605SRiver Riddle //   |  |   <else contents>              |
145*ace01605SRiver Riddle //   |  |   cf.br continue                  |
146*ace01605SRiver Riddle //   |  +--------------------------------+
147*ace01605SRiver Riddle //   |         |
148*ace01605SRiver Riddle //   ------|   |
149*ace01605SRiver Riddle //         v   v
150*ace01605SRiver Riddle //      +--------------------------------+
151*ace01605SRiver Riddle //      | continue:                      |
152*ace01605SRiver Riddle //      |   <code after the IfOp>        |
153*ace01605SRiver Riddle //      +--------------------------------+
154*ace01605SRiver Riddle //
155*ace01605SRiver Riddle // CFG for a scf.if with results.
156*ace01605SRiver Riddle //
157*ace01605SRiver Riddle //      +--------------------------------+
158*ace01605SRiver Riddle //      | <code before the IfOp>         |
159*ace01605SRiver Riddle //      | cf.cond_br %cond, %then, %else    |
160*ace01605SRiver Riddle //      +--------------------------------+
161*ace01605SRiver Riddle //             |              |
162*ace01605SRiver Riddle //             |              --------------|
163*ace01605SRiver Riddle //             v                            |
164*ace01605SRiver Riddle //      +--------------------------------+  |
165*ace01605SRiver Riddle //      | then:                          |  |
166*ace01605SRiver Riddle //      |   <then contents>              |  |
167*ace01605SRiver Riddle //      |   cf.br dom(%args...)             |  |
168*ace01605SRiver Riddle //      +--------------------------------+  |
169*ace01605SRiver Riddle //             |                            |
170*ace01605SRiver Riddle //   |----------               |-------------
171*ace01605SRiver Riddle //   |                         V
172*ace01605SRiver Riddle //   |  +--------------------------------+
173*ace01605SRiver Riddle //   |  | else:                          |
174*ace01605SRiver Riddle //   |  |   <else contents>              |
175*ace01605SRiver Riddle //   |  |   cf.br dom(%args...)             |
176*ace01605SRiver Riddle //   |  +--------------------------------+
177*ace01605SRiver Riddle //   |         |
178*ace01605SRiver Riddle //   ------|   |
179*ace01605SRiver Riddle //         v   v
180*ace01605SRiver Riddle //      +--------------------------------+
181*ace01605SRiver Riddle //      | dom(%args...):                 |
182*ace01605SRiver Riddle //      |   cf.br continue                  |
183*ace01605SRiver Riddle //      +--------------------------------+
184*ace01605SRiver Riddle //             |
185*ace01605SRiver Riddle //             v
186*ace01605SRiver Riddle //      +--------------------------------+
187*ace01605SRiver Riddle //      | continue:                      |
188*ace01605SRiver Riddle //      | <code after the IfOp>          |
189*ace01605SRiver Riddle //      +--------------------------------+
190*ace01605SRiver Riddle //
191*ace01605SRiver Riddle struct IfLowering : public OpRewritePattern<IfOp> {
192*ace01605SRiver Riddle   using OpRewritePattern<IfOp>::OpRewritePattern;
193*ace01605SRiver Riddle 
194*ace01605SRiver Riddle   LogicalResult matchAndRewrite(IfOp ifOp,
195*ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override;
196*ace01605SRiver Riddle };
197*ace01605SRiver Riddle 
198*ace01605SRiver Riddle struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
199*ace01605SRiver Riddle   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
200*ace01605SRiver Riddle 
201*ace01605SRiver Riddle   LogicalResult matchAndRewrite(ExecuteRegionOp op,
202*ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override;
203*ace01605SRiver Riddle };
204*ace01605SRiver Riddle 
205*ace01605SRiver Riddle struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
206*ace01605SRiver Riddle   using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
207*ace01605SRiver Riddle 
208*ace01605SRiver Riddle   LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
209*ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override;
210*ace01605SRiver Riddle };
211*ace01605SRiver Riddle 
212*ace01605SRiver Riddle /// Create a CFG subgraph for this loop construct. The regions of the loop need
213*ace01605SRiver Riddle /// not be a single block anymore (for example, if other SCF constructs that
214*ace01605SRiver Riddle /// they contain have been already converted to CFG), but need to be single-exit
215*ace01605SRiver Riddle /// from the last block of each region. The operations following the original
216*ace01605SRiver Riddle /// WhileOp are split into a new continuation block. Both regions of the WhileOp
217*ace01605SRiver Riddle /// are inlined, and their terminators are rewritten to organize the control
218*ace01605SRiver Riddle /// flow implementing the loop as follows.
219*ace01605SRiver Riddle ///
220*ace01605SRiver Riddle ///      +---------------------------------+
221*ace01605SRiver Riddle ///      |   <code before the WhileOp>     |
222*ace01605SRiver Riddle ///      |   cf.br ^before(%operands...)      |
223*ace01605SRiver Riddle ///      +---------------------------------+
224*ace01605SRiver Riddle ///             |
225*ace01605SRiver Riddle ///  -------|   |
226*ace01605SRiver Riddle ///  |      v   v
227*ace01605SRiver Riddle ///  |   +--------------------------------+
228*ace01605SRiver Riddle ///  |   | ^before(%bargs...):            |
229*ace01605SRiver Riddle ///  |   |   %vals... = <some payload>    |
230*ace01605SRiver Riddle ///  |   +--------------------------------+
231*ace01605SRiver Riddle ///  |                   |
232*ace01605SRiver Riddle ///  |                  ...
233*ace01605SRiver Riddle ///  |                   |
234*ace01605SRiver Riddle ///  |   +--------------------------------+
235*ace01605SRiver Riddle ///  |   | ^before-last:
236*ace01605SRiver Riddle ///  |   |   %cond = <compute condition>  |
237*ace01605SRiver Riddle ///  |   |   cf.cond_br %cond,               |
238*ace01605SRiver Riddle ///  |   |        ^after(%vals...), ^cont |
239*ace01605SRiver Riddle ///  |   +--------------------------------+
240*ace01605SRiver Riddle ///  |          |               |
241*ace01605SRiver Riddle ///  |          |               -------------|
242*ace01605SRiver Riddle ///  |          v                            |
243*ace01605SRiver Riddle ///  |   +--------------------------------+  |
244*ace01605SRiver Riddle ///  |   | ^after(%aargs...):             |  |
245*ace01605SRiver Riddle ///  |   |   <body contents>              |  |
246*ace01605SRiver Riddle ///  |   +--------------------------------+  |
247*ace01605SRiver Riddle ///  |                   |                   |
248*ace01605SRiver Riddle ///  |                  ...                  |
249*ace01605SRiver Riddle ///  |                   |                   |
250*ace01605SRiver Riddle ///  |   +--------------------------------+  |
251*ace01605SRiver Riddle ///  |   | ^after-last:                   |  |
252*ace01605SRiver Riddle ///  |   |   %yields... = <some payload>  |  |
253*ace01605SRiver Riddle ///  |   |   cf.br ^before(%yields...)       |  |
254*ace01605SRiver Riddle ///  |   +--------------------------------+  |
255*ace01605SRiver Riddle ///  |          |                            |
256*ace01605SRiver Riddle ///  |-----------        |--------------------
257*ace01605SRiver Riddle ///                      v
258*ace01605SRiver Riddle ///      +--------------------------------+
259*ace01605SRiver Riddle ///      | ^cont:                         |
260*ace01605SRiver Riddle ///      |   <code after the WhileOp>     |
261*ace01605SRiver Riddle ///      |   <%vals from 'before' region  |
262*ace01605SRiver Riddle ///      |          visible by dominance> |
263*ace01605SRiver Riddle ///      +--------------------------------+
264*ace01605SRiver Riddle ///
265*ace01605SRiver Riddle /// Values are communicated between ex-regions (the groups of blocks that used
266*ace01605SRiver Riddle /// to form a region before inlining) through block arguments of their
267*ace01605SRiver Riddle /// entry blocks, which are visible in all other dominated blocks. Similarly,
268*ace01605SRiver Riddle /// the results of the WhileOp are defined in the 'before' region, which is
269*ace01605SRiver Riddle /// required to have a single existing block, and are therefore accessible in
270*ace01605SRiver Riddle /// the continuation block due to dominance.
271*ace01605SRiver Riddle struct WhileLowering : public OpRewritePattern<WhileOp> {
272*ace01605SRiver Riddle   using OpRewritePattern<WhileOp>::OpRewritePattern;
273*ace01605SRiver Riddle 
274*ace01605SRiver Riddle   LogicalResult matchAndRewrite(WhileOp whileOp,
275*ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override;
276*ace01605SRiver Riddle };
277*ace01605SRiver Riddle 
278*ace01605SRiver Riddle /// Optimized version of the above for the case of the "after" region merely
279*ace01605SRiver Riddle /// forwarding its arguments back to the "before" region (i.e., a "do-while"
280*ace01605SRiver Riddle /// loop). This avoid inlining the "after" region completely and branches back
281*ace01605SRiver Riddle /// to the "before" entry instead.
282*ace01605SRiver Riddle struct DoWhileLowering : public OpRewritePattern<WhileOp> {
283*ace01605SRiver Riddle   using OpRewritePattern<WhileOp>::OpRewritePattern;
284*ace01605SRiver Riddle 
285*ace01605SRiver Riddle   LogicalResult matchAndRewrite(WhileOp whileOp,
286*ace01605SRiver Riddle                                 PatternRewriter &rewriter) const override;
287*ace01605SRiver Riddle };
288*ace01605SRiver Riddle } // namespace
289*ace01605SRiver Riddle 
290*ace01605SRiver Riddle LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
291*ace01605SRiver Riddle                                            PatternRewriter &rewriter) const {
292*ace01605SRiver Riddle   Location loc = forOp.getLoc();
293*ace01605SRiver Riddle 
294*ace01605SRiver Riddle   // Start by splitting the block containing the 'scf.for' into two parts.
295*ace01605SRiver Riddle   // The part before will get the init code, the part after will be the end
296*ace01605SRiver Riddle   // point.
297*ace01605SRiver Riddle   auto *initBlock = rewriter.getInsertionBlock();
298*ace01605SRiver Riddle   auto initPosition = rewriter.getInsertionPoint();
299*ace01605SRiver Riddle   auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
300*ace01605SRiver Riddle 
301*ace01605SRiver Riddle   // Use the first block of the loop body as the condition block since it is the
302*ace01605SRiver Riddle   // block that has the induction variable and loop-carried values as arguments.
303*ace01605SRiver Riddle   // Split out all operations from the first block into a new block. Move all
304*ace01605SRiver Riddle   // body blocks from the loop body region to the region containing the loop.
305*ace01605SRiver Riddle   auto *conditionBlock = &forOp.getRegion().front();
306*ace01605SRiver Riddle   auto *firstBodyBlock =
307*ace01605SRiver Riddle       rewriter.splitBlock(conditionBlock, conditionBlock->begin());
308*ace01605SRiver Riddle   auto *lastBodyBlock = &forOp.getRegion().back();
309*ace01605SRiver Riddle   rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
310*ace01605SRiver Riddle   auto iv = conditionBlock->getArgument(0);
311*ace01605SRiver Riddle 
312*ace01605SRiver Riddle   // Append the induction variable stepping logic to the last body block and
313*ace01605SRiver Riddle   // branch back to the condition block. Loop-carried values are taken from
314*ace01605SRiver Riddle   // operands of the loop terminator.
315*ace01605SRiver Riddle   Operation *terminator = lastBodyBlock->getTerminator();
316*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(lastBodyBlock);
317*ace01605SRiver Riddle   auto step = forOp.getStep();
318*ace01605SRiver Riddle   auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult();
319*ace01605SRiver Riddle   if (!stepped)
320*ace01605SRiver Riddle     return failure();
321*ace01605SRiver Riddle 
322*ace01605SRiver Riddle   SmallVector<Value, 8> loopCarried;
323*ace01605SRiver Riddle   loopCarried.push_back(stepped);
324*ace01605SRiver Riddle   loopCarried.append(terminator->operand_begin(), terminator->operand_end());
325*ace01605SRiver Riddle   rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
326*ace01605SRiver Riddle   rewriter.eraseOp(terminator);
327*ace01605SRiver Riddle 
328*ace01605SRiver Riddle   // Compute loop bounds before branching to the condition.
329*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(initBlock);
330*ace01605SRiver Riddle   Value lowerBound = forOp.getLowerBound();
331*ace01605SRiver Riddle   Value upperBound = forOp.getUpperBound();
332*ace01605SRiver Riddle   if (!lowerBound || !upperBound)
333*ace01605SRiver Riddle     return failure();
334*ace01605SRiver Riddle 
335*ace01605SRiver Riddle   // The initial values of loop-carried values is obtained from the operands
336*ace01605SRiver Riddle   // of the loop operation.
337*ace01605SRiver Riddle   SmallVector<Value, 8> destOperands;
338*ace01605SRiver Riddle   destOperands.push_back(lowerBound);
339*ace01605SRiver Riddle   auto iterOperands = forOp.getIterOperands();
340*ace01605SRiver Riddle   destOperands.append(iterOperands.begin(), iterOperands.end());
341*ace01605SRiver Riddle   rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
342*ace01605SRiver Riddle 
343*ace01605SRiver Riddle   // With the body block done, we can fill in the condition block.
344*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(conditionBlock);
345*ace01605SRiver Riddle   auto comparison = rewriter.create<arith::CmpIOp>(
346*ace01605SRiver Riddle       loc, arith::CmpIPredicate::slt, iv, upperBound);
347*ace01605SRiver Riddle 
348*ace01605SRiver Riddle   rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
349*ace01605SRiver Riddle                                     ArrayRef<Value>(), endBlock,
350*ace01605SRiver Riddle                                     ArrayRef<Value>());
351*ace01605SRiver Riddle   // The result of the loop operation is the values of the condition block
352*ace01605SRiver Riddle   // arguments except the induction variable on the last iteration.
353*ace01605SRiver Riddle   rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
354*ace01605SRiver Riddle   return success();
355*ace01605SRiver Riddle }
356*ace01605SRiver Riddle 
357*ace01605SRiver Riddle LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
358*ace01605SRiver Riddle                                           PatternRewriter &rewriter) const {
359*ace01605SRiver Riddle   auto loc = ifOp.getLoc();
360*ace01605SRiver Riddle 
361*ace01605SRiver Riddle   // Start by splitting the block containing the 'scf.if' into two parts.
362*ace01605SRiver Riddle   // The part before will contain the condition, the part after will be the
363*ace01605SRiver Riddle   // continuation point.
364*ace01605SRiver Riddle   auto *condBlock = rewriter.getInsertionBlock();
365*ace01605SRiver Riddle   auto opPosition = rewriter.getInsertionPoint();
366*ace01605SRiver Riddle   auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
367*ace01605SRiver Riddle   Block *continueBlock;
368*ace01605SRiver Riddle   if (ifOp.getNumResults() == 0) {
369*ace01605SRiver Riddle     continueBlock = remainingOpsBlock;
370*ace01605SRiver Riddle   } else {
371*ace01605SRiver Riddle     continueBlock =
372*ace01605SRiver Riddle         rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
373*ace01605SRiver Riddle                              SmallVector<Location>(ifOp.getNumResults(), loc));
374*ace01605SRiver Riddle     rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
375*ace01605SRiver Riddle   }
376*ace01605SRiver Riddle 
377*ace01605SRiver Riddle   // Move blocks from the "then" region to the region containing 'scf.if',
378*ace01605SRiver Riddle   // place it before the continuation block, and branch to it.
379*ace01605SRiver Riddle   auto &thenRegion = ifOp.getThenRegion();
380*ace01605SRiver Riddle   auto *thenBlock = &thenRegion.front();
381*ace01605SRiver Riddle   Operation *thenTerminator = thenRegion.back().getTerminator();
382*ace01605SRiver Riddle   ValueRange thenTerminatorOperands = thenTerminator->getOperands();
383*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(&thenRegion.back());
384*ace01605SRiver Riddle   rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
385*ace01605SRiver Riddle   rewriter.eraseOp(thenTerminator);
386*ace01605SRiver Riddle   rewriter.inlineRegionBefore(thenRegion, continueBlock);
387*ace01605SRiver Riddle 
388*ace01605SRiver Riddle   // Move blocks from the "else" region (if present) to the region containing
389*ace01605SRiver Riddle   // 'scf.if', place it before the continuation block and branch to it.  It
390*ace01605SRiver Riddle   // will be placed after the "then" regions.
391*ace01605SRiver Riddle   auto *elseBlock = continueBlock;
392*ace01605SRiver Riddle   auto &elseRegion = ifOp.getElseRegion();
393*ace01605SRiver Riddle   if (!elseRegion.empty()) {
394*ace01605SRiver Riddle     elseBlock = &elseRegion.front();
395*ace01605SRiver Riddle     Operation *elseTerminator = elseRegion.back().getTerminator();
396*ace01605SRiver Riddle     ValueRange elseTerminatorOperands = elseTerminator->getOperands();
397*ace01605SRiver Riddle     rewriter.setInsertionPointToEnd(&elseRegion.back());
398*ace01605SRiver Riddle     rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
399*ace01605SRiver Riddle     rewriter.eraseOp(elseTerminator);
400*ace01605SRiver Riddle     rewriter.inlineRegionBefore(elseRegion, continueBlock);
401*ace01605SRiver Riddle   }
402*ace01605SRiver Riddle 
403*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(condBlock);
404*ace01605SRiver Riddle   rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
405*ace01605SRiver Riddle                                     /*trueArgs=*/ArrayRef<Value>(), elseBlock,
406*ace01605SRiver Riddle                                     /*falseArgs=*/ArrayRef<Value>());
407*ace01605SRiver Riddle 
408*ace01605SRiver Riddle   // Ok, we're done!
409*ace01605SRiver Riddle   rewriter.replaceOp(ifOp, continueBlock->getArguments());
410*ace01605SRiver Riddle   return success();
411*ace01605SRiver Riddle }
412*ace01605SRiver Riddle 
413*ace01605SRiver Riddle LogicalResult
414*ace01605SRiver Riddle ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
415*ace01605SRiver Riddle                                        PatternRewriter &rewriter) const {
416*ace01605SRiver Riddle   auto loc = op.getLoc();
417*ace01605SRiver Riddle 
418*ace01605SRiver Riddle   auto *condBlock = rewriter.getInsertionBlock();
419*ace01605SRiver Riddle   auto opPosition = rewriter.getInsertionPoint();
420*ace01605SRiver Riddle   auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
421*ace01605SRiver Riddle 
422*ace01605SRiver Riddle   auto &region = op.getRegion();
423*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(condBlock);
424*ace01605SRiver Riddle   rewriter.create<cf::BranchOp>(loc, &region.front());
425*ace01605SRiver Riddle 
426*ace01605SRiver Riddle   for (Block &block : region) {
427*ace01605SRiver Riddle     if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
428*ace01605SRiver Riddle       ValueRange terminatorOperands = terminator->getOperands();
429*ace01605SRiver Riddle       rewriter.setInsertionPointToEnd(&block);
430*ace01605SRiver Riddle       rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
431*ace01605SRiver Riddle       rewriter.eraseOp(terminator);
432*ace01605SRiver Riddle     }
433*ace01605SRiver Riddle   }
434*ace01605SRiver Riddle 
435*ace01605SRiver Riddle   rewriter.inlineRegionBefore(region, remainingOpsBlock);
436*ace01605SRiver Riddle 
437*ace01605SRiver Riddle   SmallVector<Value> vals;
438*ace01605SRiver Riddle   SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
439*ace01605SRiver Riddle   for (auto arg :
440*ace01605SRiver Riddle        remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
441*ace01605SRiver Riddle     vals.push_back(arg);
442*ace01605SRiver Riddle   rewriter.replaceOp(op, vals);
443*ace01605SRiver Riddle   return success();
444*ace01605SRiver Riddle }
445*ace01605SRiver Riddle 
446*ace01605SRiver Riddle LogicalResult
447*ace01605SRiver Riddle ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
448*ace01605SRiver Riddle                                   PatternRewriter &rewriter) const {
449*ace01605SRiver Riddle   Location loc = parallelOp.getLoc();
450*ace01605SRiver Riddle 
451*ace01605SRiver Riddle   // For a parallel loop, we essentially need to create an n-dimensional loop
452*ace01605SRiver Riddle   // nest. We do this by translating to scf.for ops and have those lowered in
453*ace01605SRiver Riddle   // a further rewrite. If a parallel loop contains reductions (and thus returns
454*ace01605SRiver Riddle   // values), forward the initial values for the reductions down the loop
455*ace01605SRiver Riddle   // hierarchy and bubble up the results by modifying the "yield" terminator.
456*ace01605SRiver Riddle   SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
457*ace01605SRiver Riddle   SmallVector<Value, 4> ivs;
458*ace01605SRiver Riddle   ivs.reserve(parallelOp.getNumLoops());
459*ace01605SRiver Riddle   bool first = true;
460*ace01605SRiver Riddle   SmallVector<Value, 4> loopResults(iterArgs);
461*ace01605SRiver Riddle   for (auto loopOperands :
462*ace01605SRiver Riddle        llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
463*ace01605SRiver Riddle                  parallelOp.getUpperBound(), parallelOp.getStep())) {
464*ace01605SRiver Riddle     Value iv, lower, upper, step;
465*ace01605SRiver Riddle     std::tie(iv, lower, upper, step) = loopOperands;
466*ace01605SRiver Riddle     ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
467*ace01605SRiver Riddle     ivs.push_back(forOp.getInductionVar());
468*ace01605SRiver Riddle     auto iterRange = forOp.getRegionIterArgs();
469*ace01605SRiver Riddle     iterArgs.assign(iterRange.begin(), iterRange.end());
470*ace01605SRiver Riddle 
471*ace01605SRiver Riddle     if (first) {
472*ace01605SRiver Riddle       // Store the results of the outermost loop that will be used to replace
473*ace01605SRiver Riddle       // the results of the parallel loop when it is fully rewritten.
474*ace01605SRiver Riddle       loopResults.assign(forOp.result_begin(), forOp.result_end());
475*ace01605SRiver Riddle       first = false;
476*ace01605SRiver Riddle     } else if (!forOp.getResults().empty()) {
477*ace01605SRiver Riddle       // A loop is constructed with an empty "yield" terminator if there are
478*ace01605SRiver Riddle       // no results.
479*ace01605SRiver Riddle       rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
480*ace01605SRiver Riddle       rewriter.create<scf::YieldOp>(loc, forOp.getResults());
481*ace01605SRiver Riddle     }
482*ace01605SRiver Riddle 
483*ace01605SRiver Riddle     rewriter.setInsertionPointToStart(forOp.getBody());
484*ace01605SRiver Riddle   }
485*ace01605SRiver Riddle 
486*ace01605SRiver Riddle   // First, merge reduction blocks into the main region.
487*ace01605SRiver Riddle   SmallVector<Value, 4> yieldOperands;
488*ace01605SRiver Riddle   yieldOperands.reserve(parallelOp.getNumResults());
489*ace01605SRiver Riddle   for (auto &op : *parallelOp.getBody()) {
490*ace01605SRiver Riddle     auto reduce = dyn_cast<ReduceOp>(op);
491*ace01605SRiver Riddle     if (!reduce)
492*ace01605SRiver Riddle       continue;
493*ace01605SRiver Riddle 
494*ace01605SRiver Riddle     Block &reduceBlock = reduce.getReductionOperator().front();
495*ace01605SRiver Riddle     Value arg = iterArgs[yieldOperands.size()];
496*ace01605SRiver Riddle     yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
497*ace01605SRiver Riddle     rewriter.eraseOp(reduceBlock.getTerminator());
498*ace01605SRiver Riddle     rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
499*ace01605SRiver Riddle     rewriter.eraseOp(reduce);
500*ace01605SRiver Riddle   }
501*ace01605SRiver Riddle 
502*ace01605SRiver Riddle   // Then merge the loop body without the terminator.
503*ace01605SRiver Riddle   rewriter.eraseOp(parallelOp.getBody()->getTerminator());
504*ace01605SRiver Riddle   Block *newBody = rewriter.getInsertionBlock();
505*ace01605SRiver Riddle   if (newBody->empty())
506*ace01605SRiver Riddle     rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
507*ace01605SRiver Riddle   else
508*ace01605SRiver Riddle     rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
509*ace01605SRiver Riddle                               ivs);
510*ace01605SRiver Riddle 
511*ace01605SRiver Riddle   // Finally, create the terminator if required (for loops with no results, it
512*ace01605SRiver Riddle   // has been already created in loop construction).
513*ace01605SRiver Riddle   if (!yieldOperands.empty()) {
514*ace01605SRiver Riddle     rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
515*ace01605SRiver Riddle     rewriter.create<scf::YieldOp>(loc, yieldOperands);
516*ace01605SRiver Riddle   }
517*ace01605SRiver Riddle 
518*ace01605SRiver Riddle   rewriter.replaceOp(parallelOp, loopResults);
519*ace01605SRiver Riddle 
520*ace01605SRiver Riddle   return success();
521*ace01605SRiver Riddle }
522*ace01605SRiver Riddle 
523*ace01605SRiver Riddle LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
524*ace01605SRiver Riddle                                              PatternRewriter &rewriter) const {
525*ace01605SRiver Riddle   OpBuilder::InsertionGuard guard(rewriter);
526*ace01605SRiver Riddle   Location loc = whileOp.getLoc();
527*ace01605SRiver Riddle 
528*ace01605SRiver Riddle   // Split the current block before the WhileOp to create the inlining point.
529*ace01605SRiver Riddle   Block *currentBlock = rewriter.getInsertionBlock();
530*ace01605SRiver Riddle   Block *continuation =
531*ace01605SRiver Riddle       rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
532*ace01605SRiver Riddle 
533*ace01605SRiver Riddle   // Inline both regions.
534*ace01605SRiver Riddle   Block *after = &whileOp.getAfter().front();
535*ace01605SRiver Riddle   Block *afterLast = &whileOp.getAfter().back();
536*ace01605SRiver Riddle   Block *before = &whileOp.getBefore().front();
537*ace01605SRiver Riddle   Block *beforeLast = &whileOp.getBefore().back();
538*ace01605SRiver Riddle   rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
539*ace01605SRiver Riddle   rewriter.inlineRegionBefore(whileOp.getBefore(), after);
540*ace01605SRiver Riddle 
541*ace01605SRiver Riddle   // Branch to the "before" region.
542*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(currentBlock);
543*ace01605SRiver Riddle   rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
544*ace01605SRiver Riddle 
545*ace01605SRiver Riddle   // Replace terminators with branches. Assuming bodies are SESE, which holds
546*ace01605SRiver Riddle   // given only the patterns from this file, we only need to look at the last
547*ace01605SRiver Riddle   // block. This should be reconsidered if we allow break/continue in SCF.
548*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(beforeLast);
549*ace01605SRiver Riddle   auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
550*ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
551*ace01605SRiver Riddle                                                 after, condOp.getArgs(),
552*ace01605SRiver Riddle                                                 continuation, ValueRange());
553*ace01605SRiver Riddle 
554*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(afterLast);
555*ace01605SRiver Riddle   auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
556*ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
557*ace01605SRiver Riddle                                             yieldOp.getResults());
558*ace01605SRiver Riddle 
559*ace01605SRiver Riddle   // Replace the op with values "yielded" from the "before" region, which are
560*ace01605SRiver Riddle   // visible by dominance.
561*ace01605SRiver Riddle   rewriter.replaceOp(whileOp, condOp.getArgs());
562*ace01605SRiver Riddle 
563*ace01605SRiver Riddle   return success();
564*ace01605SRiver Riddle }
565*ace01605SRiver Riddle 
566*ace01605SRiver Riddle LogicalResult
567*ace01605SRiver Riddle DoWhileLowering::matchAndRewrite(WhileOp whileOp,
568*ace01605SRiver Riddle                                  PatternRewriter &rewriter) const {
569*ace01605SRiver Riddle   if (!llvm::hasSingleElement(whileOp.getAfter()))
570*ace01605SRiver Riddle     return rewriter.notifyMatchFailure(whileOp,
571*ace01605SRiver Riddle                                        "do-while simplification applicable to "
572*ace01605SRiver Riddle                                        "single-block 'after' region only");
573*ace01605SRiver Riddle 
574*ace01605SRiver Riddle   Block &afterBlock = whileOp.getAfter().front();
575*ace01605SRiver Riddle   if (!llvm::hasSingleElement(afterBlock))
576*ace01605SRiver Riddle     return rewriter.notifyMatchFailure(whileOp,
577*ace01605SRiver Riddle                                        "do-while simplification applicable "
578*ace01605SRiver Riddle                                        "only if 'after' region has no payload");
579*ace01605SRiver Riddle 
580*ace01605SRiver Riddle   auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
581*ace01605SRiver Riddle   if (!yield || yield.getResults() != afterBlock.getArguments())
582*ace01605SRiver Riddle     return rewriter.notifyMatchFailure(whileOp,
583*ace01605SRiver Riddle                                        "do-while simplification applicable "
584*ace01605SRiver Riddle                                        "only to forwarding 'after' regions");
585*ace01605SRiver Riddle 
586*ace01605SRiver Riddle   // Split the current block before the WhileOp to create the inlining point.
587*ace01605SRiver Riddle   OpBuilder::InsertionGuard guard(rewriter);
588*ace01605SRiver Riddle   Block *currentBlock = rewriter.getInsertionBlock();
589*ace01605SRiver Riddle   Block *continuation =
590*ace01605SRiver Riddle       rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
591*ace01605SRiver Riddle 
592*ace01605SRiver Riddle   // Only the "before" region should be inlined.
593*ace01605SRiver Riddle   Block *before = &whileOp.getBefore().front();
594*ace01605SRiver Riddle   Block *beforeLast = &whileOp.getBefore().back();
595*ace01605SRiver Riddle   rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
596*ace01605SRiver Riddle 
597*ace01605SRiver Riddle   // Branch to the "before" region.
598*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(currentBlock);
599*ace01605SRiver Riddle   rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
600*ace01605SRiver Riddle 
601*ace01605SRiver Riddle   // Loop around the "before" region based on condition.
602*ace01605SRiver Riddle   rewriter.setInsertionPointToEnd(beforeLast);
603*ace01605SRiver Riddle   auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
604*ace01605SRiver Riddle   rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
605*ace01605SRiver Riddle                                                 before, condOp.getArgs(),
606*ace01605SRiver Riddle                                                 continuation, ValueRange());
607*ace01605SRiver Riddle 
608*ace01605SRiver Riddle   // Replace the op with values "yielded" from the "before" region, which are
609*ace01605SRiver Riddle   // visible by dominance.
610*ace01605SRiver Riddle   rewriter.replaceOp(whileOp, condOp.getArgs());
611*ace01605SRiver Riddle 
612*ace01605SRiver Riddle   return success();
613*ace01605SRiver Riddle }
614*ace01605SRiver Riddle 
615*ace01605SRiver Riddle void mlir::populateSCFToControlFlowConversionPatterns(
616*ace01605SRiver Riddle     RewritePatternSet &patterns) {
617*ace01605SRiver Riddle   patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
618*ace01605SRiver Riddle                ExecuteRegionLowering>(patterns.getContext());
619*ace01605SRiver Riddle   patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
620*ace01605SRiver Riddle }
621*ace01605SRiver Riddle 
622*ace01605SRiver Riddle void SCFToControlFlowPass::runOnOperation() {
623*ace01605SRiver Riddle   RewritePatternSet patterns(&getContext());
624*ace01605SRiver Riddle   populateSCFToControlFlowConversionPatterns(patterns);
625*ace01605SRiver Riddle 
626*ace01605SRiver Riddle   // Configure conversion to lower out SCF operations.
627*ace01605SRiver Riddle   ConversionTarget target(getContext());
628*ace01605SRiver Riddle   target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
629*ace01605SRiver Riddle                       scf::ExecuteRegionOp>();
630*ace01605SRiver Riddle   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
631*ace01605SRiver Riddle   if (failed(
632*ace01605SRiver Riddle           applyPartialConversion(getOperation(), target, std::move(patterns))))
633*ace01605SRiver Riddle     signalPassFailure();
634*ace01605SRiver Riddle }
635*ace01605SRiver Riddle 
636*ace01605SRiver Riddle std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
637*ace01605SRiver Riddle   return std::make_unique<SCFToControlFlowPass>();
638*ace01605SRiver Riddle }
639