168d69237SJean Perier //===-- RewriteLoop.cpp ---------------------------------------------------===// 268d69237SJean Perier // 368d69237SJean Perier // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 468d69237SJean Perier // See https://llvm.org/LICENSE.txt for license information. 568d69237SJean Perier // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 668d69237SJean Perier // 768d69237SJean Perier //===----------------------------------------------------------------------===// 868d69237SJean Perier 968d69237SJean Perier #include "PassDetail.h" 1068d69237SJean Perier #include "flang/Optimizer/Dialect/FIRDialect.h" 1168d69237SJean Perier #include "flang/Optimizer/Dialect/FIROps.h" 1268d69237SJean Perier #include "flang/Optimizer/Transforms/Passes.h" 1368d69237SJean Perier #include "mlir/Dialect/Affine/IR/AffineOps.h" 1468d69237SJean Perier #include "mlir/Dialect/StandardOps/IR/Ops.h" 1568d69237SJean Perier #include "mlir/Pass/Pass.h" 1668d69237SJean Perier #include "mlir/Transforms/DialectConversion.h" 1768d69237SJean Perier #include "llvm/Support/CommandLine.h" 1868d69237SJean Perier 1968d69237SJean Perier using namespace fir; 2068d69237SJean Perier 2168d69237SJean Perier namespace { 2268d69237SJean Perier 2368d69237SJean Perier // Conversion of fir control ops to more primitive control-flow. 2468d69237SJean Perier // 2568d69237SJean Perier // FIR loops that cannot be converted to the affine dialect will remain as 2668d69237SJean Perier // `fir.do_loop` operations. These can be converted to control-flow operations. 2768d69237SJean Perier 2868d69237SJean Perier /// Convert `fir.do_loop` to CFG 2968d69237SJean Perier class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> { 3068d69237SJean Perier public: 3168d69237SJean Perier using OpRewritePattern::OpRewritePattern; 3268d69237SJean Perier 3368d69237SJean Perier CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 3468d69237SJean Perier : mlir::OpRewritePattern<fir::DoLoopOp>(ctx), 3568d69237SJean Perier forceLoopToExecuteOnce(forceLoopToExecuteOnce) {} 3668d69237SJean Perier 3768d69237SJean Perier mlir::LogicalResult 3868d69237SJean Perier matchAndRewrite(DoLoopOp loop, 3968d69237SJean Perier mlir::PatternRewriter &rewriter) const override { 4068d69237SJean Perier auto loc = loop.getLoc(); 4168d69237SJean Perier 4268d69237SJean Perier // Create the start and end blocks that will wrap the DoLoopOp with an 4368d69237SJean Perier // initalizer and an end point 4468d69237SJean Perier auto *initBlock = rewriter.getInsertionBlock(); 4568d69237SJean Perier auto initPos = rewriter.getInsertionPoint(); 4668d69237SJean Perier auto *endBlock = rewriter.splitBlock(initBlock, initPos); 4768d69237SJean Perier 4868d69237SJean Perier // Split the first DoLoopOp block in two parts. The part before will be the 4968d69237SJean Perier // conditional block since it already has the induction variable and 5068d69237SJean Perier // loop-carried values as arguments. 5168d69237SJean Perier auto *conditionalBlock = &loop.region().front(); 5268d69237SJean Perier conditionalBlock->addArgument(rewriter.getIndexType()); 5368d69237SJean Perier auto *firstBlock = 5468d69237SJean Perier rewriter.splitBlock(conditionalBlock, conditionalBlock->begin()); 5568d69237SJean Perier auto *lastBlock = &loop.region().back(); 5668d69237SJean Perier 5768d69237SJean Perier // Move the blocks from the DoLoopOp between initBlock and endBlock 5868d69237SJean Perier rewriter.inlineRegionBefore(loop.region(), endBlock); 5968d69237SJean Perier 6068d69237SJean Perier // Get loop values from the DoLoopOp 6168d69237SJean Perier auto low = loop.lowerBound(); 6268d69237SJean Perier auto high = loop.upperBound(); 6368d69237SJean Perier assert(low && high && "must be a Value"); 6468d69237SJean Perier auto step = loop.step(); 6568d69237SJean Perier 6668d69237SJean Perier // Initalization block 6768d69237SJean Perier rewriter.setInsertionPointToEnd(initBlock); 68*a54f4eaeSMogball auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low); 69*a54f4eaeSMogball auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step); 7068d69237SJean Perier mlir::Value iters = 71*a54f4eaeSMogball rewriter.create<mlir::arith::DivSIOp>(loc, distance, step); 7268d69237SJean Perier 7368d69237SJean Perier if (forceLoopToExecuteOnce) { 74*a54f4eaeSMogball auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 75*a54f4eaeSMogball auto cond = rewriter.create<mlir::arith::CmpIOp>( 76*a54f4eaeSMogball loc, arith::CmpIPredicate::sle, iters, zero); 77*a54f4eaeSMogball auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 7868d69237SJean Perier iters = rewriter.create<mlir::SelectOp>(loc, cond, one, iters); 7968d69237SJean Perier } 8068d69237SJean Perier 8168d69237SJean Perier llvm::SmallVector<mlir::Value> loopOperands; 8268d69237SJean Perier loopOperands.push_back(low); 8368d69237SJean Perier auto operands = loop.getIterOperands(); 8468d69237SJean Perier loopOperands.append(operands.begin(), operands.end()); 8568d69237SJean Perier loopOperands.push_back(iters); 8668d69237SJean Perier 8768d69237SJean Perier rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopOperands); 8868d69237SJean Perier 8968d69237SJean Perier // Last loop block 9068d69237SJean Perier auto *terminator = lastBlock->getTerminator(); 9168d69237SJean Perier rewriter.setInsertionPointToEnd(lastBlock); 9268d69237SJean Perier auto iv = conditionalBlock->getArgument(0); 93*a54f4eaeSMogball mlir::Value steppedIndex = 94*a54f4eaeSMogball rewriter.create<mlir::arith::AddIOp>(loc, iv, step); 9568d69237SJean Perier assert(steppedIndex && "must be a Value"); 9668d69237SJean Perier auto lastArg = conditionalBlock->getNumArguments() - 1; 9768d69237SJean Perier auto itersLeft = conditionalBlock->getArgument(lastArg); 98*a54f4eaeSMogball auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 9968d69237SJean Perier mlir::Value itersMinusOne = 100*a54f4eaeSMogball rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one); 10168d69237SJean Perier 10268d69237SJean Perier llvm::SmallVector<mlir::Value> loopCarried; 10368d69237SJean Perier loopCarried.push_back(steppedIndex); 10468d69237SJean Perier auto begin = loop.finalValue() ? std::next(terminator->operand_begin()) 10568d69237SJean Perier : terminator->operand_begin(); 10668d69237SJean Perier loopCarried.append(begin, terminator->operand_end()); 10768d69237SJean Perier loopCarried.push_back(itersMinusOne); 10868d69237SJean Perier rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopCarried); 10968d69237SJean Perier rewriter.eraseOp(terminator); 11068d69237SJean Perier 11168d69237SJean Perier // Conditional block 11268d69237SJean Perier rewriter.setInsertionPointToEnd(conditionalBlock); 113*a54f4eaeSMogball auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 114*a54f4eaeSMogball auto comparison = rewriter.create<mlir::arith::CmpIOp>( 115*a54f4eaeSMogball loc, arith::CmpIPredicate::sgt, itersLeft, zero); 11668d69237SJean Perier 11768d69237SJean Perier rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBlock, 11868d69237SJean Perier llvm::ArrayRef<mlir::Value>(), endBlock, 11968d69237SJean Perier llvm::ArrayRef<mlir::Value>()); 12068d69237SJean Perier 12168d69237SJean Perier // The result of the loop operation is the values of the condition block 12268d69237SJean Perier // arguments except the induction variable on the last iteration. 12368d69237SJean Perier auto args = loop.finalValue() 12468d69237SJean Perier ? conditionalBlock->getArguments() 12568d69237SJean Perier : conditionalBlock->getArguments().drop_front(); 12668d69237SJean Perier rewriter.replaceOp(loop, args.drop_back()); 12768d69237SJean Perier return success(); 12868d69237SJean Perier } 12968d69237SJean Perier 13068d69237SJean Perier private: 13168d69237SJean Perier bool forceLoopToExecuteOnce; 13268d69237SJean Perier }; 13368d69237SJean Perier 13468d69237SJean Perier /// Convert `fir.if` to control-flow 13568d69237SJean Perier class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> { 13668d69237SJean Perier public: 13768d69237SJean Perier using OpRewritePattern::OpRewritePattern; 13868d69237SJean Perier 13968d69237SJean Perier CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 14068d69237SJean Perier : mlir::OpRewritePattern<fir::IfOp>(ctx), 14168d69237SJean Perier forceLoopToExecuteOnce(forceLoopToExecuteOnce) {} 14268d69237SJean Perier 14368d69237SJean Perier mlir::LogicalResult 14468d69237SJean Perier matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override { 14568d69237SJean Perier auto loc = ifOp.getLoc(); 14668d69237SJean Perier 14768d69237SJean Perier // Split the block containing the 'fir.if' into two parts. The part before 14868d69237SJean Perier // will contain the condition, the part after will be the continuation 14968d69237SJean Perier // point. 15068d69237SJean Perier auto *condBlock = rewriter.getInsertionBlock(); 15168d69237SJean Perier auto opPosition = rewriter.getInsertionPoint(); 15268d69237SJean Perier auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 15368d69237SJean Perier mlir::Block *continueBlock; 15468d69237SJean Perier if (ifOp.getNumResults() == 0) { 15568d69237SJean Perier continueBlock = remainingOpsBlock; 15668d69237SJean Perier } else { 15768d69237SJean Perier continueBlock = 15868d69237SJean Perier rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes()); 15968d69237SJean Perier rewriter.create<mlir::BranchOp>(loc, remainingOpsBlock); 16068d69237SJean Perier } 16168d69237SJean Perier 16268d69237SJean Perier // Move blocks from the "then" region to the region containing 'fir.if', 16368d69237SJean Perier // place it before the continuation block, and branch to it. 16468d69237SJean Perier auto &ifOpRegion = ifOp.thenRegion(); 16568d69237SJean Perier auto *ifOpBlock = &ifOpRegion.front(); 16668d69237SJean Perier auto *ifOpTerminator = ifOpRegion.back().getTerminator(); 16768d69237SJean Perier auto ifOpTerminatorOperands = ifOpTerminator->getOperands(); 16868d69237SJean Perier rewriter.setInsertionPointToEnd(&ifOpRegion.back()); 16968d69237SJean Perier rewriter.create<mlir::BranchOp>(loc, continueBlock, ifOpTerminatorOperands); 17068d69237SJean Perier rewriter.eraseOp(ifOpTerminator); 17168d69237SJean Perier rewriter.inlineRegionBefore(ifOpRegion, continueBlock); 17268d69237SJean Perier 17368d69237SJean Perier // Move blocks from the "else" region (if present) to the region containing 17468d69237SJean Perier // 'fir.if', place it before the continuation block and branch to it. It 17568d69237SJean Perier // will be placed after the "then" regions. 17668d69237SJean Perier auto *otherwiseBlock = continueBlock; 17768d69237SJean Perier auto &otherwiseRegion = ifOp.elseRegion(); 17868d69237SJean Perier if (!otherwiseRegion.empty()) { 17968d69237SJean Perier otherwiseBlock = &otherwiseRegion.front(); 18068d69237SJean Perier auto *otherwiseTerm = otherwiseRegion.back().getTerminator(); 18168d69237SJean Perier auto otherwiseTermOperands = otherwiseTerm->getOperands(); 18268d69237SJean Perier rewriter.setInsertionPointToEnd(&otherwiseRegion.back()); 18368d69237SJean Perier rewriter.create<mlir::BranchOp>(loc, continueBlock, 18468d69237SJean Perier otherwiseTermOperands); 18568d69237SJean Perier rewriter.eraseOp(otherwiseTerm); 18668d69237SJean Perier rewriter.inlineRegionBefore(otherwiseRegion, continueBlock); 18768d69237SJean Perier } 18868d69237SJean Perier 18968d69237SJean Perier rewriter.setInsertionPointToEnd(condBlock); 19068d69237SJean Perier rewriter.create<mlir::CondBranchOp>( 19168d69237SJean Perier loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(), 19268d69237SJean Perier otherwiseBlock, llvm::ArrayRef<mlir::Value>()); 19368d69237SJean Perier rewriter.replaceOp(ifOp, continueBlock->getArguments()); 19468d69237SJean Perier return success(); 19568d69237SJean Perier } 19668d69237SJean Perier 19768d69237SJean Perier private: 19868d69237SJean Perier bool forceLoopToExecuteOnce; 19968d69237SJean Perier }; 20068d69237SJean Perier 20168d69237SJean Perier /// Convert `fir.iter_while` to control-flow. 20268d69237SJean Perier class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> { 20368d69237SJean Perier public: 20468d69237SJean Perier using OpRewritePattern::OpRewritePattern; 20568d69237SJean Perier 20668d69237SJean Perier CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 20768d69237SJean Perier : mlir::OpRewritePattern<fir::IterWhileOp>(ctx), 20868d69237SJean Perier forceLoopToExecuteOnce(forceLoopToExecuteOnce) {} 20968d69237SJean Perier 21068d69237SJean Perier mlir::LogicalResult 21168d69237SJean Perier matchAndRewrite(fir::IterWhileOp whileOp, 21268d69237SJean Perier mlir::PatternRewriter &rewriter) const override { 21368d69237SJean Perier auto loc = whileOp.getLoc(); 21468d69237SJean Perier 21568d69237SJean Perier // Start by splitting the block containing the 'fir.do_loop' into two parts. 21668d69237SJean Perier // The part before will get the init code, the part after will be the end 21768d69237SJean Perier // point. 21868d69237SJean Perier auto *initBlock = rewriter.getInsertionBlock(); 21968d69237SJean Perier auto initPosition = rewriter.getInsertionPoint(); 22068d69237SJean Perier auto *endBlock = rewriter.splitBlock(initBlock, initPosition); 22168d69237SJean Perier 22268d69237SJean Perier // Use the first block of the loop body as the condition block since it is 22368d69237SJean Perier // the block that has the induction variable and loop-carried values as 22468d69237SJean Perier // arguments. Split out all operations from the first block into a new 22568d69237SJean Perier // block. Move all body blocks from the loop body region to the region 22668d69237SJean Perier // containing the loop. 22768d69237SJean Perier auto *conditionBlock = &whileOp.region().front(); 22868d69237SJean Perier auto *firstBodyBlock = 22968d69237SJean Perier rewriter.splitBlock(conditionBlock, conditionBlock->begin()); 23068d69237SJean Perier auto *lastBodyBlock = &whileOp.region().back(); 23168d69237SJean Perier rewriter.inlineRegionBefore(whileOp.region(), endBlock); 23268d69237SJean Perier auto iv = conditionBlock->getArgument(0); 23368d69237SJean Perier auto iterateVar = conditionBlock->getArgument(1); 23468d69237SJean Perier 23568d69237SJean Perier // Append the induction variable stepping logic to the last body block and 23668d69237SJean Perier // branch back to the condition block. Loop-carried values are taken from 23768d69237SJean Perier // operands of the loop terminator. 23868d69237SJean Perier auto *terminator = lastBodyBlock->getTerminator(); 23968d69237SJean Perier rewriter.setInsertionPointToEnd(lastBodyBlock); 24068d69237SJean Perier auto step = whileOp.step(); 241*a54f4eaeSMogball mlir::Value stepped = rewriter.create<mlir::arith::AddIOp>(loc, iv, step); 24268d69237SJean Perier assert(stepped && "must be a Value"); 24368d69237SJean Perier 24468d69237SJean Perier llvm::SmallVector<mlir::Value> loopCarried; 24568d69237SJean Perier loopCarried.push_back(stepped); 24668d69237SJean Perier auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin()) 24768d69237SJean Perier : terminator->operand_begin(); 24868d69237SJean Perier loopCarried.append(begin, terminator->operand_end()); 24968d69237SJean Perier rewriter.create<mlir::BranchOp>(loc, conditionBlock, loopCarried); 25068d69237SJean Perier rewriter.eraseOp(terminator); 25168d69237SJean Perier 25268d69237SJean Perier // Compute loop bounds before branching to the condition. 25368d69237SJean Perier rewriter.setInsertionPointToEnd(initBlock); 25468d69237SJean Perier auto lowerBound = whileOp.lowerBound(); 25568d69237SJean Perier auto upperBound = whileOp.upperBound(); 25668d69237SJean Perier assert(lowerBound && upperBound && "must be a Value"); 25768d69237SJean Perier 25868d69237SJean Perier // The initial values of loop-carried values is obtained from the operands 25968d69237SJean Perier // of the loop operation. 26068d69237SJean Perier llvm::SmallVector<mlir::Value> destOperands; 26168d69237SJean Perier destOperands.push_back(lowerBound); 26268d69237SJean Perier auto iterOperands = whileOp.getIterOperands(); 26368d69237SJean Perier destOperands.append(iterOperands.begin(), iterOperands.end()); 26468d69237SJean Perier rewriter.create<mlir::BranchOp>(loc, conditionBlock, destOperands); 26568d69237SJean Perier 26668d69237SJean Perier // With the body block done, we can fill in the condition block. 26768d69237SJean Perier rewriter.setInsertionPointToEnd(conditionBlock); 26868d69237SJean Perier // The comparison depends on the sign of the step value. We fully expect 26968d69237SJean Perier // this expression to be folded by the optimizer or LLVM. This expression 27068d69237SJean Perier // is written this way so that `step == 0` always returns `false`. 271*a54f4eaeSMogball auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 272*a54f4eaeSMogball auto compl0 = rewriter.create<mlir::arith::CmpIOp>( 273*a54f4eaeSMogball loc, arith::CmpIPredicate::slt, zero, step); 274*a54f4eaeSMogball auto compl1 = rewriter.create<mlir::arith::CmpIOp>( 275*a54f4eaeSMogball loc, arith::CmpIPredicate::sle, iv, upperBound); 276*a54f4eaeSMogball auto compl2 = rewriter.create<mlir::arith::CmpIOp>( 277*a54f4eaeSMogball loc, arith::CmpIPredicate::slt, step, zero); 278*a54f4eaeSMogball auto compl3 = rewriter.create<mlir::arith::CmpIOp>( 279*a54f4eaeSMogball loc, arith::CmpIPredicate::sle, upperBound, iv); 280*a54f4eaeSMogball auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1); 281*a54f4eaeSMogball auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3); 282*a54f4eaeSMogball auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1); 28368d69237SJean Perier // Remember to AND in the early-exit bool. 284*a54f4eaeSMogball auto comparison = 285*a54f4eaeSMogball rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2); 28668d69237SJean Perier rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBodyBlock, 28768d69237SJean Perier llvm::ArrayRef<mlir::Value>(), endBlock, 28868d69237SJean Perier llvm::ArrayRef<mlir::Value>()); 28968d69237SJean Perier // The result of the loop operation is the values of the condition block 29068d69237SJean Perier // arguments except the induction variable on the last iteration. 29168d69237SJean Perier auto args = whileOp.finalValue() 29268d69237SJean Perier ? conditionBlock->getArguments() 29368d69237SJean Perier : conditionBlock->getArguments().drop_front(); 29468d69237SJean Perier rewriter.replaceOp(whileOp, args); 29568d69237SJean Perier return success(); 29668d69237SJean Perier } 29768d69237SJean Perier 29868d69237SJean Perier private: 29968d69237SJean Perier bool forceLoopToExecuteOnce; 30068d69237SJean Perier }; 30168d69237SJean Perier 30268d69237SJean Perier /// Convert FIR structured control flow ops to CFG ops. 30368d69237SJean Perier class CfgConversion : public CFGConversionBase<CfgConversion> { 30468d69237SJean Perier public: 30568d69237SJean Perier void runOnFunction() override { 30668d69237SJean Perier auto *context = &getContext(); 30768d69237SJean Perier mlir::OwningRewritePatternList patterns(context); 30868d69237SJean Perier patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>( 30968d69237SJean Perier context, forceLoopToExecuteOnce); 31068d69237SJean Perier mlir::ConversionTarget target(*context); 31168d69237SJean Perier target.addLegalDialect<mlir::AffineDialect, FIROpsDialect, 31268d69237SJean Perier mlir::StandardOpsDialect>(); 31368d69237SJean Perier 31468d69237SJean Perier // apply the patterns 31568d69237SJean Perier target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>(); 31668d69237SJean Perier target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 31768d69237SJean Perier if (mlir::failed(mlir::applyPartialConversion(getFunction(), target, 31868d69237SJean Perier std::move(patterns)))) { 31968d69237SJean Perier mlir::emitError(mlir::UnknownLoc::get(context), 32068d69237SJean Perier "error in converting to CFG\n"); 32168d69237SJean Perier signalPassFailure(); 32268d69237SJean Perier } 32368d69237SJean Perier } 32468d69237SJean Perier }; 32568d69237SJean Perier } // namespace 32668d69237SJean Perier 32768d69237SJean Perier /// Convert FIR's structured control flow ops to CFG ops. This 32868d69237SJean Perier /// conversion enables the `createLowerToCFGPass` to transform these to CFG 32968d69237SJean Perier /// form. 33068d69237SJean Perier std::unique_ptr<mlir::Pass> fir::createFirToCfgPass() { 33168d69237SJean Perier return std::make_unique<CfgConversion>(); 33268d69237SJean Perier } 333