1032cb165SMorten Borup Petersen //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2032cb165SMorten Borup Petersen //
3032cb165SMorten Borup Petersen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4032cb165SMorten Borup Petersen // See https://llvm.org/LICENSE.txt for license information.
5032cb165SMorten Borup Petersen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6032cb165SMorten Borup Petersen //
7032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===//
8032cb165SMorten Borup Petersen //
9032cb165SMorten Borup Petersen // Transforms SCF.ForOp's into SCF.WhileOp's.
10032cb165SMorten Borup Petersen //
11032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===//
12032cb165SMorten Borup Petersen 
13032cb165SMorten Borup Petersen #include "PassDetail.h"
14*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15032cb165SMorten Borup Petersen #include "mlir/Dialect/SCF/Passes.h"
16032cb165SMorten Borup Petersen #include "mlir/Dialect/SCF/SCF.h"
17032cb165SMorten Borup Petersen #include "mlir/Dialect/SCF/Transforms.h"
18032cb165SMorten Borup Petersen #include "mlir/Dialect/StandardOps/IR/Ops.h"
19032cb165SMorten Borup Petersen #include "mlir/IR/PatternMatch.h"
20032cb165SMorten Borup Petersen #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21032cb165SMorten Borup Petersen 
22032cb165SMorten Borup Petersen using namespace llvm;
23032cb165SMorten Borup Petersen using namespace mlir;
24032cb165SMorten Borup Petersen using scf::ForOp;
25032cb165SMorten Borup Petersen using scf::WhileOp;
26032cb165SMorten Borup Petersen 
27032cb165SMorten Borup Petersen namespace {
28032cb165SMorten Borup Petersen 
29032cb165SMorten Borup Petersen struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
30032cb165SMorten Borup Petersen   using OpRewritePattern<ForOp>::OpRewritePattern;
31032cb165SMorten Borup Petersen 
32032cb165SMorten Borup Petersen   LogicalResult matchAndRewrite(ForOp forOp,
33032cb165SMorten Borup Petersen                                 PatternRewriter &rewriter) const override {
34032cb165SMorten Borup Petersen     // Generate type signature for the loop-carried values. The induction
35032cb165SMorten Borup Petersen     // variable is placed first, followed by the forOp.iterArgs.
36032cb165SMorten Borup Petersen     SmallVector<Type, 8> lcvTypes;
37032cb165SMorten Borup Petersen     lcvTypes.push_back(forOp.getInductionVar().getType());
38032cb165SMorten Borup Petersen     llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes),
39032cb165SMorten Borup Petersen                     [&](auto v) { return v.getType(); });
40032cb165SMorten Borup Petersen 
41032cb165SMorten Borup Petersen     // Build scf.WhileOp
42032cb165SMorten Borup Petersen     SmallVector<Value> initArgs;
43032cb165SMorten Borup Petersen     initArgs.push_back(forOp.lowerBound());
44032cb165SMorten Borup Petersen     llvm::append_range(initArgs, forOp.initArgs());
45032cb165SMorten Borup Petersen     auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
46032cb165SMorten Borup Petersen                                             forOp->getAttrs());
47032cb165SMorten Borup Petersen 
48032cb165SMorten Borup Petersen     // 'before' region contains the loop condition and forwarding of iteration
49032cb165SMorten Borup Petersen     // arguments to the 'after' region.
50032cb165SMorten Borup Petersen     auto *beforeBlock = rewriter.createBlock(
51032cb165SMorten Borup Petersen         &whileOp.before(), whileOp.before().begin(), lcvTypes, {});
52032cb165SMorten Borup Petersen     rewriter.setInsertionPointToStart(&whileOp.before().front());
53*a54f4eaeSMogball     auto cmpOp = rewriter.create<arith::CmpIOp>(
54*a54f4eaeSMogball         whileOp.getLoc(), arith::CmpIPredicate::slt,
55*a54f4eaeSMogball         beforeBlock->getArgument(0), forOp.upperBound());
56032cb165SMorten Borup Petersen     rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
57032cb165SMorten Borup Petersen                                       beforeBlock->getArguments());
58032cb165SMorten Borup Petersen 
59032cb165SMorten Borup Petersen     // Inline for-loop body into an executeRegion operation in the "after"
60032cb165SMorten Borup Petersen     // region. The return type of the execRegionOp does not contain the
61032cb165SMorten Borup Petersen     // iv - yields in the source for-loop contain only iterArgs.
62032cb165SMorten Borup Petersen     auto *afterBlock = rewriter.createBlock(
63032cb165SMorten Borup Petersen         &whileOp.after(), whileOp.after().begin(), lcvTypes, {});
64032cb165SMorten Borup Petersen 
65032cb165SMorten Borup Petersen     // Add induction variable incrementation
66032cb165SMorten Borup Petersen     rewriter.setInsertionPointToEnd(afterBlock);
67*a54f4eaeSMogball     auto ivIncOp = rewriter.create<arith::AddIOp>(
68032cb165SMorten Borup Petersen         whileOp.getLoc(), afterBlock->getArgument(0), forOp.step());
69032cb165SMorten Borup Petersen 
70032cb165SMorten Borup Petersen     // Rewrite uses of the for-loop block arguments to the new while-loop
71032cb165SMorten Borup Petersen     // "after" arguments
72032cb165SMorten Borup Petersen     for (auto barg : enumerate(forOp.getBody(0)->getArguments()))
73032cb165SMorten Borup Petersen       barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
74032cb165SMorten Borup Petersen 
75032cb165SMorten Borup Petersen     // Inline for-loop body operations into 'after' region.
76032cb165SMorten Borup Petersen     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
77032cb165SMorten Borup Petersen       arg.moveBefore(afterBlock, afterBlock->end());
78032cb165SMorten Borup Petersen 
79032cb165SMorten Borup Petersen     // Add incremented IV to yield operations
80032cb165SMorten Borup Petersen     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
81032cb165SMorten Borup Petersen       SmallVector<Value> yieldOperands = yieldOp.getOperands();
82032cb165SMorten Borup Petersen       yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
83032cb165SMorten Borup Petersen       yieldOp->setOperands(yieldOperands);
84032cb165SMorten Borup Petersen     }
85032cb165SMorten Borup Petersen 
86032cb165SMorten Borup Petersen     // We cannot do a direct replacement of the forOp since the while op returns
87032cb165SMorten Borup Petersen     // an extra value (the induction variable escapes the loop through being
88032cb165SMorten Borup Petersen     // carried in the set of iterargs). Instead, rewrite uses of the forOp
89032cb165SMorten Borup Petersen     // results.
90032cb165SMorten Borup Petersen     for (auto arg : llvm::enumerate(forOp.getResults()))
91032cb165SMorten Borup Petersen       arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
92032cb165SMorten Borup Petersen 
93032cb165SMorten Borup Petersen     rewriter.eraseOp(forOp);
94032cb165SMorten Borup Petersen     return success();
95032cb165SMorten Borup Petersen   }
96032cb165SMorten Borup Petersen };
97032cb165SMorten Borup Petersen 
98032cb165SMorten Borup Petersen struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
99032cb165SMorten Borup Petersen   void runOnFunction() override {
100032cb165SMorten Borup Petersen     FuncOp funcOp = getFunction();
101032cb165SMorten Borup Petersen     MLIRContext *ctx = funcOp.getContext();
102032cb165SMorten Borup Petersen     RewritePatternSet patterns(ctx);
103032cb165SMorten Borup Petersen     patterns.add<ForLoopLoweringPattern>(ctx);
104032cb165SMorten Borup Petersen     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
105032cb165SMorten Borup Petersen   }
106032cb165SMorten Borup Petersen };
107032cb165SMorten Borup Petersen } // namespace
108032cb165SMorten Borup Petersen 
109032cb165SMorten Borup Petersen std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
110032cb165SMorten Borup Petersen   return std::make_unique<ForToWhileLoop>();
111032cb165SMorten Borup Petersen }
112