1*ace01605SRiver Riddle //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===// 2*ace01605SRiver Riddle // 3*ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5*ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*ace01605SRiver Riddle // 7*ace01605SRiver Riddle //===----------------------------------------------------------------------===// 8*ace01605SRiver Riddle // 9*ace01605SRiver Riddle // This file implements a pass to convert scf.for, scf.if and loop.terminator 10*ace01605SRiver Riddle // ops into standard CFG ops. 11*ace01605SRiver Riddle // 12*ace01605SRiver Riddle //===----------------------------------------------------------------------===// 13*ace01605SRiver Riddle 14*ace01605SRiver Riddle #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 15*ace01605SRiver Riddle #include "../PassDetail.h" 16*ace01605SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17*ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 18*ace01605SRiver Riddle #include "mlir/Dialect/SCF/SCF.h" 19*ace01605SRiver Riddle #include "mlir/IR/BlockAndValueMapping.h" 20*ace01605SRiver Riddle #include "mlir/IR/Builders.h" 21*ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h" 22*ace01605SRiver Riddle #include "mlir/IR/MLIRContext.h" 23*ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h" 24*ace01605SRiver Riddle #include "mlir/Transforms/DialectConversion.h" 25*ace01605SRiver Riddle #include "mlir/Transforms/Passes.h" 26*ace01605SRiver Riddle 27*ace01605SRiver Riddle using namespace mlir; 28*ace01605SRiver Riddle using namespace mlir::scf; 29*ace01605SRiver Riddle 30*ace01605SRiver Riddle namespace { 31*ace01605SRiver Riddle 32*ace01605SRiver Riddle struct SCFToControlFlowPass 33*ace01605SRiver Riddle : public SCFToControlFlowBase<SCFToControlFlowPass> { 34*ace01605SRiver Riddle void runOnOperation() override; 35*ace01605SRiver Riddle }; 36*ace01605SRiver Riddle 37*ace01605SRiver Riddle // Create a CFG subgraph for the loop around its body blocks (if the body 38*ace01605SRiver Riddle // contained other loops, they have been already lowered to a flow of blocks). 39*ace01605SRiver Riddle // Maintain the invariants that a CFG subgraph created for any loop has a single 40*ace01605SRiver Riddle // entry and a single exit, and that the entry/exit blocks are respectively 41*ace01605SRiver Riddle // first/last blocks in the parent region. The original loop operation is 42*ace01605SRiver Riddle // replaced by the initialization operations that set up the initial value of 43*ace01605SRiver Riddle // the loop induction variable (%iv) and computes the loop bounds that are loop- 44*ace01605SRiver Riddle // invariant for affine loops. The operations following the original scf.for 45*ace01605SRiver Riddle // are split out into a separate continuation (exit) block. A condition block is 46*ace01605SRiver Riddle // created before the continuation block. It checks the exit condition of the 47*ace01605SRiver Riddle // loop and branches either to the continuation block, or to the first block of 48*ace01605SRiver Riddle // the body. The condition block takes as arguments the values of the induction 49*ace01605SRiver Riddle // variable followed by loop-carried values. Since it dominates both the body 50*ace01605SRiver Riddle // blocks and the continuation block, loop-carried values are visible in all of 51*ace01605SRiver Riddle // those blocks. Induction variable modification is appended to the last block 52*ace01605SRiver Riddle // of the body (which is the exit block from the body subgraph thanks to the 53*ace01605SRiver Riddle // invariant we maintain) along with a branch that loops back to the condition 54*ace01605SRiver Riddle // block. Loop-carried values are the loop terminator operands, which are 55*ace01605SRiver Riddle // forwarded to the branch. 56*ace01605SRiver Riddle // 57*ace01605SRiver Riddle // +---------------------------------+ 58*ace01605SRiver Riddle // | <code before the ForOp> | 59*ace01605SRiver Riddle // | <definitions of %init...> | 60*ace01605SRiver Riddle // | <compute initial %iv value> | 61*ace01605SRiver Riddle // | cf.br cond(%iv, %init...) | 62*ace01605SRiver Riddle // +---------------------------------+ 63*ace01605SRiver Riddle // | 64*ace01605SRiver Riddle // -------| | 65*ace01605SRiver Riddle // | v v 66*ace01605SRiver Riddle // | +--------------------------------+ 67*ace01605SRiver Riddle // | | cond(%iv, %init...): | 68*ace01605SRiver Riddle // | | <compare %iv to upper bound> | 69*ace01605SRiver Riddle // | | cf.cond_br %r, body, end | 70*ace01605SRiver Riddle // | +--------------------------------+ 71*ace01605SRiver Riddle // | | | 72*ace01605SRiver Riddle // | | -------------| 73*ace01605SRiver Riddle // | v | 74*ace01605SRiver Riddle // | +--------------------------------+ | 75*ace01605SRiver Riddle // | | body-first: | | 76*ace01605SRiver Riddle // | | <%init visible by dominance> | | 77*ace01605SRiver Riddle // | | <body contents> | | 78*ace01605SRiver Riddle // | +--------------------------------+ | 79*ace01605SRiver Riddle // | | | 80*ace01605SRiver Riddle // | ... | 81*ace01605SRiver Riddle // | | | 82*ace01605SRiver Riddle // | +--------------------------------+ | 83*ace01605SRiver Riddle // | | body-last: | | 84*ace01605SRiver Riddle // | | <body contents> | | 85*ace01605SRiver Riddle // | | <operands of yield = %yields>| | 86*ace01605SRiver Riddle // | | %new_iv =<add step to %iv> | | 87*ace01605SRiver Riddle // | | cf.br cond(%new_iv, %yields) | | 88*ace01605SRiver Riddle // | +--------------------------------+ | 89*ace01605SRiver Riddle // | | | 90*ace01605SRiver Riddle // |----------- |-------------------- 91*ace01605SRiver Riddle // v 92*ace01605SRiver Riddle // +--------------------------------+ 93*ace01605SRiver Riddle // | end: | 94*ace01605SRiver Riddle // | <code after the ForOp> | 95*ace01605SRiver Riddle // | <%init visible by dominance> | 96*ace01605SRiver Riddle // +--------------------------------+ 97*ace01605SRiver Riddle // 98*ace01605SRiver Riddle struct ForLowering : public OpRewritePattern<ForOp> { 99*ace01605SRiver Riddle using OpRewritePattern<ForOp>::OpRewritePattern; 100*ace01605SRiver Riddle 101*ace01605SRiver Riddle LogicalResult matchAndRewrite(ForOp forOp, 102*ace01605SRiver Riddle PatternRewriter &rewriter) const override; 103*ace01605SRiver Riddle }; 104*ace01605SRiver Riddle 105*ace01605SRiver Riddle // Create a CFG subgraph for the scf.if operation (including its "then" and 106*ace01605SRiver Riddle // optional "else" operation blocks). We maintain the invariants that the 107*ace01605SRiver Riddle // subgraph has a single entry and a single exit point, and that the entry/exit 108*ace01605SRiver Riddle // blocks are respectively the first/last block of the enclosing region. The 109*ace01605SRiver Riddle // operations following the scf.if are split into a continuation (subgraph 110*ace01605SRiver Riddle // exit) block. The condition is lowered to a chain of blocks that implement the 111*ace01605SRiver Riddle // short-circuit scheme. The "scf.if" operation is replaced with a conditional 112*ace01605SRiver Riddle // branch to either the first block of the "then" region, or to the first block 113*ace01605SRiver Riddle // of the "else" region. In these blocks, "scf.yield" is unconditional branches 114*ace01605SRiver Riddle // to the post-dominating block. When the "scf.if" does not return values, the 115*ace01605SRiver Riddle // post-dominating block is the same as the continuation block. When it returns 116*ace01605SRiver Riddle // values, the post-dominating block is a new block with arguments that 117*ace01605SRiver Riddle // correspond to the values returned by the "scf.if" that unconditionally 118*ace01605SRiver Riddle // branches to the continuation block. This allows block arguments to dominate 119*ace01605SRiver Riddle // any uses of the hitherto "scf.if" results that they replaced. (Inserting a 120*ace01605SRiver Riddle // new block allows us to avoid modifying the argument list of an existing 121*ace01605SRiver Riddle // block, which is illegal in a conversion pattern). When the "else" region is 122*ace01605SRiver Riddle // empty, which is only allowed for "scf.if"s that don't return values, the 123*ace01605SRiver Riddle // condition branches directly to the continuation block. 124*ace01605SRiver Riddle // 125*ace01605SRiver Riddle // CFG for a scf.if with else and without results. 126*ace01605SRiver Riddle // 127*ace01605SRiver Riddle // +--------------------------------+ 128*ace01605SRiver Riddle // | <code before the IfOp> | 129*ace01605SRiver Riddle // | cf.cond_br %cond, %then, %else | 130*ace01605SRiver Riddle // +--------------------------------+ 131*ace01605SRiver Riddle // | | 132*ace01605SRiver Riddle // | --------------| 133*ace01605SRiver Riddle // v | 134*ace01605SRiver Riddle // +--------------------------------+ | 135*ace01605SRiver Riddle // | then: | | 136*ace01605SRiver Riddle // | <then contents> | | 137*ace01605SRiver Riddle // | cf.br continue | | 138*ace01605SRiver Riddle // +--------------------------------+ | 139*ace01605SRiver Riddle // | | 140*ace01605SRiver Riddle // |---------- |------------- 141*ace01605SRiver Riddle // | V 142*ace01605SRiver Riddle // | +--------------------------------+ 143*ace01605SRiver Riddle // | | else: | 144*ace01605SRiver Riddle // | | <else contents> | 145*ace01605SRiver Riddle // | | cf.br continue | 146*ace01605SRiver Riddle // | +--------------------------------+ 147*ace01605SRiver Riddle // | | 148*ace01605SRiver Riddle // ------| | 149*ace01605SRiver Riddle // v v 150*ace01605SRiver Riddle // +--------------------------------+ 151*ace01605SRiver Riddle // | continue: | 152*ace01605SRiver Riddle // | <code after the IfOp> | 153*ace01605SRiver Riddle // +--------------------------------+ 154*ace01605SRiver Riddle // 155*ace01605SRiver Riddle // CFG for a scf.if with results. 156*ace01605SRiver Riddle // 157*ace01605SRiver Riddle // +--------------------------------+ 158*ace01605SRiver Riddle // | <code before the IfOp> | 159*ace01605SRiver Riddle // | cf.cond_br %cond, %then, %else | 160*ace01605SRiver Riddle // +--------------------------------+ 161*ace01605SRiver Riddle // | | 162*ace01605SRiver Riddle // | --------------| 163*ace01605SRiver Riddle // v | 164*ace01605SRiver Riddle // +--------------------------------+ | 165*ace01605SRiver Riddle // | then: | | 166*ace01605SRiver Riddle // | <then contents> | | 167*ace01605SRiver Riddle // | cf.br dom(%args...) | | 168*ace01605SRiver Riddle // +--------------------------------+ | 169*ace01605SRiver Riddle // | | 170*ace01605SRiver Riddle // |---------- |------------- 171*ace01605SRiver Riddle // | V 172*ace01605SRiver Riddle // | +--------------------------------+ 173*ace01605SRiver Riddle // | | else: | 174*ace01605SRiver Riddle // | | <else contents> | 175*ace01605SRiver Riddle // | | cf.br dom(%args...) | 176*ace01605SRiver Riddle // | +--------------------------------+ 177*ace01605SRiver Riddle // | | 178*ace01605SRiver Riddle // ------| | 179*ace01605SRiver Riddle // v v 180*ace01605SRiver Riddle // +--------------------------------+ 181*ace01605SRiver Riddle // | dom(%args...): | 182*ace01605SRiver Riddle // | cf.br continue | 183*ace01605SRiver Riddle // +--------------------------------+ 184*ace01605SRiver Riddle // | 185*ace01605SRiver Riddle // v 186*ace01605SRiver Riddle // +--------------------------------+ 187*ace01605SRiver Riddle // | continue: | 188*ace01605SRiver Riddle // | <code after the IfOp> | 189*ace01605SRiver Riddle // +--------------------------------+ 190*ace01605SRiver Riddle // 191*ace01605SRiver Riddle struct IfLowering : public OpRewritePattern<IfOp> { 192*ace01605SRiver Riddle using OpRewritePattern<IfOp>::OpRewritePattern; 193*ace01605SRiver Riddle 194*ace01605SRiver Riddle LogicalResult matchAndRewrite(IfOp ifOp, 195*ace01605SRiver Riddle PatternRewriter &rewriter) const override; 196*ace01605SRiver Riddle }; 197*ace01605SRiver Riddle 198*ace01605SRiver Riddle struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> { 199*ace01605SRiver Riddle using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; 200*ace01605SRiver Riddle 201*ace01605SRiver Riddle LogicalResult matchAndRewrite(ExecuteRegionOp op, 202*ace01605SRiver Riddle PatternRewriter &rewriter) const override; 203*ace01605SRiver Riddle }; 204*ace01605SRiver Riddle 205*ace01605SRiver Riddle struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> { 206*ace01605SRiver Riddle using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern; 207*ace01605SRiver Riddle 208*ace01605SRiver Riddle LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp, 209*ace01605SRiver Riddle PatternRewriter &rewriter) const override; 210*ace01605SRiver Riddle }; 211*ace01605SRiver Riddle 212*ace01605SRiver Riddle /// Create a CFG subgraph for this loop construct. The regions of the loop need 213*ace01605SRiver Riddle /// not be a single block anymore (for example, if other SCF constructs that 214*ace01605SRiver Riddle /// they contain have been already converted to CFG), but need to be single-exit 215*ace01605SRiver Riddle /// from the last block of each region. The operations following the original 216*ace01605SRiver Riddle /// WhileOp are split into a new continuation block. Both regions of the WhileOp 217*ace01605SRiver Riddle /// are inlined, and their terminators are rewritten to organize the control 218*ace01605SRiver Riddle /// flow implementing the loop as follows. 219*ace01605SRiver Riddle /// 220*ace01605SRiver Riddle /// +---------------------------------+ 221*ace01605SRiver Riddle /// | <code before the WhileOp> | 222*ace01605SRiver Riddle /// | cf.br ^before(%operands...) | 223*ace01605SRiver Riddle /// +---------------------------------+ 224*ace01605SRiver Riddle /// | 225*ace01605SRiver Riddle /// -------| | 226*ace01605SRiver Riddle /// | v v 227*ace01605SRiver Riddle /// | +--------------------------------+ 228*ace01605SRiver Riddle /// | | ^before(%bargs...): | 229*ace01605SRiver Riddle /// | | %vals... = <some payload> | 230*ace01605SRiver Riddle /// | +--------------------------------+ 231*ace01605SRiver Riddle /// | | 232*ace01605SRiver Riddle /// | ... 233*ace01605SRiver Riddle /// | | 234*ace01605SRiver Riddle /// | +--------------------------------+ 235*ace01605SRiver Riddle /// | | ^before-last: 236*ace01605SRiver Riddle /// | | %cond = <compute condition> | 237*ace01605SRiver Riddle /// | | cf.cond_br %cond, | 238*ace01605SRiver Riddle /// | | ^after(%vals...), ^cont | 239*ace01605SRiver Riddle /// | +--------------------------------+ 240*ace01605SRiver Riddle /// | | | 241*ace01605SRiver Riddle /// | | -------------| 242*ace01605SRiver Riddle /// | v | 243*ace01605SRiver Riddle /// | +--------------------------------+ | 244*ace01605SRiver Riddle /// | | ^after(%aargs...): | | 245*ace01605SRiver Riddle /// | | <body contents> | | 246*ace01605SRiver Riddle /// | +--------------------------------+ | 247*ace01605SRiver Riddle /// | | | 248*ace01605SRiver Riddle /// | ... | 249*ace01605SRiver Riddle /// | | | 250*ace01605SRiver Riddle /// | +--------------------------------+ | 251*ace01605SRiver Riddle /// | | ^after-last: | | 252*ace01605SRiver Riddle /// | | %yields... = <some payload> | | 253*ace01605SRiver Riddle /// | | cf.br ^before(%yields...) | | 254*ace01605SRiver Riddle /// | +--------------------------------+ | 255*ace01605SRiver Riddle /// | | | 256*ace01605SRiver Riddle /// |----------- |-------------------- 257*ace01605SRiver Riddle /// v 258*ace01605SRiver Riddle /// +--------------------------------+ 259*ace01605SRiver Riddle /// | ^cont: | 260*ace01605SRiver Riddle /// | <code after the WhileOp> | 261*ace01605SRiver Riddle /// | <%vals from 'before' region | 262*ace01605SRiver Riddle /// | visible by dominance> | 263*ace01605SRiver Riddle /// +--------------------------------+ 264*ace01605SRiver Riddle /// 265*ace01605SRiver Riddle /// Values are communicated between ex-regions (the groups of blocks that used 266*ace01605SRiver Riddle /// to form a region before inlining) through block arguments of their 267*ace01605SRiver Riddle /// entry blocks, which are visible in all other dominated blocks. Similarly, 268*ace01605SRiver Riddle /// the results of the WhileOp are defined in the 'before' region, which is 269*ace01605SRiver Riddle /// required to have a single existing block, and are therefore accessible in 270*ace01605SRiver Riddle /// the continuation block due to dominance. 271*ace01605SRiver Riddle struct WhileLowering : public OpRewritePattern<WhileOp> { 272*ace01605SRiver Riddle using OpRewritePattern<WhileOp>::OpRewritePattern; 273*ace01605SRiver Riddle 274*ace01605SRiver Riddle LogicalResult matchAndRewrite(WhileOp whileOp, 275*ace01605SRiver Riddle PatternRewriter &rewriter) const override; 276*ace01605SRiver Riddle }; 277*ace01605SRiver Riddle 278*ace01605SRiver Riddle /// Optimized version of the above for the case of the "after" region merely 279*ace01605SRiver Riddle /// forwarding its arguments back to the "before" region (i.e., a "do-while" 280*ace01605SRiver Riddle /// loop). This avoid inlining the "after" region completely and branches back 281*ace01605SRiver Riddle /// to the "before" entry instead. 282*ace01605SRiver Riddle struct DoWhileLowering : public OpRewritePattern<WhileOp> { 283*ace01605SRiver Riddle using OpRewritePattern<WhileOp>::OpRewritePattern; 284*ace01605SRiver Riddle 285*ace01605SRiver Riddle LogicalResult matchAndRewrite(WhileOp whileOp, 286*ace01605SRiver Riddle PatternRewriter &rewriter) const override; 287*ace01605SRiver Riddle }; 288*ace01605SRiver Riddle } // namespace 289*ace01605SRiver Riddle 290*ace01605SRiver Riddle LogicalResult ForLowering::matchAndRewrite(ForOp forOp, 291*ace01605SRiver Riddle PatternRewriter &rewriter) const { 292*ace01605SRiver Riddle Location loc = forOp.getLoc(); 293*ace01605SRiver Riddle 294*ace01605SRiver Riddle // Start by splitting the block containing the 'scf.for' into two parts. 295*ace01605SRiver Riddle // The part before will get the init code, the part after will be the end 296*ace01605SRiver Riddle // point. 297*ace01605SRiver Riddle auto *initBlock = rewriter.getInsertionBlock(); 298*ace01605SRiver Riddle auto initPosition = rewriter.getInsertionPoint(); 299*ace01605SRiver Riddle auto *endBlock = rewriter.splitBlock(initBlock, initPosition); 300*ace01605SRiver Riddle 301*ace01605SRiver Riddle // Use the first block of the loop body as the condition block since it is the 302*ace01605SRiver Riddle // block that has the induction variable and loop-carried values as arguments. 303*ace01605SRiver Riddle // Split out all operations from the first block into a new block. Move all 304*ace01605SRiver Riddle // body blocks from the loop body region to the region containing the loop. 305*ace01605SRiver Riddle auto *conditionBlock = &forOp.getRegion().front(); 306*ace01605SRiver Riddle auto *firstBodyBlock = 307*ace01605SRiver Riddle rewriter.splitBlock(conditionBlock, conditionBlock->begin()); 308*ace01605SRiver Riddle auto *lastBodyBlock = &forOp.getRegion().back(); 309*ace01605SRiver Riddle rewriter.inlineRegionBefore(forOp.getRegion(), endBlock); 310*ace01605SRiver Riddle auto iv = conditionBlock->getArgument(0); 311*ace01605SRiver Riddle 312*ace01605SRiver Riddle // Append the induction variable stepping logic to the last body block and 313*ace01605SRiver Riddle // branch back to the condition block. Loop-carried values are taken from 314*ace01605SRiver Riddle // operands of the loop terminator. 315*ace01605SRiver Riddle Operation *terminator = lastBodyBlock->getTerminator(); 316*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(lastBodyBlock); 317*ace01605SRiver Riddle auto step = forOp.getStep(); 318*ace01605SRiver Riddle auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult(); 319*ace01605SRiver Riddle if (!stepped) 320*ace01605SRiver Riddle return failure(); 321*ace01605SRiver Riddle 322*ace01605SRiver Riddle SmallVector<Value, 8> loopCarried; 323*ace01605SRiver Riddle loopCarried.push_back(stepped); 324*ace01605SRiver Riddle loopCarried.append(terminator->operand_begin(), terminator->operand_end()); 325*ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried); 326*ace01605SRiver Riddle rewriter.eraseOp(terminator); 327*ace01605SRiver Riddle 328*ace01605SRiver Riddle // Compute loop bounds before branching to the condition. 329*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(initBlock); 330*ace01605SRiver Riddle Value lowerBound = forOp.getLowerBound(); 331*ace01605SRiver Riddle Value upperBound = forOp.getUpperBound(); 332*ace01605SRiver Riddle if (!lowerBound || !upperBound) 333*ace01605SRiver Riddle return failure(); 334*ace01605SRiver Riddle 335*ace01605SRiver Riddle // The initial values of loop-carried values is obtained from the operands 336*ace01605SRiver Riddle // of the loop operation. 337*ace01605SRiver Riddle SmallVector<Value, 8> destOperands; 338*ace01605SRiver Riddle destOperands.push_back(lowerBound); 339*ace01605SRiver Riddle auto iterOperands = forOp.getIterOperands(); 340*ace01605SRiver Riddle destOperands.append(iterOperands.begin(), iterOperands.end()); 341*ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands); 342*ace01605SRiver Riddle 343*ace01605SRiver Riddle // With the body block done, we can fill in the condition block. 344*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(conditionBlock); 345*ace01605SRiver Riddle auto comparison = rewriter.create<arith::CmpIOp>( 346*ace01605SRiver Riddle loc, arith::CmpIPredicate::slt, iv, upperBound); 347*ace01605SRiver Riddle 348*ace01605SRiver Riddle rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock, 349*ace01605SRiver Riddle ArrayRef<Value>(), endBlock, 350*ace01605SRiver Riddle ArrayRef<Value>()); 351*ace01605SRiver Riddle // The result of the loop operation is the values of the condition block 352*ace01605SRiver Riddle // arguments except the induction variable on the last iteration. 353*ace01605SRiver Riddle rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); 354*ace01605SRiver Riddle return success(); 355*ace01605SRiver Riddle } 356*ace01605SRiver Riddle 357*ace01605SRiver Riddle LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, 358*ace01605SRiver Riddle PatternRewriter &rewriter) const { 359*ace01605SRiver Riddle auto loc = ifOp.getLoc(); 360*ace01605SRiver Riddle 361*ace01605SRiver Riddle // Start by splitting the block containing the 'scf.if' into two parts. 362*ace01605SRiver Riddle // The part before will contain the condition, the part after will be the 363*ace01605SRiver Riddle // continuation point. 364*ace01605SRiver Riddle auto *condBlock = rewriter.getInsertionBlock(); 365*ace01605SRiver Riddle auto opPosition = rewriter.getInsertionPoint(); 366*ace01605SRiver Riddle auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 367*ace01605SRiver Riddle Block *continueBlock; 368*ace01605SRiver Riddle if (ifOp.getNumResults() == 0) { 369*ace01605SRiver Riddle continueBlock = remainingOpsBlock; 370*ace01605SRiver Riddle } else { 371*ace01605SRiver Riddle continueBlock = 372*ace01605SRiver Riddle rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), 373*ace01605SRiver Riddle SmallVector<Location>(ifOp.getNumResults(), loc)); 374*ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, remainingOpsBlock); 375*ace01605SRiver Riddle } 376*ace01605SRiver Riddle 377*ace01605SRiver Riddle // Move blocks from the "then" region to the region containing 'scf.if', 378*ace01605SRiver Riddle // place it before the continuation block, and branch to it. 379*ace01605SRiver Riddle auto &thenRegion = ifOp.getThenRegion(); 380*ace01605SRiver Riddle auto *thenBlock = &thenRegion.front(); 381*ace01605SRiver Riddle Operation *thenTerminator = thenRegion.back().getTerminator(); 382*ace01605SRiver Riddle ValueRange thenTerminatorOperands = thenTerminator->getOperands(); 383*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(&thenRegion.back()); 384*ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands); 385*ace01605SRiver Riddle rewriter.eraseOp(thenTerminator); 386*ace01605SRiver Riddle rewriter.inlineRegionBefore(thenRegion, continueBlock); 387*ace01605SRiver Riddle 388*ace01605SRiver Riddle // Move blocks from the "else" region (if present) to the region containing 389*ace01605SRiver Riddle // 'scf.if', place it before the continuation block and branch to it. It 390*ace01605SRiver Riddle // will be placed after the "then" regions. 391*ace01605SRiver Riddle auto *elseBlock = continueBlock; 392*ace01605SRiver Riddle auto &elseRegion = ifOp.getElseRegion(); 393*ace01605SRiver Riddle if (!elseRegion.empty()) { 394*ace01605SRiver Riddle elseBlock = &elseRegion.front(); 395*ace01605SRiver Riddle Operation *elseTerminator = elseRegion.back().getTerminator(); 396*ace01605SRiver Riddle ValueRange elseTerminatorOperands = elseTerminator->getOperands(); 397*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(&elseRegion.back()); 398*ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands); 399*ace01605SRiver Riddle rewriter.eraseOp(elseTerminator); 400*ace01605SRiver Riddle rewriter.inlineRegionBefore(elseRegion, continueBlock); 401*ace01605SRiver Riddle } 402*ace01605SRiver Riddle 403*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(condBlock); 404*ace01605SRiver Riddle rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock, 405*ace01605SRiver Riddle /*trueArgs=*/ArrayRef<Value>(), elseBlock, 406*ace01605SRiver Riddle /*falseArgs=*/ArrayRef<Value>()); 407*ace01605SRiver Riddle 408*ace01605SRiver Riddle // Ok, we're done! 409*ace01605SRiver Riddle rewriter.replaceOp(ifOp, continueBlock->getArguments()); 410*ace01605SRiver Riddle return success(); 411*ace01605SRiver Riddle } 412*ace01605SRiver Riddle 413*ace01605SRiver Riddle LogicalResult 414*ace01605SRiver Riddle ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, 415*ace01605SRiver Riddle PatternRewriter &rewriter) const { 416*ace01605SRiver Riddle auto loc = op.getLoc(); 417*ace01605SRiver Riddle 418*ace01605SRiver Riddle auto *condBlock = rewriter.getInsertionBlock(); 419*ace01605SRiver Riddle auto opPosition = rewriter.getInsertionPoint(); 420*ace01605SRiver Riddle auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 421*ace01605SRiver Riddle 422*ace01605SRiver Riddle auto ®ion = op.getRegion(); 423*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(condBlock); 424*ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, ®ion.front()); 425*ace01605SRiver Riddle 426*ace01605SRiver Riddle for (Block &block : region) { 427*ace01605SRiver Riddle if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) { 428*ace01605SRiver Riddle ValueRange terminatorOperands = terminator->getOperands(); 429*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(&block); 430*ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands); 431*ace01605SRiver Riddle rewriter.eraseOp(terminator); 432*ace01605SRiver Riddle } 433*ace01605SRiver Riddle } 434*ace01605SRiver Riddle 435*ace01605SRiver Riddle rewriter.inlineRegionBefore(region, remainingOpsBlock); 436*ace01605SRiver Riddle 437*ace01605SRiver Riddle SmallVector<Value> vals; 438*ace01605SRiver Riddle SmallVector<Location> argLocs(op.getNumResults(), op->getLoc()); 439*ace01605SRiver Riddle for (auto arg : 440*ace01605SRiver Riddle remainingOpsBlock->addArguments(op->getResultTypes(), argLocs)) 441*ace01605SRiver Riddle vals.push_back(arg); 442*ace01605SRiver Riddle rewriter.replaceOp(op, vals); 443*ace01605SRiver Riddle return success(); 444*ace01605SRiver Riddle } 445*ace01605SRiver Riddle 446*ace01605SRiver Riddle LogicalResult 447*ace01605SRiver Riddle ParallelLowering::matchAndRewrite(ParallelOp parallelOp, 448*ace01605SRiver Riddle PatternRewriter &rewriter) const { 449*ace01605SRiver Riddle Location loc = parallelOp.getLoc(); 450*ace01605SRiver Riddle 451*ace01605SRiver Riddle // For a parallel loop, we essentially need to create an n-dimensional loop 452*ace01605SRiver Riddle // nest. We do this by translating to scf.for ops and have those lowered in 453*ace01605SRiver Riddle // a further rewrite. If a parallel loop contains reductions (and thus returns 454*ace01605SRiver Riddle // values), forward the initial values for the reductions down the loop 455*ace01605SRiver Riddle // hierarchy and bubble up the results by modifying the "yield" terminator. 456*ace01605SRiver Riddle SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals()); 457*ace01605SRiver Riddle SmallVector<Value, 4> ivs; 458*ace01605SRiver Riddle ivs.reserve(parallelOp.getNumLoops()); 459*ace01605SRiver Riddle bool first = true; 460*ace01605SRiver Riddle SmallVector<Value, 4> loopResults(iterArgs); 461*ace01605SRiver Riddle for (auto loopOperands : 462*ace01605SRiver Riddle llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(), 463*ace01605SRiver Riddle parallelOp.getUpperBound(), parallelOp.getStep())) { 464*ace01605SRiver Riddle Value iv, lower, upper, step; 465*ace01605SRiver Riddle std::tie(iv, lower, upper, step) = loopOperands; 466*ace01605SRiver Riddle ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs); 467*ace01605SRiver Riddle ivs.push_back(forOp.getInductionVar()); 468*ace01605SRiver Riddle auto iterRange = forOp.getRegionIterArgs(); 469*ace01605SRiver Riddle iterArgs.assign(iterRange.begin(), iterRange.end()); 470*ace01605SRiver Riddle 471*ace01605SRiver Riddle if (first) { 472*ace01605SRiver Riddle // Store the results of the outermost loop that will be used to replace 473*ace01605SRiver Riddle // the results of the parallel loop when it is fully rewritten. 474*ace01605SRiver Riddle loopResults.assign(forOp.result_begin(), forOp.result_end()); 475*ace01605SRiver Riddle first = false; 476*ace01605SRiver Riddle } else if (!forOp.getResults().empty()) { 477*ace01605SRiver Riddle // A loop is constructed with an empty "yield" terminator if there are 478*ace01605SRiver Riddle // no results. 479*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); 480*ace01605SRiver Riddle rewriter.create<scf::YieldOp>(loc, forOp.getResults()); 481*ace01605SRiver Riddle } 482*ace01605SRiver Riddle 483*ace01605SRiver Riddle rewriter.setInsertionPointToStart(forOp.getBody()); 484*ace01605SRiver Riddle } 485*ace01605SRiver Riddle 486*ace01605SRiver Riddle // First, merge reduction blocks into the main region. 487*ace01605SRiver Riddle SmallVector<Value, 4> yieldOperands; 488*ace01605SRiver Riddle yieldOperands.reserve(parallelOp.getNumResults()); 489*ace01605SRiver Riddle for (auto &op : *parallelOp.getBody()) { 490*ace01605SRiver Riddle auto reduce = dyn_cast<ReduceOp>(op); 491*ace01605SRiver Riddle if (!reduce) 492*ace01605SRiver Riddle continue; 493*ace01605SRiver Riddle 494*ace01605SRiver Riddle Block &reduceBlock = reduce.getReductionOperator().front(); 495*ace01605SRiver Riddle Value arg = iterArgs[yieldOperands.size()]; 496*ace01605SRiver Riddle yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0)); 497*ace01605SRiver Riddle rewriter.eraseOp(reduceBlock.getTerminator()); 498*ace01605SRiver Riddle rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()}); 499*ace01605SRiver Riddle rewriter.eraseOp(reduce); 500*ace01605SRiver Riddle } 501*ace01605SRiver Riddle 502*ace01605SRiver Riddle // Then merge the loop body without the terminator. 503*ace01605SRiver Riddle rewriter.eraseOp(parallelOp.getBody()->getTerminator()); 504*ace01605SRiver Riddle Block *newBody = rewriter.getInsertionBlock(); 505*ace01605SRiver Riddle if (newBody->empty()) 506*ace01605SRiver Riddle rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs); 507*ace01605SRiver Riddle else 508*ace01605SRiver Riddle rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(), 509*ace01605SRiver Riddle ivs); 510*ace01605SRiver Riddle 511*ace01605SRiver Riddle // Finally, create the terminator if required (for loops with no results, it 512*ace01605SRiver Riddle // has been already created in loop construction). 513*ace01605SRiver Riddle if (!yieldOperands.empty()) { 514*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); 515*ace01605SRiver Riddle rewriter.create<scf::YieldOp>(loc, yieldOperands); 516*ace01605SRiver Riddle } 517*ace01605SRiver Riddle 518*ace01605SRiver Riddle rewriter.replaceOp(parallelOp, loopResults); 519*ace01605SRiver Riddle 520*ace01605SRiver Riddle return success(); 521*ace01605SRiver Riddle } 522*ace01605SRiver Riddle 523*ace01605SRiver Riddle LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, 524*ace01605SRiver Riddle PatternRewriter &rewriter) const { 525*ace01605SRiver Riddle OpBuilder::InsertionGuard guard(rewriter); 526*ace01605SRiver Riddle Location loc = whileOp.getLoc(); 527*ace01605SRiver Riddle 528*ace01605SRiver Riddle // Split the current block before the WhileOp to create the inlining point. 529*ace01605SRiver Riddle Block *currentBlock = rewriter.getInsertionBlock(); 530*ace01605SRiver Riddle Block *continuation = 531*ace01605SRiver Riddle rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 532*ace01605SRiver Riddle 533*ace01605SRiver Riddle // Inline both regions. 534*ace01605SRiver Riddle Block *after = &whileOp.getAfter().front(); 535*ace01605SRiver Riddle Block *afterLast = &whileOp.getAfter().back(); 536*ace01605SRiver Riddle Block *before = &whileOp.getBefore().front(); 537*ace01605SRiver Riddle Block *beforeLast = &whileOp.getBefore().back(); 538*ace01605SRiver Riddle rewriter.inlineRegionBefore(whileOp.getAfter(), continuation); 539*ace01605SRiver Riddle rewriter.inlineRegionBefore(whileOp.getBefore(), after); 540*ace01605SRiver Riddle 541*ace01605SRiver Riddle // Branch to the "before" region. 542*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(currentBlock); 543*ace01605SRiver Riddle rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits()); 544*ace01605SRiver Riddle 545*ace01605SRiver Riddle // Replace terminators with branches. Assuming bodies are SESE, which holds 546*ace01605SRiver Riddle // given only the patterns from this file, we only need to look at the last 547*ace01605SRiver Riddle // block. This should be reconsidered if we allow break/continue in SCF. 548*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(beforeLast); 549*ace01605SRiver Riddle auto condOp = cast<ConditionOp>(beforeLast->getTerminator()); 550*ace01605SRiver Riddle rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), 551*ace01605SRiver Riddle after, condOp.getArgs(), 552*ace01605SRiver Riddle continuation, ValueRange()); 553*ace01605SRiver Riddle 554*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(afterLast); 555*ace01605SRiver Riddle auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator()); 556*ace01605SRiver Riddle rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before, 557*ace01605SRiver Riddle yieldOp.getResults()); 558*ace01605SRiver Riddle 559*ace01605SRiver Riddle // Replace the op with values "yielded" from the "before" region, which are 560*ace01605SRiver Riddle // visible by dominance. 561*ace01605SRiver Riddle rewriter.replaceOp(whileOp, condOp.getArgs()); 562*ace01605SRiver Riddle 563*ace01605SRiver Riddle return success(); 564*ace01605SRiver Riddle } 565*ace01605SRiver Riddle 566*ace01605SRiver Riddle LogicalResult 567*ace01605SRiver Riddle DoWhileLowering::matchAndRewrite(WhileOp whileOp, 568*ace01605SRiver Riddle PatternRewriter &rewriter) const { 569*ace01605SRiver Riddle if (!llvm::hasSingleElement(whileOp.getAfter())) 570*ace01605SRiver Riddle return rewriter.notifyMatchFailure(whileOp, 571*ace01605SRiver Riddle "do-while simplification applicable to " 572*ace01605SRiver Riddle "single-block 'after' region only"); 573*ace01605SRiver Riddle 574*ace01605SRiver Riddle Block &afterBlock = whileOp.getAfter().front(); 575*ace01605SRiver Riddle if (!llvm::hasSingleElement(afterBlock)) 576*ace01605SRiver Riddle return rewriter.notifyMatchFailure(whileOp, 577*ace01605SRiver Riddle "do-while simplification applicable " 578*ace01605SRiver Riddle "only if 'after' region has no payload"); 579*ace01605SRiver Riddle 580*ace01605SRiver Riddle auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front()); 581*ace01605SRiver Riddle if (!yield || yield.getResults() != afterBlock.getArguments()) 582*ace01605SRiver Riddle return rewriter.notifyMatchFailure(whileOp, 583*ace01605SRiver Riddle "do-while simplification applicable " 584*ace01605SRiver Riddle "only to forwarding 'after' regions"); 585*ace01605SRiver Riddle 586*ace01605SRiver Riddle // Split the current block before the WhileOp to create the inlining point. 587*ace01605SRiver Riddle OpBuilder::InsertionGuard guard(rewriter); 588*ace01605SRiver Riddle Block *currentBlock = rewriter.getInsertionBlock(); 589*ace01605SRiver Riddle Block *continuation = 590*ace01605SRiver Riddle rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); 591*ace01605SRiver Riddle 592*ace01605SRiver Riddle // Only the "before" region should be inlined. 593*ace01605SRiver Riddle Block *before = &whileOp.getBefore().front(); 594*ace01605SRiver Riddle Block *beforeLast = &whileOp.getBefore().back(); 595*ace01605SRiver Riddle rewriter.inlineRegionBefore(whileOp.getBefore(), continuation); 596*ace01605SRiver Riddle 597*ace01605SRiver Riddle // Branch to the "before" region. 598*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(currentBlock); 599*ace01605SRiver Riddle rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits()); 600*ace01605SRiver Riddle 601*ace01605SRiver Riddle // Loop around the "before" region based on condition. 602*ace01605SRiver Riddle rewriter.setInsertionPointToEnd(beforeLast); 603*ace01605SRiver Riddle auto condOp = cast<ConditionOp>(beforeLast->getTerminator()); 604*ace01605SRiver Riddle rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), 605*ace01605SRiver Riddle before, condOp.getArgs(), 606*ace01605SRiver Riddle continuation, ValueRange()); 607*ace01605SRiver Riddle 608*ace01605SRiver Riddle // Replace the op with values "yielded" from the "before" region, which are 609*ace01605SRiver Riddle // visible by dominance. 610*ace01605SRiver Riddle rewriter.replaceOp(whileOp, condOp.getArgs()); 611*ace01605SRiver Riddle 612*ace01605SRiver Riddle return success(); 613*ace01605SRiver Riddle } 614*ace01605SRiver Riddle 615*ace01605SRiver Riddle void mlir::populateSCFToControlFlowConversionPatterns( 616*ace01605SRiver Riddle RewritePatternSet &patterns) { 617*ace01605SRiver Riddle patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering, 618*ace01605SRiver Riddle ExecuteRegionLowering>(patterns.getContext()); 619*ace01605SRiver Riddle patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2); 620*ace01605SRiver Riddle } 621*ace01605SRiver Riddle 622*ace01605SRiver Riddle void SCFToControlFlowPass::runOnOperation() { 623*ace01605SRiver Riddle RewritePatternSet patterns(&getContext()); 624*ace01605SRiver Riddle populateSCFToControlFlowConversionPatterns(patterns); 625*ace01605SRiver Riddle 626*ace01605SRiver Riddle // Configure conversion to lower out SCF operations. 627*ace01605SRiver Riddle ConversionTarget target(getContext()); 628*ace01605SRiver Riddle target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp, 629*ace01605SRiver Riddle scf::ExecuteRegionOp>(); 630*ace01605SRiver Riddle target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 631*ace01605SRiver Riddle if (failed( 632*ace01605SRiver Riddle applyPartialConversion(getOperation(), target, std::move(patterns)))) 633*ace01605SRiver Riddle signalPassFailure(); 634*ace01605SRiver Riddle } 635*ace01605SRiver Riddle 636*ace01605SRiver Riddle std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() { 637*ace01605SRiver Riddle return std::make_unique<SCFToControlFlowPass>(); 638*ace01605SRiver Riddle } 639