168d69237SJean Perier //===-- RewriteLoop.cpp ---------------------------------------------------===//
268d69237SJean Perier //
368d69237SJean Perier // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
468d69237SJean Perier // See https://llvm.org/LICENSE.txt for license information.
568d69237SJean Perier // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
668d69237SJean Perier //
768d69237SJean Perier //===----------------------------------------------------------------------===//
868d69237SJean Perier 
968d69237SJean Perier #include "PassDetail.h"
1068d69237SJean Perier #include "flang/Optimizer/Dialect/FIRDialect.h"
1168d69237SJean Perier #include "flang/Optimizer/Dialect/FIROps.h"
1268d69237SJean Perier #include "flang/Optimizer/Transforms/Passes.h"
1368d69237SJean Perier #include "mlir/Dialect/Affine/IR/AffineOps.h"
1468d69237SJean Perier #include "mlir/Dialect/StandardOps/IR/Ops.h"
1568d69237SJean Perier #include "mlir/Pass/Pass.h"
1668d69237SJean Perier #include "mlir/Transforms/DialectConversion.h"
1768d69237SJean Perier #include "llvm/Support/CommandLine.h"
1868d69237SJean Perier 
1968d69237SJean Perier using namespace fir;
2068d69237SJean Perier 
2168d69237SJean Perier namespace {
2268d69237SJean Perier 
2368d69237SJean Perier // Conversion of fir control ops to more primitive control-flow.
2468d69237SJean Perier //
2568d69237SJean Perier // FIR loops that cannot be converted to the affine dialect will remain as
2668d69237SJean Perier // `fir.do_loop` operations.  These can be converted to control-flow operations.
2768d69237SJean Perier 
2868d69237SJean Perier /// Convert `fir.do_loop` to CFG
2968d69237SJean Perier class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
3068d69237SJean Perier public:
3168d69237SJean Perier   using OpRewritePattern::OpRewritePattern;
3268d69237SJean Perier 
3368d69237SJean Perier   CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
3468d69237SJean Perier       : mlir::OpRewritePattern<fir::DoLoopOp>(ctx),
3568d69237SJean Perier         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
3668d69237SJean Perier 
3768d69237SJean Perier   mlir::LogicalResult
3868d69237SJean Perier   matchAndRewrite(DoLoopOp loop,
3968d69237SJean Perier                   mlir::PatternRewriter &rewriter) const override {
4068d69237SJean Perier     auto loc = loop.getLoc();
4168d69237SJean Perier 
4268d69237SJean Perier     // Create the start and end blocks that will wrap the DoLoopOp with an
4368d69237SJean Perier     // initalizer and an end point
4468d69237SJean Perier     auto *initBlock = rewriter.getInsertionBlock();
4568d69237SJean Perier     auto initPos = rewriter.getInsertionPoint();
4668d69237SJean Perier     auto *endBlock = rewriter.splitBlock(initBlock, initPos);
4768d69237SJean Perier 
4868d69237SJean Perier     // Split the first DoLoopOp block in two parts. The part before will be the
4968d69237SJean Perier     // conditional block since it already has the induction variable and
5068d69237SJean Perier     // loop-carried values as arguments.
5168d69237SJean Perier     auto *conditionalBlock = &loop.region().front();
5268d69237SJean Perier     conditionalBlock->addArgument(rewriter.getIndexType());
5368d69237SJean Perier     auto *firstBlock =
5468d69237SJean Perier         rewriter.splitBlock(conditionalBlock, conditionalBlock->begin());
5568d69237SJean Perier     auto *lastBlock = &loop.region().back();
5668d69237SJean Perier 
5768d69237SJean Perier     // Move the blocks from the DoLoopOp between initBlock and endBlock
5868d69237SJean Perier     rewriter.inlineRegionBefore(loop.region(), endBlock);
5968d69237SJean Perier 
6068d69237SJean Perier     // Get loop values from the DoLoopOp
6168d69237SJean Perier     auto low = loop.lowerBound();
6268d69237SJean Perier     auto high = loop.upperBound();
6368d69237SJean Perier     assert(low && high && "must be a Value");
6468d69237SJean Perier     auto step = loop.step();
6568d69237SJean Perier 
6668d69237SJean Perier     // Initalization block
6768d69237SJean Perier     rewriter.setInsertionPointToEnd(initBlock);
68*a54f4eaeSMogball     auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
69*a54f4eaeSMogball     auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
7068d69237SJean Perier     mlir::Value iters =
71*a54f4eaeSMogball         rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
7268d69237SJean Perier 
7368d69237SJean Perier     if (forceLoopToExecuteOnce) {
74*a54f4eaeSMogball       auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
75*a54f4eaeSMogball       auto cond = rewriter.create<mlir::arith::CmpIOp>(
76*a54f4eaeSMogball           loc, arith::CmpIPredicate::sle, iters, zero);
77*a54f4eaeSMogball       auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
7868d69237SJean Perier       iters = rewriter.create<mlir::SelectOp>(loc, cond, one, iters);
7968d69237SJean Perier     }
8068d69237SJean Perier 
8168d69237SJean Perier     llvm::SmallVector<mlir::Value> loopOperands;
8268d69237SJean Perier     loopOperands.push_back(low);
8368d69237SJean Perier     auto operands = loop.getIterOperands();
8468d69237SJean Perier     loopOperands.append(operands.begin(), operands.end());
8568d69237SJean Perier     loopOperands.push_back(iters);
8668d69237SJean Perier 
8768d69237SJean Perier     rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopOperands);
8868d69237SJean Perier 
8968d69237SJean Perier     // Last loop block
9068d69237SJean Perier     auto *terminator = lastBlock->getTerminator();
9168d69237SJean Perier     rewriter.setInsertionPointToEnd(lastBlock);
9268d69237SJean Perier     auto iv = conditionalBlock->getArgument(0);
93*a54f4eaeSMogball     mlir::Value steppedIndex =
94*a54f4eaeSMogball         rewriter.create<mlir::arith::AddIOp>(loc, iv, step);
9568d69237SJean Perier     assert(steppedIndex && "must be a Value");
9668d69237SJean Perier     auto lastArg = conditionalBlock->getNumArguments() - 1;
9768d69237SJean Perier     auto itersLeft = conditionalBlock->getArgument(lastArg);
98*a54f4eaeSMogball     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
9968d69237SJean Perier     mlir::Value itersMinusOne =
100*a54f4eaeSMogball         rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one);
10168d69237SJean Perier 
10268d69237SJean Perier     llvm::SmallVector<mlir::Value> loopCarried;
10368d69237SJean Perier     loopCarried.push_back(steppedIndex);
10468d69237SJean Perier     auto begin = loop.finalValue() ? std::next(terminator->operand_begin())
10568d69237SJean Perier                                    : terminator->operand_begin();
10668d69237SJean Perier     loopCarried.append(begin, terminator->operand_end());
10768d69237SJean Perier     loopCarried.push_back(itersMinusOne);
10868d69237SJean Perier     rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopCarried);
10968d69237SJean Perier     rewriter.eraseOp(terminator);
11068d69237SJean Perier 
11168d69237SJean Perier     // Conditional block
11268d69237SJean Perier     rewriter.setInsertionPointToEnd(conditionalBlock);
113*a54f4eaeSMogball     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
114*a54f4eaeSMogball     auto comparison = rewriter.create<mlir::arith::CmpIOp>(
115*a54f4eaeSMogball         loc, arith::CmpIPredicate::sgt, itersLeft, zero);
11668d69237SJean Perier 
11768d69237SJean Perier     rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBlock,
11868d69237SJean Perier                                         llvm::ArrayRef<mlir::Value>(), endBlock,
11968d69237SJean Perier                                         llvm::ArrayRef<mlir::Value>());
12068d69237SJean Perier 
12168d69237SJean Perier     // The result of the loop operation is the values of the condition block
12268d69237SJean Perier     // arguments except the induction variable on the last iteration.
12368d69237SJean Perier     auto args = loop.finalValue()
12468d69237SJean Perier                     ? conditionalBlock->getArguments()
12568d69237SJean Perier                     : conditionalBlock->getArguments().drop_front();
12668d69237SJean Perier     rewriter.replaceOp(loop, args.drop_back());
12768d69237SJean Perier     return success();
12868d69237SJean Perier   }
12968d69237SJean Perier 
13068d69237SJean Perier private:
13168d69237SJean Perier   bool forceLoopToExecuteOnce;
13268d69237SJean Perier };
13368d69237SJean Perier 
13468d69237SJean Perier /// Convert `fir.if` to control-flow
13568d69237SJean Perier class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
13668d69237SJean Perier public:
13768d69237SJean Perier   using OpRewritePattern::OpRewritePattern;
13868d69237SJean Perier 
13968d69237SJean Perier   CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
14068d69237SJean Perier       : mlir::OpRewritePattern<fir::IfOp>(ctx),
14168d69237SJean Perier         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
14268d69237SJean Perier 
14368d69237SJean Perier   mlir::LogicalResult
14468d69237SJean Perier   matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override {
14568d69237SJean Perier     auto loc = ifOp.getLoc();
14668d69237SJean Perier 
14768d69237SJean Perier     // Split the block containing the 'fir.if' into two parts.  The part before
14868d69237SJean Perier     // will contain the condition, the part after will be the continuation
14968d69237SJean Perier     // point.
15068d69237SJean Perier     auto *condBlock = rewriter.getInsertionBlock();
15168d69237SJean Perier     auto opPosition = rewriter.getInsertionPoint();
15268d69237SJean Perier     auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
15368d69237SJean Perier     mlir::Block *continueBlock;
15468d69237SJean Perier     if (ifOp.getNumResults() == 0) {
15568d69237SJean Perier       continueBlock = remainingOpsBlock;
15668d69237SJean Perier     } else {
15768d69237SJean Perier       continueBlock =
15868d69237SJean Perier           rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
15968d69237SJean Perier       rewriter.create<mlir::BranchOp>(loc, remainingOpsBlock);
16068d69237SJean Perier     }
16168d69237SJean Perier 
16268d69237SJean Perier     // Move blocks from the "then" region to the region containing 'fir.if',
16368d69237SJean Perier     // place it before the continuation block, and branch to it.
16468d69237SJean Perier     auto &ifOpRegion = ifOp.thenRegion();
16568d69237SJean Perier     auto *ifOpBlock = &ifOpRegion.front();
16668d69237SJean Perier     auto *ifOpTerminator = ifOpRegion.back().getTerminator();
16768d69237SJean Perier     auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
16868d69237SJean Perier     rewriter.setInsertionPointToEnd(&ifOpRegion.back());
16968d69237SJean Perier     rewriter.create<mlir::BranchOp>(loc, continueBlock, ifOpTerminatorOperands);
17068d69237SJean Perier     rewriter.eraseOp(ifOpTerminator);
17168d69237SJean Perier     rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
17268d69237SJean Perier 
17368d69237SJean Perier     // Move blocks from the "else" region (if present) to the region containing
17468d69237SJean Perier     // 'fir.if', place it before the continuation block and branch to it.  It
17568d69237SJean Perier     // will be placed after the "then" regions.
17668d69237SJean Perier     auto *otherwiseBlock = continueBlock;
17768d69237SJean Perier     auto &otherwiseRegion = ifOp.elseRegion();
17868d69237SJean Perier     if (!otherwiseRegion.empty()) {
17968d69237SJean Perier       otherwiseBlock = &otherwiseRegion.front();
18068d69237SJean Perier       auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
18168d69237SJean Perier       auto otherwiseTermOperands = otherwiseTerm->getOperands();
18268d69237SJean Perier       rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
18368d69237SJean Perier       rewriter.create<mlir::BranchOp>(loc, continueBlock,
18468d69237SJean Perier                                       otherwiseTermOperands);
18568d69237SJean Perier       rewriter.eraseOp(otherwiseTerm);
18668d69237SJean Perier       rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
18768d69237SJean Perier     }
18868d69237SJean Perier 
18968d69237SJean Perier     rewriter.setInsertionPointToEnd(condBlock);
19068d69237SJean Perier     rewriter.create<mlir::CondBranchOp>(
19168d69237SJean Perier         loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
19268d69237SJean Perier         otherwiseBlock, llvm::ArrayRef<mlir::Value>());
19368d69237SJean Perier     rewriter.replaceOp(ifOp, continueBlock->getArguments());
19468d69237SJean Perier     return success();
19568d69237SJean Perier   }
19668d69237SJean Perier 
19768d69237SJean Perier private:
19868d69237SJean Perier   bool forceLoopToExecuteOnce;
19968d69237SJean Perier };
20068d69237SJean Perier 
20168d69237SJean Perier /// Convert `fir.iter_while` to control-flow.
20268d69237SJean Perier class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> {
20368d69237SJean Perier public:
20468d69237SJean Perier   using OpRewritePattern::OpRewritePattern;
20568d69237SJean Perier 
20668d69237SJean Perier   CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
20768d69237SJean Perier       : mlir::OpRewritePattern<fir::IterWhileOp>(ctx),
20868d69237SJean Perier         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
20968d69237SJean Perier 
21068d69237SJean Perier   mlir::LogicalResult
21168d69237SJean Perier   matchAndRewrite(fir::IterWhileOp whileOp,
21268d69237SJean Perier                   mlir::PatternRewriter &rewriter) const override {
21368d69237SJean Perier     auto loc = whileOp.getLoc();
21468d69237SJean Perier 
21568d69237SJean Perier     // Start by splitting the block containing the 'fir.do_loop' into two parts.
21668d69237SJean Perier     // The part before will get the init code, the part after will be the end
21768d69237SJean Perier     // point.
21868d69237SJean Perier     auto *initBlock = rewriter.getInsertionBlock();
21968d69237SJean Perier     auto initPosition = rewriter.getInsertionPoint();
22068d69237SJean Perier     auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
22168d69237SJean Perier 
22268d69237SJean Perier     // Use the first block of the loop body as the condition block since it is
22368d69237SJean Perier     // the block that has the induction variable and loop-carried values as
22468d69237SJean Perier     // arguments. Split out all operations from the first block into a new
22568d69237SJean Perier     // block. Move all body blocks from the loop body region to the region
22668d69237SJean Perier     // containing the loop.
22768d69237SJean Perier     auto *conditionBlock = &whileOp.region().front();
22868d69237SJean Perier     auto *firstBodyBlock =
22968d69237SJean Perier         rewriter.splitBlock(conditionBlock, conditionBlock->begin());
23068d69237SJean Perier     auto *lastBodyBlock = &whileOp.region().back();
23168d69237SJean Perier     rewriter.inlineRegionBefore(whileOp.region(), endBlock);
23268d69237SJean Perier     auto iv = conditionBlock->getArgument(0);
23368d69237SJean Perier     auto iterateVar = conditionBlock->getArgument(1);
23468d69237SJean Perier 
23568d69237SJean Perier     // Append the induction variable stepping logic to the last body block and
23668d69237SJean Perier     // branch back to the condition block. Loop-carried values are taken from
23768d69237SJean Perier     // operands of the loop terminator.
23868d69237SJean Perier     auto *terminator = lastBodyBlock->getTerminator();
23968d69237SJean Perier     rewriter.setInsertionPointToEnd(lastBodyBlock);
24068d69237SJean Perier     auto step = whileOp.step();
241*a54f4eaeSMogball     mlir::Value stepped = rewriter.create<mlir::arith::AddIOp>(loc, iv, step);
24268d69237SJean Perier     assert(stepped && "must be a Value");
24368d69237SJean Perier 
24468d69237SJean Perier     llvm::SmallVector<mlir::Value> loopCarried;
24568d69237SJean Perier     loopCarried.push_back(stepped);
24668d69237SJean Perier     auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin())
24768d69237SJean Perier                                       : terminator->operand_begin();
24868d69237SJean Perier     loopCarried.append(begin, terminator->operand_end());
24968d69237SJean Perier     rewriter.create<mlir::BranchOp>(loc, conditionBlock, loopCarried);
25068d69237SJean Perier     rewriter.eraseOp(terminator);
25168d69237SJean Perier 
25268d69237SJean Perier     // Compute loop bounds before branching to the condition.
25368d69237SJean Perier     rewriter.setInsertionPointToEnd(initBlock);
25468d69237SJean Perier     auto lowerBound = whileOp.lowerBound();
25568d69237SJean Perier     auto upperBound = whileOp.upperBound();
25668d69237SJean Perier     assert(lowerBound && upperBound && "must be a Value");
25768d69237SJean Perier 
25868d69237SJean Perier     // The initial values of loop-carried values is obtained from the operands
25968d69237SJean Perier     // of the loop operation.
26068d69237SJean Perier     llvm::SmallVector<mlir::Value> destOperands;
26168d69237SJean Perier     destOperands.push_back(lowerBound);
26268d69237SJean Perier     auto iterOperands = whileOp.getIterOperands();
26368d69237SJean Perier     destOperands.append(iterOperands.begin(), iterOperands.end());
26468d69237SJean Perier     rewriter.create<mlir::BranchOp>(loc, conditionBlock, destOperands);
26568d69237SJean Perier 
26668d69237SJean Perier     // With the body block done, we can fill in the condition block.
26768d69237SJean Perier     rewriter.setInsertionPointToEnd(conditionBlock);
26868d69237SJean Perier     // The comparison depends on the sign of the step value. We fully expect
26968d69237SJean Perier     // this expression to be folded by the optimizer or LLVM. This expression
27068d69237SJean Perier     // is written this way so that `step == 0` always returns `false`.
271*a54f4eaeSMogball     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
272*a54f4eaeSMogball     auto compl0 = rewriter.create<mlir::arith::CmpIOp>(
273*a54f4eaeSMogball         loc, arith::CmpIPredicate::slt, zero, step);
274*a54f4eaeSMogball     auto compl1 = rewriter.create<mlir::arith::CmpIOp>(
275*a54f4eaeSMogball         loc, arith::CmpIPredicate::sle, iv, upperBound);
276*a54f4eaeSMogball     auto compl2 = rewriter.create<mlir::arith::CmpIOp>(
277*a54f4eaeSMogball         loc, arith::CmpIPredicate::slt, step, zero);
278*a54f4eaeSMogball     auto compl3 = rewriter.create<mlir::arith::CmpIOp>(
279*a54f4eaeSMogball         loc, arith::CmpIPredicate::sle, upperBound, iv);
280*a54f4eaeSMogball     auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1);
281*a54f4eaeSMogball     auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3);
282*a54f4eaeSMogball     auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1);
28368d69237SJean Perier     // Remember to AND in the early-exit bool.
284*a54f4eaeSMogball     auto comparison =
285*a54f4eaeSMogball         rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
28668d69237SJean Perier     rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBodyBlock,
28768d69237SJean Perier                                         llvm::ArrayRef<mlir::Value>(), endBlock,
28868d69237SJean Perier                                         llvm::ArrayRef<mlir::Value>());
28968d69237SJean Perier     // The result of the loop operation is the values of the condition block
29068d69237SJean Perier     // arguments except the induction variable on the last iteration.
29168d69237SJean Perier     auto args = whileOp.finalValue()
29268d69237SJean Perier                     ? conditionBlock->getArguments()
29368d69237SJean Perier                     : conditionBlock->getArguments().drop_front();
29468d69237SJean Perier     rewriter.replaceOp(whileOp, args);
29568d69237SJean Perier     return success();
29668d69237SJean Perier   }
29768d69237SJean Perier 
29868d69237SJean Perier private:
29968d69237SJean Perier   bool forceLoopToExecuteOnce;
30068d69237SJean Perier };
30168d69237SJean Perier 
30268d69237SJean Perier /// Convert FIR structured control flow ops to CFG ops.
30368d69237SJean Perier class CfgConversion : public CFGConversionBase<CfgConversion> {
30468d69237SJean Perier public:
30568d69237SJean Perier   void runOnFunction() override {
30668d69237SJean Perier     auto *context = &getContext();
30768d69237SJean Perier     mlir::OwningRewritePatternList patterns(context);
30868d69237SJean Perier     patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
30968d69237SJean Perier         context, forceLoopToExecuteOnce);
31068d69237SJean Perier     mlir::ConversionTarget target(*context);
31168d69237SJean Perier     target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
31268d69237SJean Perier                            mlir::StandardOpsDialect>();
31368d69237SJean Perier 
31468d69237SJean Perier     // apply the patterns
31568d69237SJean Perier     target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
31668d69237SJean Perier     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
31768d69237SJean Perier     if (mlir::failed(mlir::applyPartialConversion(getFunction(), target,
31868d69237SJean Perier                                                   std::move(patterns)))) {
31968d69237SJean Perier       mlir::emitError(mlir::UnknownLoc::get(context),
32068d69237SJean Perier                       "error in converting to CFG\n");
32168d69237SJean Perier       signalPassFailure();
32268d69237SJean Perier     }
32368d69237SJean Perier   }
32468d69237SJean Perier };
32568d69237SJean Perier } // namespace
32668d69237SJean Perier 
32768d69237SJean Perier /// Convert FIR's structured control flow ops to CFG ops.  This
32868d69237SJean Perier /// conversion enables the `createLowerToCFGPass` to transform these to CFG
32968d69237SJean Perier /// form.
33068d69237SJean Perier std::unique_ptr<mlir::Pass> fir::createFirToCfgPass() {
33168d69237SJean Perier   return std::make_unique<CfgConversion>();
33268d69237SJean Perier }
333