1032cb165SMorten Borup Petersen //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2032cb165SMorten Borup Petersen //
3032cb165SMorten Borup Petersen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4032cb165SMorten Borup Petersen // See https://llvm.org/LICENSE.txt for license information.
5032cb165SMorten Borup Petersen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6032cb165SMorten Borup Petersen //
7032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===//
8032cb165SMorten Borup Petersen //
9032cb165SMorten Borup Petersen // Transforms SCF.ForOp's into SCF.WhileOp's.
10032cb165SMorten Borup Petersen //
11032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===//
12032cb165SMorten Borup Petersen 
13032cb165SMorten Borup Petersen #include "PassDetail.h"
14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Passes.h"
178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18032cb165SMorten Borup Petersen #include "mlir/IR/PatternMatch.h"
19032cb165SMorten Borup Petersen #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20032cb165SMorten Borup Petersen 
21032cb165SMorten Borup Petersen using namespace llvm;
22032cb165SMorten Borup Petersen using namespace mlir;
23032cb165SMorten Borup Petersen using scf::ForOp;
24032cb165SMorten Borup Petersen using scf::WhileOp;
25032cb165SMorten Borup Petersen 
26032cb165SMorten Borup Petersen namespace {
27032cb165SMorten Borup Petersen 
28032cb165SMorten Borup Petersen struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
29032cb165SMorten Borup Petersen   using OpRewritePattern<ForOp>::OpRewritePattern;
30032cb165SMorten Borup Petersen 
matchAndRewrite__anon26886cca0111::ForLoopLoweringPattern31032cb165SMorten Borup Petersen   LogicalResult matchAndRewrite(ForOp forOp,
32032cb165SMorten Borup Petersen                                 PatternRewriter &rewriter) const override {
33032cb165SMorten Borup Petersen     // Generate type signature for the loop-carried values. The induction
34032cb165SMorten Borup Petersen     // variable is placed first, followed by the forOp.iterArgs.
35e084679fSRiver Riddle     SmallVector<Type> lcvTypes;
36e084679fSRiver Riddle     SmallVector<Location> lcvLocs;
37032cb165SMorten Borup Petersen     lcvTypes.push_back(forOp.getInductionVar().getType());
38e084679fSRiver Riddle     lcvLocs.push_back(forOp.getInductionVar().getLoc());
39e084679fSRiver Riddle     for (Value value : forOp.getInitArgs()) {
40e084679fSRiver Riddle       lcvTypes.push_back(value.getType());
41e084679fSRiver Riddle       lcvLocs.push_back(value.getLoc());
42e084679fSRiver Riddle     }
43032cb165SMorten Borup Petersen 
44032cb165SMorten Borup Petersen     // Build scf.WhileOp
45032cb165SMorten Borup Petersen     SmallVector<Value> initArgs;
46c0342a2dSJacques Pienaar     initArgs.push_back(forOp.getLowerBound());
47c0342a2dSJacques Pienaar     llvm::append_range(initArgs, forOp.getInitArgs());
48032cb165SMorten Borup Petersen     auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
49032cb165SMorten Borup Petersen                                             forOp->getAttrs());
50032cb165SMorten Borup Petersen 
51032cb165SMorten Borup Petersen     // 'before' region contains the loop condition and forwarding of iteration
52032cb165SMorten Borup Petersen     // arguments to the 'after' region.
53032cb165SMorten Borup Petersen     auto *beforeBlock = rewriter.createBlock(
54e084679fSRiver Riddle         &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
55c0342a2dSJacques Pienaar     rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
56a54f4eaeSMogball     auto cmpOp = rewriter.create<arith::CmpIOp>(
57a54f4eaeSMogball         whileOp.getLoc(), arith::CmpIPredicate::slt,
58c0342a2dSJacques Pienaar         beforeBlock->getArgument(0), forOp.getUpperBound());
59032cb165SMorten Borup Petersen     rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
60032cb165SMorten Borup Petersen                                       beforeBlock->getArguments());
61032cb165SMorten Borup Petersen 
62032cb165SMorten Borup Petersen     // Inline for-loop body into an executeRegion operation in the "after"
63032cb165SMorten Borup Petersen     // region. The return type of the execRegionOp does not contain the
64032cb165SMorten Borup Petersen     // iv - yields in the source for-loop contain only iterArgs.
65032cb165SMorten Borup Petersen     auto *afterBlock = rewriter.createBlock(
66e084679fSRiver Riddle         &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
67032cb165SMorten Borup Petersen 
68032cb165SMorten Borup Petersen     // Add induction variable incrementation
69032cb165SMorten Borup Petersen     rewriter.setInsertionPointToEnd(afterBlock);
70a54f4eaeSMogball     auto ivIncOp = rewriter.create<arith::AddIOp>(
71c0342a2dSJacques Pienaar         whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
72032cb165SMorten Borup Petersen 
73032cb165SMorten Borup Petersen     // Rewrite uses of the for-loop block arguments to the new while-loop
74032cb165SMorten Borup Petersen     // "after" arguments
75e4853be2SMehdi Amini     for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
76032cb165SMorten Borup Petersen       barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
77032cb165SMorten Borup Petersen 
78032cb165SMorten Borup Petersen     // Inline for-loop body operations into 'after' region.
79032cb165SMorten Borup Petersen     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
80032cb165SMorten Borup Petersen       arg.moveBefore(afterBlock, afterBlock->end());
81032cb165SMorten Borup Petersen 
82032cb165SMorten Borup Petersen     // Add incremented IV to yield operations
83032cb165SMorten Borup Petersen     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
84032cb165SMorten Borup Petersen       SmallVector<Value> yieldOperands = yieldOp.getOperands();
85032cb165SMorten Borup Petersen       yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
86032cb165SMorten Borup Petersen       yieldOp->setOperands(yieldOperands);
87032cb165SMorten Borup Petersen     }
88032cb165SMorten Borup Petersen 
89032cb165SMorten Borup Petersen     // We cannot do a direct replacement of the forOp since the while op returns
90032cb165SMorten Borup Petersen     // an extra value (the induction variable escapes the loop through being
91032cb165SMorten Borup Petersen     // carried in the set of iterargs). Instead, rewrite uses of the forOp
92032cb165SMorten Borup Petersen     // results.
93e4853be2SMehdi Amini     for (const auto &arg : llvm::enumerate(forOp.getResults()))
94032cb165SMorten Borup Petersen       arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
95032cb165SMorten Borup Petersen 
96032cb165SMorten Borup Petersen     rewriter.eraseOp(forOp);
97032cb165SMorten Borup Petersen     return success();
98032cb165SMorten Borup Petersen   }
99032cb165SMorten Borup Petersen };
100032cb165SMorten Borup Petersen 
101032cb165SMorten Borup Petersen struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
runOnOperation__anon26886cca0111::ForToWhileLoop10241574554SRiver Riddle   void runOnOperation() override {
103*54998986SStella Laurenzo     auto *parentOp = getOperation();
104*54998986SStella Laurenzo     MLIRContext *ctx = parentOp->getContext();
105032cb165SMorten Borup Petersen     RewritePatternSet patterns(ctx);
106032cb165SMorten Borup Petersen     patterns.add<ForLoopLoweringPattern>(ctx);
107*54998986SStella Laurenzo     (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
108032cb165SMorten Borup Petersen   }
109032cb165SMorten Borup Petersen };
110032cb165SMorten Borup Petersen } // namespace
111032cb165SMorten Borup Petersen 
createForToWhileLoopPass()112032cb165SMorten Borup Petersen std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
113032cb165SMorten Borup Petersen   return std::make_unique<ForToWhileLoop>();
114032cb165SMorten Borup Petersen }
115