1ace01605SRiver Riddle //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2ace01605SRiver Riddle //
3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ace01605SRiver Riddle //
7ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8ace01605SRiver Riddle //
9ace01605SRiver Riddle // This file implements a pass to convert scf.for, scf.if and loop.terminator
10ace01605SRiver Riddle // ops into standard CFG ops.
11ace01605SRiver Riddle //
12ace01605SRiver Riddle //===----------------------------------------------------------------------===//
13ace01605SRiver Riddle
14ace01605SRiver Riddle #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
15ace01605SRiver Riddle #include "../PassDetail.h"
16ace01605SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
19ace01605SRiver Riddle #include "mlir/IR/BlockAndValueMapping.h"
20ace01605SRiver Riddle #include "mlir/IR/Builders.h"
21ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h"
22ace01605SRiver Riddle #include "mlir/IR/MLIRContext.h"
23ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h"
24ace01605SRiver Riddle #include "mlir/Transforms/DialectConversion.h"
25ace01605SRiver Riddle #include "mlir/Transforms/Passes.h"
26ace01605SRiver Riddle
27ace01605SRiver Riddle using namespace mlir;
28ace01605SRiver Riddle using namespace mlir::scf;
29ace01605SRiver Riddle
30ace01605SRiver Riddle namespace {
31ace01605SRiver Riddle
32ace01605SRiver Riddle struct SCFToControlFlowPass
33ace01605SRiver Riddle : public SCFToControlFlowBase<SCFToControlFlowPass> {
34ace01605SRiver Riddle void runOnOperation() override;
35ace01605SRiver Riddle };
36ace01605SRiver Riddle
37ace01605SRiver Riddle // Create a CFG subgraph for the loop around its body blocks (if the body
38ace01605SRiver Riddle // contained other loops, they have been already lowered to a flow of blocks).
39ace01605SRiver Riddle // Maintain the invariants that a CFG subgraph created for any loop has a single
40ace01605SRiver Riddle // entry and a single exit, and that the entry/exit blocks are respectively
41ace01605SRiver Riddle // first/last blocks in the parent region. The original loop operation is
42ace01605SRiver Riddle // replaced by the initialization operations that set up the initial value of
43ace01605SRiver Riddle // the loop induction variable (%iv) and computes the loop bounds that are loop-
44ace01605SRiver Riddle // invariant for affine loops. The operations following the original scf.for
45ace01605SRiver Riddle // are split out into a separate continuation (exit) block. A condition block is
46ace01605SRiver Riddle // created before the continuation block. It checks the exit condition of the
47ace01605SRiver Riddle // loop and branches either to the continuation block, or to the first block of
48ace01605SRiver Riddle // the body. The condition block takes as arguments the values of the induction
49ace01605SRiver Riddle // variable followed by loop-carried values. Since it dominates both the body
50ace01605SRiver Riddle // blocks and the continuation block, loop-carried values are visible in all of
51ace01605SRiver Riddle // those blocks. Induction variable modification is appended to the last block
52ace01605SRiver Riddle // of the body (which is the exit block from the body subgraph thanks to the
53ace01605SRiver Riddle // invariant we maintain) along with a branch that loops back to the condition
54ace01605SRiver Riddle // block. Loop-carried values are the loop terminator operands, which are
55ace01605SRiver Riddle // forwarded to the branch.
56ace01605SRiver Riddle //
57ace01605SRiver Riddle // +---------------------------------+
58ace01605SRiver Riddle // | <code before the ForOp> |
59ace01605SRiver Riddle // | <definitions of %init...> |
60ace01605SRiver Riddle // | <compute initial %iv value> |
61ace01605SRiver Riddle // | cf.br cond(%iv, %init...) |
62ace01605SRiver Riddle // +---------------------------------+
63ace01605SRiver Riddle // |
64ace01605SRiver Riddle // -------| |
65ace01605SRiver Riddle // | v v
66ace01605SRiver Riddle // | +--------------------------------+
67ace01605SRiver Riddle // | | cond(%iv, %init...): |
68ace01605SRiver Riddle // | | <compare %iv to upper bound> |
69ace01605SRiver Riddle // | | cf.cond_br %r, body, end |
70ace01605SRiver Riddle // | +--------------------------------+
71ace01605SRiver Riddle // | | |
72ace01605SRiver Riddle // | | -------------|
73ace01605SRiver Riddle // | v |
74ace01605SRiver Riddle // | +--------------------------------+ |
75ace01605SRiver Riddle // | | body-first: | |
76ace01605SRiver Riddle // | | <%init visible by dominance> | |
77ace01605SRiver Riddle // | | <body contents> | |
78ace01605SRiver Riddle // | +--------------------------------+ |
79ace01605SRiver Riddle // | | |
80ace01605SRiver Riddle // | ... |
81ace01605SRiver Riddle // | | |
82ace01605SRiver Riddle // | +--------------------------------+ |
83ace01605SRiver Riddle // | | body-last: | |
84ace01605SRiver Riddle // | | <body contents> | |
85ace01605SRiver Riddle // | | <operands of yield = %yields>| |
86ace01605SRiver Riddle // | | %new_iv =<add step to %iv> | |
87ace01605SRiver Riddle // | | cf.br cond(%new_iv, %yields) | |
88ace01605SRiver Riddle // | +--------------------------------+ |
89ace01605SRiver Riddle // | | |
90ace01605SRiver Riddle // |----------- |--------------------
91ace01605SRiver Riddle // v
92ace01605SRiver Riddle // +--------------------------------+
93ace01605SRiver Riddle // | end: |
94ace01605SRiver Riddle // | <code after the ForOp> |
95ace01605SRiver Riddle // | <%init visible by dominance> |
96ace01605SRiver Riddle // +--------------------------------+
97ace01605SRiver Riddle //
98ace01605SRiver Riddle struct ForLowering : public OpRewritePattern<ForOp> {
99ace01605SRiver Riddle using OpRewritePattern<ForOp>::OpRewritePattern;
100ace01605SRiver Riddle
101ace01605SRiver Riddle LogicalResult matchAndRewrite(ForOp forOp,
102ace01605SRiver Riddle PatternRewriter &rewriter) const override;
103ace01605SRiver Riddle };
104ace01605SRiver Riddle
105ace01605SRiver Riddle // Create a CFG subgraph for the scf.if operation (including its "then" and
106ace01605SRiver Riddle // optional "else" operation blocks). We maintain the invariants that the
107ace01605SRiver Riddle // subgraph has a single entry and a single exit point, and that the entry/exit
108ace01605SRiver Riddle // blocks are respectively the first/last block of the enclosing region. The
109ace01605SRiver Riddle // operations following the scf.if are split into a continuation (subgraph
110ace01605SRiver Riddle // exit) block. The condition is lowered to a chain of blocks that implement the
111ace01605SRiver Riddle // short-circuit scheme. The "scf.if" operation is replaced with a conditional
112ace01605SRiver Riddle // branch to either the first block of the "then" region, or to the first block
113ace01605SRiver Riddle // of the "else" region. In these blocks, "scf.yield" is unconditional branches
114ace01605SRiver Riddle // to the post-dominating block. When the "scf.if" does not return values, the
115ace01605SRiver Riddle // post-dominating block is the same as the continuation block. When it returns
116ace01605SRiver Riddle // values, the post-dominating block is a new block with arguments that
117ace01605SRiver Riddle // correspond to the values returned by the "scf.if" that unconditionally
118ace01605SRiver Riddle // branches to the continuation block. This allows block arguments to dominate
119ace01605SRiver Riddle // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
120ace01605SRiver Riddle // new block allows us to avoid modifying the argument list of an existing
121ace01605SRiver Riddle // block, which is illegal in a conversion pattern). When the "else" region is
122ace01605SRiver Riddle // empty, which is only allowed for "scf.if"s that don't return values, the
123ace01605SRiver Riddle // condition branches directly to the continuation block.
124ace01605SRiver Riddle //
125ace01605SRiver Riddle // CFG for a scf.if with else and without results.
126ace01605SRiver Riddle //
127ace01605SRiver Riddle // +--------------------------------+
128ace01605SRiver Riddle // | <code before the IfOp> |
129ace01605SRiver Riddle // | cf.cond_br %cond, %then, %else |
130ace01605SRiver Riddle // +--------------------------------+
131ace01605SRiver Riddle // | |
132ace01605SRiver Riddle // | --------------|
133ace01605SRiver Riddle // v |
134ace01605SRiver Riddle // +--------------------------------+ |
135ace01605SRiver Riddle // | then: | |
136ace01605SRiver Riddle // | <then contents> | |
137ace01605SRiver Riddle // | cf.br continue | |
138ace01605SRiver Riddle // +--------------------------------+ |
139ace01605SRiver Riddle // | |
140ace01605SRiver Riddle // |---------- |-------------
141ace01605SRiver Riddle // | V
142ace01605SRiver Riddle // | +--------------------------------+
143ace01605SRiver Riddle // | | else: |
144ace01605SRiver Riddle // | | <else contents> |
145ace01605SRiver Riddle // | | cf.br continue |
146ace01605SRiver Riddle // | +--------------------------------+
147ace01605SRiver Riddle // | |
148ace01605SRiver Riddle // ------| |
149ace01605SRiver Riddle // v v
150ace01605SRiver Riddle // +--------------------------------+
151ace01605SRiver Riddle // | continue: |
152ace01605SRiver Riddle // | <code after the IfOp> |
153ace01605SRiver Riddle // +--------------------------------+
154ace01605SRiver Riddle //
155ace01605SRiver Riddle // CFG for a scf.if with results.
156ace01605SRiver Riddle //
157ace01605SRiver Riddle // +--------------------------------+
158ace01605SRiver Riddle // | <code before the IfOp> |
159ace01605SRiver Riddle // | cf.cond_br %cond, %then, %else |
160ace01605SRiver Riddle // +--------------------------------+
161ace01605SRiver Riddle // | |
162ace01605SRiver Riddle // | --------------|
163ace01605SRiver Riddle // v |
164ace01605SRiver Riddle // +--------------------------------+ |
165ace01605SRiver Riddle // | then: | |
166ace01605SRiver Riddle // | <then contents> | |
167ace01605SRiver Riddle // | cf.br dom(%args...) | |
168ace01605SRiver Riddle // +--------------------------------+ |
169ace01605SRiver Riddle // | |
170ace01605SRiver Riddle // |---------- |-------------
171ace01605SRiver Riddle // | V
172ace01605SRiver Riddle // | +--------------------------------+
173ace01605SRiver Riddle // | | else: |
174ace01605SRiver Riddle // | | <else contents> |
175ace01605SRiver Riddle // | | cf.br dom(%args...) |
176ace01605SRiver Riddle // | +--------------------------------+
177ace01605SRiver Riddle // | |
178ace01605SRiver Riddle // ------| |
179ace01605SRiver Riddle // v v
180ace01605SRiver Riddle // +--------------------------------+
181ace01605SRiver Riddle // | dom(%args...): |
182ace01605SRiver Riddle // | cf.br continue |
183ace01605SRiver Riddle // +--------------------------------+
184ace01605SRiver Riddle // |
185ace01605SRiver Riddle // v
186ace01605SRiver Riddle // +--------------------------------+
187ace01605SRiver Riddle // | continue: |
188ace01605SRiver Riddle // | <code after the IfOp> |
189ace01605SRiver Riddle // +--------------------------------+
190ace01605SRiver Riddle //
191ace01605SRiver Riddle struct IfLowering : public OpRewritePattern<IfOp> {
192ace01605SRiver Riddle using OpRewritePattern<IfOp>::OpRewritePattern;
193ace01605SRiver Riddle
194ace01605SRiver Riddle LogicalResult matchAndRewrite(IfOp ifOp,
195ace01605SRiver Riddle PatternRewriter &rewriter) const override;
196ace01605SRiver Riddle };
197ace01605SRiver Riddle
198ace01605SRiver Riddle struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
199ace01605SRiver Riddle using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
200ace01605SRiver Riddle
201ace01605SRiver Riddle LogicalResult matchAndRewrite(ExecuteRegionOp op,
202ace01605SRiver Riddle PatternRewriter &rewriter) const override;
203ace01605SRiver Riddle };
204ace01605SRiver Riddle
205ace01605SRiver Riddle struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
206ace01605SRiver Riddle using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
207ace01605SRiver Riddle
208ace01605SRiver Riddle LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
209ace01605SRiver Riddle PatternRewriter &rewriter) const override;
210ace01605SRiver Riddle };
211ace01605SRiver Riddle
212ace01605SRiver Riddle /// Create a CFG subgraph for this loop construct. The regions of the loop need
213ace01605SRiver Riddle /// not be a single block anymore (for example, if other SCF constructs that
214ace01605SRiver Riddle /// they contain have been already converted to CFG), but need to be single-exit
215ace01605SRiver Riddle /// from the last block of each region. The operations following the original
216ace01605SRiver Riddle /// WhileOp are split into a new continuation block. Both regions of the WhileOp
217ace01605SRiver Riddle /// are inlined, and their terminators are rewritten to organize the control
218ace01605SRiver Riddle /// flow implementing the loop as follows.
219ace01605SRiver Riddle ///
220ace01605SRiver Riddle /// +---------------------------------+
221ace01605SRiver Riddle /// | <code before the WhileOp> |
222ace01605SRiver Riddle /// | cf.br ^before(%operands...) |
223ace01605SRiver Riddle /// +---------------------------------+
224ace01605SRiver Riddle /// |
225ace01605SRiver Riddle /// -------| |
226ace01605SRiver Riddle /// | v v
227ace01605SRiver Riddle /// | +--------------------------------+
228ace01605SRiver Riddle /// | | ^before(%bargs...): |
229ace01605SRiver Riddle /// | | %vals... = <some payload> |
230ace01605SRiver Riddle /// | +--------------------------------+
231ace01605SRiver Riddle /// | |
232ace01605SRiver Riddle /// | ...
233ace01605SRiver Riddle /// | |
234ace01605SRiver Riddle /// | +--------------------------------+
235ace01605SRiver Riddle /// | | ^before-last:
236ace01605SRiver Riddle /// | | %cond = <compute condition> |
237ace01605SRiver Riddle /// | | cf.cond_br %cond, |
238ace01605SRiver Riddle /// | | ^after(%vals...), ^cont |
239ace01605SRiver Riddle /// | +--------------------------------+
240ace01605SRiver Riddle /// | | |
241ace01605SRiver Riddle /// | | -------------|
242ace01605SRiver Riddle /// | v |
243ace01605SRiver Riddle /// | +--------------------------------+ |
244ace01605SRiver Riddle /// | | ^after(%aargs...): | |
245ace01605SRiver Riddle /// | | <body contents> | |
246ace01605SRiver Riddle /// | +--------------------------------+ |
247ace01605SRiver Riddle /// | | |
248ace01605SRiver Riddle /// | ... |
249ace01605SRiver Riddle /// | | |
250ace01605SRiver Riddle /// | +--------------------------------+ |
251ace01605SRiver Riddle /// | | ^after-last: | |
252ace01605SRiver Riddle /// | | %yields... = <some payload> | |
253ace01605SRiver Riddle /// | | cf.br ^before(%yields...) | |
254ace01605SRiver Riddle /// | +--------------------------------+ |
255ace01605SRiver Riddle /// | | |
256ace01605SRiver Riddle /// |----------- |--------------------
257ace01605SRiver Riddle /// v
258ace01605SRiver Riddle /// +--------------------------------+
259ace01605SRiver Riddle /// | ^cont: |
260ace01605SRiver Riddle /// | <code after the WhileOp> |
261ace01605SRiver Riddle /// | <%vals from 'before' region |
262ace01605SRiver Riddle /// | visible by dominance> |
263ace01605SRiver Riddle /// +--------------------------------+
264ace01605SRiver Riddle ///
265ace01605SRiver Riddle /// Values are communicated between ex-regions (the groups of blocks that used
266ace01605SRiver Riddle /// to form a region before inlining) through block arguments of their
267ace01605SRiver Riddle /// entry blocks, which are visible in all other dominated blocks. Similarly,
268ace01605SRiver Riddle /// the results of the WhileOp are defined in the 'before' region, which is
269ace01605SRiver Riddle /// required to have a single existing block, and are therefore accessible in
270ace01605SRiver Riddle /// the continuation block due to dominance.
271ace01605SRiver Riddle struct WhileLowering : public OpRewritePattern<WhileOp> {
272ace01605SRiver Riddle using OpRewritePattern<WhileOp>::OpRewritePattern;
273ace01605SRiver Riddle
274ace01605SRiver Riddle LogicalResult matchAndRewrite(WhileOp whileOp,
275ace01605SRiver Riddle PatternRewriter &rewriter) const override;
276ace01605SRiver Riddle };
277ace01605SRiver Riddle
278ace01605SRiver Riddle /// Optimized version of the above for the case of the "after" region merely
279ace01605SRiver Riddle /// forwarding its arguments back to the "before" region (i.e., a "do-while"
280ace01605SRiver Riddle /// loop). This avoid inlining the "after" region completely and branches back
281ace01605SRiver Riddle /// to the "before" entry instead.
282ace01605SRiver Riddle struct DoWhileLowering : public OpRewritePattern<WhileOp> {
283ace01605SRiver Riddle using OpRewritePattern<WhileOp>::OpRewritePattern;
284ace01605SRiver Riddle
285ace01605SRiver Riddle LogicalResult matchAndRewrite(WhileOp whileOp,
286ace01605SRiver Riddle PatternRewriter &rewriter) const override;
287ace01605SRiver Riddle };
288ace01605SRiver Riddle } // namespace
289ace01605SRiver Riddle
matchAndRewrite(ForOp forOp,PatternRewriter & rewriter) const290ace01605SRiver Riddle LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
291ace01605SRiver Riddle PatternRewriter &rewriter) const {
292ace01605SRiver Riddle Location loc = forOp.getLoc();
293ace01605SRiver Riddle
294ace01605SRiver Riddle // Start by splitting the block containing the 'scf.for' into two parts.
295ace01605SRiver Riddle // The part before will get the init code, the part after will be the end
296ace01605SRiver Riddle // point.
297ace01605SRiver Riddle auto *initBlock = rewriter.getInsertionBlock();
298ace01605SRiver Riddle auto initPosition = rewriter.getInsertionPoint();
299ace01605SRiver Riddle auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
300ace01605SRiver Riddle
301ace01605SRiver Riddle // Use the first block of the loop body as the condition block since it is the
302ace01605SRiver Riddle // block that has the induction variable and loop-carried values as arguments.
303ace01605SRiver Riddle // Split out all operations from the first block into a new block. Move all
304ace01605SRiver Riddle // body blocks from the loop body region to the region containing the loop.
305ace01605SRiver Riddle auto *conditionBlock = &forOp.getRegion().front();
306ace01605SRiver Riddle auto *firstBodyBlock =
307ace01605SRiver Riddle rewriter.splitBlock(conditionBlock, conditionBlock->begin());
308ace01605SRiver Riddle auto *lastBodyBlock = &forOp.getRegion().back();
309ace01605SRiver Riddle rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
310ace01605SRiver Riddle auto iv = conditionBlock->getArgument(0);
311ace01605SRiver Riddle
312ace01605SRiver Riddle // Append the induction variable stepping logic to the last body block and
313ace01605SRiver Riddle // branch back to the condition block. Loop-carried values are taken from
314ace01605SRiver Riddle // operands of the loop terminator.
315ace01605SRiver Riddle Operation *terminator = lastBodyBlock->getTerminator();
316ace01605SRiver Riddle rewriter.setInsertionPointToEnd(lastBodyBlock);
317ace01605SRiver Riddle auto step = forOp.getStep();
318ace01605SRiver Riddle auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult();
319ace01605SRiver Riddle if (!stepped)
320ace01605SRiver Riddle return failure();
321ace01605SRiver Riddle
322ace01605SRiver Riddle SmallVector<Value, 8> loopCarried;
323ace01605SRiver Riddle loopCarried.push_back(stepped);
324ace01605SRiver Riddle loopCarried.append(terminator->operand_begin(), terminator->operand_end());
325ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
326ace01605SRiver Riddle rewriter.eraseOp(terminator);
327ace01605SRiver Riddle
328ace01605SRiver Riddle // Compute loop bounds before branching to the condition.
329ace01605SRiver Riddle rewriter.setInsertionPointToEnd(initBlock);
330ace01605SRiver Riddle Value lowerBound = forOp.getLowerBound();
331ace01605SRiver Riddle Value upperBound = forOp.getUpperBound();
332ace01605SRiver Riddle if (!lowerBound || !upperBound)
333ace01605SRiver Riddle return failure();
334ace01605SRiver Riddle
335ace01605SRiver Riddle // The initial values of loop-carried values is obtained from the operands
336ace01605SRiver Riddle // of the loop operation.
337ace01605SRiver Riddle SmallVector<Value, 8> destOperands;
338ace01605SRiver Riddle destOperands.push_back(lowerBound);
339ace01605SRiver Riddle auto iterOperands = forOp.getIterOperands();
340ace01605SRiver Riddle destOperands.append(iterOperands.begin(), iterOperands.end());
341ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
342ace01605SRiver Riddle
343ace01605SRiver Riddle // With the body block done, we can fill in the condition block.
344ace01605SRiver Riddle rewriter.setInsertionPointToEnd(conditionBlock);
345ace01605SRiver Riddle auto comparison = rewriter.create<arith::CmpIOp>(
346ace01605SRiver Riddle loc, arith::CmpIPredicate::slt, iv, upperBound);
347ace01605SRiver Riddle
348ace01605SRiver Riddle rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
349ace01605SRiver Riddle ArrayRef<Value>(), endBlock,
350ace01605SRiver Riddle ArrayRef<Value>());
351ace01605SRiver Riddle // The result of the loop operation is the values of the condition block
352ace01605SRiver Riddle // arguments except the induction variable on the last iteration.
353ace01605SRiver Riddle rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
354ace01605SRiver Riddle return success();
355ace01605SRiver Riddle }
356ace01605SRiver Riddle
matchAndRewrite(IfOp ifOp,PatternRewriter & rewriter) const357ace01605SRiver Riddle LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
358ace01605SRiver Riddle PatternRewriter &rewriter) const {
359ace01605SRiver Riddle auto loc = ifOp.getLoc();
360ace01605SRiver Riddle
361ace01605SRiver Riddle // Start by splitting the block containing the 'scf.if' into two parts.
362ace01605SRiver Riddle // The part before will contain the condition, the part after will be the
363ace01605SRiver Riddle // continuation point.
364ace01605SRiver Riddle auto *condBlock = rewriter.getInsertionBlock();
365ace01605SRiver Riddle auto opPosition = rewriter.getInsertionPoint();
366ace01605SRiver Riddle auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
367ace01605SRiver Riddle Block *continueBlock;
368ace01605SRiver Riddle if (ifOp.getNumResults() == 0) {
369ace01605SRiver Riddle continueBlock = remainingOpsBlock;
370ace01605SRiver Riddle } else {
371ace01605SRiver Riddle continueBlock =
372ace01605SRiver Riddle rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
373ace01605SRiver Riddle SmallVector<Location>(ifOp.getNumResults(), loc));
374ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
375ace01605SRiver Riddle }
376ace01605SRiver Riddle
377ace01605SRiver Riddle // Move blocks from the "then" region to the region containing 'scf.if',
378ace01605SRiver Riddle // place it before the continuation block, and branch to it.
379ace01605SRiver Riddle auto &thenRegion = ifOp.getThenRegion();
380ace01605SRiver Riddle auto *thenBlock = &thenRegion.front();
381ace01605SRiver Riddle Operation *thenTerminator = thenRegion.back().getTerminator();
382ace01605SRiver Riddle ValueRange thenTerminatorOperands = thenTerminator->getOperands();
383ace01605SRiver Riddle rewriter.setInsertionPointToEnd(&thenRegion.back());
384ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
385ace01605SRiver Riddle rewriter.eraseOp(thenTerminator);
386ace01605SRiver Riddle rewriter.inlineRegionBefore(thenRegion, continueBlock);
387ace01605SRiver Riddle
388ace01605SRiver Riddle // Move blocks from the "else" region (if present) to the region containing
389ace01605SRiver Riddle // 'scf.if', place it before the continuation block and branch to it. It
390ace01605SRiver Riddle // will be placed after the "then" regions.
391ace01605SRiver Riddle auto *elseBlock = continueBlock;
392ace01605SRiver Riddle auto &elseRegion = ifOp.getElseRegion();
393ace01605SRiver Riddle if (!elseRegion.empty()) {
394ace01605SRiver Riddle elseBlock = &elseRegion.front();
395ace01605SRiver Riddle Operation *elseTerminator = elseRegion.back().getTerminator();
396ace01605SRiver Riddle ValueRange elseTerminatorOperands = elseTerminator->getOperands();
397ace01605SRiver Riddle rewriter.setInsertionPointToEnd(&elseRegion.back());
398ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
399ace01605SRiver Riddle rewriter.eraseOp(elseTerminator);
400ace01605SRiver Riddle rewriter.inlineRegionBefore(elseRegion, continueBlock);
401ace01605SRiver Riddle }
402ace01605SRiver Riddle
403ace01605SRiver Riddle rewriter.setInsertionPointToEnd(condBlock);
404ace01605SRiver Riddle rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
405ace01605SRiver Riddle /*trueArgs=*/ArrayRef<Value>(), elseBlock,
406ace01605SRiver Riddle /*falseArgs=*/ArrayRef<Value>());
407ace01605SRiver Riddle
408ace01605SRiver Riddle // Ok, we're done!
409ace01605SRiver Riddle rewriter.replaceOp(ifOp, continueBlock->getArguments());
410ace01605SRiver Riddle return success();
411ace01605SRiver Riddle }
412ace01605SRiver Riddle
413ace01605SRiver Riddle LogicalResult
matchAndRewrite(ExecuteRegionOp op,PatternRewriter & rewriter) const414ace01605SRiver Riddle ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
415ace01605SRiver Riddle PatternRewriter &rewriter) const {
416ace01605SRiver Riddle auto loc = op.getLoc();
417ace01605SRiver Riddle
418ace01605SRiver Riddle auto *condBlock = rewriter.getInsertionBlock();
419ace01605SRiver Riddle auto opPosition = rewriter.getInsertionPoint();
420ace01605SRiver Riddle auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
421ace01605SRiver Riddle
422ace01605SRiver Riddle auto ®ion = op.getRegion();
423ace01605SRiver Riddle rewriter.setInsertionPointToEnd(condBlock);
424ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, ®ion.front());
425ace01605SRiver Riddle
426ace01605SRiver Riddle for (Block &block : region) {
427ace01605SRiver Riddle if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
428ace01605SRiver Riddle ValueRange terminatorOperands = terminator->getOperands();
429ace01605SRiver Riddle rewriter.setInsertionPointToEnd(&block);
430ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
431ace01605SRiver Riddle rewriter.eraseOp(terminator);
432ace01605SRiver Riddle }
433ace01605SRiver Riddle }
434ace01605SRiver Riddle
435ace01605SRiver Riddle rewriter.inlineRegionBefore(region, remainingOpsBlock);
436ace01605SRiver Riddle
437ace01605SRiver Riddle SmallVector<Value> vals;
438ace01605SRiver Riddle SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
439ace01605SRiver Riddle for (auto arg :
440ace01605SRiver Riddle remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
441ace01605SRiver Riddle vals.push_back(arg);
442ace01605SRiver Riddle rewriter.replaceOp(op, vals);
443ace01605SRiver Riddle return success();
444ace01605SRiver Riddle }
445ace01605SRiver Riddle
446ace01605SRiver Riddle LogicalResult
matchAndRewrite(ParallelOp parallelOp,PatternRewriter & rewriter) const447ace01605SRiver Riddle ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
448ace01605SRiver Riddle PatternRewriter &rewriter) const {
449ace01605SRiver Riddle Location loc = parallelOp.getLoc();
450ace01605SRiver Riddle
451ace01605SRiver Riddle // For a parallel loop, we essentially need to create an n-dimensional loop
452ace01605SRiver Riddle // nest. We do this by translating to scf.for ops and have those lowered in
453ace01605SRiver Riddle // a further rewrite. If a parallel loop contains reductions (and thus returns
454ace01605SRiver Riddle // values), forward the initial values for the reductions down the loop
455ace01605SRiver Riddle // hierarchy and bubble up the results by modifying the "yield" terminator.
456ace01605SRiver Riddle SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
457ace01605SRiver Riddle SmallVector<Value, 4> ivs;
458ace01605SRiver Riddle ivs.reserve(parallelOp.getNumLoops());
459ace01605SRiver Riddle bool first = true;
460ace01605SRiver Riddle SmallVector<Value, 4> loopResults(iterArgs);
461ace01605SRiver Riddle for (auto loopOperands :
462ace01605SRiver Riddle llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
463ace01605SRiver Riddle parallelOp.getUpperBound(), parallelOp.getStep())) {
464ace01605SRiver Riddle Value iv, lower, upper, step;
465ace01605SRiver Riddle std::tie(iv, lower, upper, step) = loopOperands;
466ace01605SRiver Riddle ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
467ace01605SRiver Riddle ivs.push_back(forOp.getInductionVar());
468ace01605SRiver Riddle auto iterRange = forOp.getRegionIterArgs();
469ace01605SRiver Riddle iterArgs.assign(iterRange.begin(), iterRange.end());
470ace01605SRiver Riddle
471ace01605SRiver Riddle if (first) {
472ace01605SRiver Riddle // Store the results of the outermost loop that will be used to replace
473ace01605SRiver Riddle // the results of the parallel loop when it is fully rewritten.
474ace01605SRiver Riddle loopResults.assign(forOp.result_begin(), forOp.result_end());
475ace01605SRiver Riddle first = false;
476ace01605SRiver Riddle } else if (!forOp.getResults().empty()) {
477ace01605SRiver Riddle // A loop is constructed with an empty "yield" terminator if there are
478ace01605SRiver Riddle // no results.
479ace01605SRiver Riddle rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
480ace01605SRiver Riddle rewriter.create<scf::YieldOp>(loc, forOp.getResults());
481ace01605SRiver Riddle }
482ace01605SRiver Riddle
483ace01605SRiver Riddle rewriter.setInsertionPointToStart(forOp.getBody());
484ace01605SRiver Riddle }
485ace01605SRiver Riddle
486ace01605SRiver Riddle // First, merge reduction blocks into the main region.
487ace01605SRiver Riddle SmallVector<Value, 4> yieldOperands;
488ace01605SRiver Riddle yieldOperands.reserve(parallelOp.getNumResults());
489ace01605SRiver Riddle for (auto &op : *parallelOp.getBody()) {
490ace01605SRiver Riddle auto reduce = dyn_cast<ReduceOp>(op);
491ace01605SRiver Riddle if (!reduce)
492ace01605SRiver Riddle continue;
493ace01605SRiver Riddle
494ace01605SRiver Riddle Block &reduceBlock = reduce.getReductionOperator().front();
495ace01605SRiver Riddle Value arg = iterArgs[yieldOperands.size()];
496ace01605SRiver Riddle yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
497ace01605SRiver Riddle rewriter.eraseOp(reduceBlock.getTerminator());
498ace01605SRiver Riddle rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
499ace01605SRiver Riddle rewriter.eraseOp(reduce);
500ace01605SRiver Riddle }
501ace01605SRiver Riddle
502ace01605SRiver Riddle // Then merge the loop body without the terminator.
503ace01605SRiver Riddle rewriter.eraseOp(parallelOp.getBody()->getTerminator());
504ace01605SRiver Riddle Block *newBody = rewriter.getInsertionBlock();
505ace01605SRiver Riddle if (newBody->empty())
506ace01605SRiver Riddle rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
507ace01605SRiver Riddle else
508ace01605SRiver Riddle rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
509ace01605SRiver Riddle ivs);
510ace01605SRiver Riddle
511ace01605SRiver Riddle // Finally, create the terminator if required (for loops with no results, it
512ace01605SRiver Riddle // has been already created in loop construction).
513ace01605SRiver Riddle if (!yieldOperands.empty()) {
514ace01605SRiver Riddle rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
515ace01605SRiver Riddle rewriter.create<scf::YieldOp>(loc, yieldOperands);
516ace01605SRiver Riddle }
517ace01605SRiver Riddle
518ace01605SRiver Riddle rewriter.replaceOp(parallelOp, loopResults);
519ace01605SRiver Riddle
520ace01605SRiver Riddle return success();
521ace01605SRiver Riddle }
522ace01605SRiver Riddle
matchAndRewrite(WhileOp whileOp,PatternRewriter & rewriter) const523ace01605SRiver Riddle LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
524ace01605SRiver Riddle PatternRewriter &rewriter) const {
525ace01605SRiver Riddle OpBuilder::InsertionGuard guard(rewriter);
526ace01605SRiver Riddle Location loc = whileOp.getLoc();
527ace01605SRiver Riddle
528ace01605SRiver Riddle // Split the current block before the WhileOp to create the inlining point.
529ace01605SRiver Riddle Block *currentBlock = rewriter.getInsertionBlock();
530ace01605SRiver Riddle Block *continuation =
531ace01605SRiver Riddle rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
532ace01605SRiver Riddle
533ace01605SRiver Riddle // Inline both regions.
534ace01605SRiver Riddle Block *after = &whileOp.getAfter().front();
535ace01605SRiver Riddle Block *afterLast = &whileOp.getAfter().back();
536ace01605SRiver Riddle Block *before = &whileOp.getBefore().front();
537ace01605SRiver Riddle Block *beforeLast = &whileOp.getBefore().back();
538ace01605SRiver Riddle rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
539ace01605SRiver Riddle rewriter.inlineRegionBefore(whileOp.getBefore(), after);
540ace01605SRiver Riddle
541ace01605SRiver Riddle // Branch to the "before" region.
542ace01605SRiver Riddle rewriter.setInsertionPointToEnd(currentBlock);
543ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
544ace01605SRiver Riddle
545ace01605SRiver Riddle // Replace terminators with branches. Assuming bodies are SESE, which holds
546ace01605SRiver Riddle // given only the patterns from this file, we only need to look at the last
547ace01605SRiver Riddle // block. This should be reconsidered if we allow break/continue in SCF.
548ace01605SRiver Riddle rewriter.setInsertionPointToEnd(beforeLast);
549ace01605SRiver Riddle auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
550ace01605SRiver Riddle rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
551ace01605SRiver Riddle after, condOp.getArgs(),
552ace01605SRiver Riddle continuation, ValueRange());
553ace01605SRiver Riddle
554ace01605SRiver Riddle rewriter.setInsertionPointToEnd(afterLast);
555ace01605SRiver Riddle auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
556ace01605SRiver Riddle rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
557ace01605SRiver Riddle yieldOp.getResults());
558ace01605SRiver Riddle
559ace01605SRiver Riddle // Replace the op with values "yielded" from the "before" region, which are
560ace01605SRiver Riddle // visible by dominance.
561ace01605SRiver Riddle rewriter.replaceOp(whileOp, condOp.getArgs());
562ace01605SRiver Riddle
563ace01605SRiver Riddle return success();
564ace01605SRiver Riddle }
565ace01605SRiver Riddle
566ace01605SRiver Riddle LogicalResult
matchAndRewrite(WhileOp whileOp,PatternRewriter & rewriter) const567ace01605SRiver Riddle DoWhileLowering::matchAndRewrite(WhileOp whileOp,
568ace01605SRiver Riddle PatternRewriter &rewriter) const {
569ace01605SRiver Riddle if (!llvm::hasSingleElement(whileOp.getAfter()))
570ace01605SRiver Riddle return rewriter.notifyMatchFailure(whileOp,
571ace01605SRiver Riddle "do-while simplification applicable to "
572ace01605SRiver Riddle "single-block 'after' region only");
573ace01605SRiver Riddle
574ace01605SRiver Riddle Block &afterBlock = whileOp.getAfter().front();
575ace01605SRiver Riddle if (!llvm::hasSingleElement(afterBlock))
576ace01605SRiver Riddle return rewriter.notifyMatchFailure(whileOp,
577ace01605SRiver Riddle "do-while simplification applicable "
578ace01605SRiver Riddle "only if 'after' region has no payload");
579ace01605SRiver Riddle
580ace01605SRiver Riddle auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
581ace01605SRiver Riddle if (!yield || yield.getResults() != afterBlock.getArguments())
582ace01605SRiver Riddle return rewriter.notifyMatchFailure(whileOp,
583ace01605SRiver Riddle "do-while simplification applicable "
584ace01605SRiver Riddle "only to forwarding 'after' regions");
585ace01605SRiver Riddle
586ace01605SRiver Riddle // Split the current block before the WhileOp to create the inlining point.
587ace01605SRiver Riddle OpBuilder::InsertionGuard guard(rewriter);
588ace01605SRiver Riddle Block *currentBlock = rewriter.getInsertionBlock();
589ace01605SRiver Riddle Block *continuation =
590ace01605SRiver Riddle rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
591ace01605SRiver Riddle
592ace01605SRiver Riddle // Only the "before" region should be inlined.
593ace01605SRiver Riddle Block *before = &whileOp.getBefore().front();
594ace01605SRiver Riddle Block *beforeLast = &whileOp.getBefore().back();
595ace01605SRiver Riddle rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
596ace01605SRiver Riddle
597ace01605SRiver Riddle // Branch to the "before" region.
598ace01605SRiver Riddle rewriter.setInsertionPointToEnd(currentBlock);
599ace01605SRiver Riddle rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
600ace01605SRiver Riddle
601ace01605SRiver Riddle // Loop around the "before" region based on condition.
602ace01605SRiver Riddle rewriter.setInsertionPointToEnd(beforeLast);
603ace01605SRiver Riddle auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
604ace01605SRiver Riddle rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
605ace01605SRiver Riddle before, condOp.getArgs(),
606ace01605SRiver Riddle continuation, ValueRange());
607ace01605SRiver Riddle
608ace01605SRiver Riddle // Replace the op with values "yielded" from the "before" region, which are
609ace01605SRiver Riddle // visible by dominance.
610ace01605SRiver Riddle rewriter.replaceOp(whileOp, condOp.getArgs());
611ace01605SRiver Riddle
612ace01605SRiver Riddle return success();
613ace01605SRiver Riddle }
614ace01605SRiver Riddle
populateSCFToControlFlowConversionPatterns(RewritePatternSet & patterns)615ace01605SRiver Riddle void mlir::populateSCFToControlFlowConversionPatterns(
616ace01605SRiver Riddle RewritePatternSet &patterns) {
617ace01605SRiver Riddle patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
618ace01605SRiver Riddle ExecuteRegionLowering>(patterns.getContext());
619ace01605SRiver Riddle patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
620ace01605SRiver Riddle }
621ace01605SRiver Riddle
runOnOperation()622ace01605SRiver Riddle void SCFToControlFlowPass::runOnOperation() {
623ace01605SRiver Riddle RewritePatternSet patterns(&getContext());
624ace01605SRiver Riddle populateSCFToControlFlowConversionPatterns(patterns);
625ace01605SRiver Riddle
626ace01605SRiver Riddle // Configure conversion to lower out SCF operations.
627ace01605SRiver Riddle ConversionTarget target(getContext());
628ace01605SRiver Riddle target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
629ace01605SRiver Riddle scf::ExecuteRegionOp>();
630ace01605SRiver Riddle target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
631ace01605SRiver Riddle if (failed(
632ace01605SRiver Riddle applyPartialConversion(getOperation(), target, std::move(patterns))))
633ace01605SRiver Riddle signalPassFailure();
634ace01605SRiver Riddle }
635ace01605SRiver Riddle
createConvertSCFToCFPass()636ace01605SRiver Riddle std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() {
637ace01605SRiver Riddle return std::make_unique<SCFToControlFlowPass>();
638ace01605SRiver Riddle }
639