1032cb165SMorten Borup Petersen //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2032cb165SMorten Borup Petersen //
3032cb165SMorten Borup Petersen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4032cb165SMorten Borup Petersen // See https://llvm.org/LICENSE.txt for license information.
5032cb165SMorten Borup Petersen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6032cb165SMorten Borup Petersen //
7032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===//
8032cb165SMorten Borup Petersen //
9032cb165SMorten Borup Petersen // Transforms SCF.ForOp's into SCF.WhileOp's.
10032cb165SMorten Borup Petersen //
11032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===//
12032cb165SMorten Borup Petersen
13032cb165SMorten Borup Petersen #include "PassDetail.h"
14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
168b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Passes.h"
178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18032cb165SMorten Borup Petersen #include "mlir/IR/PatternMatch.h"
19032cb165SMorten Borup Petersen #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20032cb165SMorten Borup Petersen
21032cb165SMorten Borup Petersen using namespace llvm;
22032cb165SMorten Borup Petersen using namespace mlir;
23032cb165SMorten Borup Petersen using scf::ForOp;
24032cb165SMorten Borup Petersen using scf::WhileOp;
25032cb165SMorten Borup Petersen
26032cb165SMorten Borup Petersen namespace {
27032cb165SMorten Borup Petersen
28032cb165SMorten Borup Petersen struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
29032cb165SMorten Borup Petersen using OpRewritePattern<ForOp>::OpRewritePattern;
30032cb165SMorten Borup Petersen
matchAndRewrite__anon26886cca0111::ForLoopLoweringPattern31032cb165SMorten Borup Petersen LogicalResult matchAndRewrite(ForOp forOp,
32032cb165SMorten Borup Petersen PatternRewriter &rewriter) const override {
33032cb165SMorten Borup Petersen // Generate type signature for the loop-carried values. The induction
34032cb165SMorten Borup Petersen // variable is placed first, followed by the forOp.iterArgs.
35e084679fSRiver Riddle SmallVector<Type> lcvTypes;
36e084679fSRiver Riddle SmallVector<Location> lcvLocs;
37032cb165SMorten Borup Petersen lcvTypes.push_back(forOp.getInductionVar().getType());
38e084679fSRiver Riddle lcvLocs.push_back(forOp.getInductionVar().getLoc());
39e084679fSRiver Riddle for (Value value : forOp.getInitArgs()) {
40e084679fSRiver Riddle lcvTypes.push_back(value.getType());
41e084679fSRiver Riddle lcvLocs.push_back(value.getLoc());
42e084679fSRiver Riddle }
43032cb165SMorten Borup Petersen
44032cb165SMorten Borup Petersen // Build scf.WhileOp
45032cb165SMorten Borup Petersen SmallVector<Value> initArgs;
46c0342a2dSJacques Pienaar initArgs.push_back(forOp.getLowerBound());
47c0342a2dSJacques Pienaar llvm::append_range(initArgs, forOp.getInitArgs());
48032cb165SMorten Borup Petersen auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
49032cb165SMorten Borup Petersen forOp->getAttrs());
50032cb165SMorten Borup Petersen
51032cb165SMorten Borup Petersen // 'before' region contains the loop condition and forwarding of iteration
52032cb165SMorten Borup Petersen // arguments to the 'after' region.
53032cb165SMorten Borup Petersen auto *beforeBlock = rewriter.createBlock(
54e084679fSRiver Riddle &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
55c0342a2dSJacques Pienaar rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
56a54f4eaeSMogball auto cmpOp = rewriter.create<arith::CmpIOp>(
57a54f4eaeSMogball whileOp.getLoc(), arith::CmpIPredicate::slt,
58c0342a2dSJacques Pienaar beforeBlock->getArgument(0), forOp.getUpperBound());
59032cb165SMorten Borup Petersen rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
60032cb165SMorten Borup Petersen beforeBlock->getArguments());
61032cb165SMorten Borup Petersen
62032cb165SMorten Borup Petersen // Inline for-loop body into an executeRegion operation in the "after"
63032cb165SMorten Borup Petersen // region. The return type of the execRegionOp does not contain the
64032cb165SMorten Borup Petersen // iv - yields in the source for-loop contain only iterArgs.
65032cb165SMorten Borup Petersen auto *afterBlock = rewriter.createBlock(
66e084679fSRiver Riddle &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
67032cb165SMorten Borup Petersen
68032cb165SMorten Borup Petersen // Add induction variable incrementation
69032cb165SMorten Borup Petersen rewriter.setInsertionPointToEnd(afterBlock);
70a54f4eaeSMogball auto ivIncOp = rewriter.create<arith::AddIOp>(
71c0342a2dSJacques Pienaar whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
72032cb165SMorten Borup Petersen
73032cb165SMorten Borup Petersen // Rewrite uses of the for-loop block arguments to the new while-loop
74032cb165SMorten Borup Petersen // "after" arguments
75e4853be2SMehdi Amini for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
76032cb165SMorten Borup Petersen barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
77032cb165SMorten Borup Petersen
78032cb165SMorten Borup Petersen // Inline for-loop body operations into 'after' region.
79032cb165SMorten Borup Petersen for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
80032cb165SMorten Borup Petersen arg.moveBefore(afterBlock, afterBlock->end());
81032cb165SMorten Borup Petersen
82032cb165SMorten Borup Petersen // Add incremented IV to yield operations
83032cb165SMorten Borup Petersen for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
84032cb165SMorten Borup Petersen SmallVector<Value> yieldOperands = yieldOp.getOperands();
85032cb165SMorten Borup Petersen yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
86032cb165SMorten Borup Petersen yieldOp->setOperands(yieldOperands);
87032cb165SMorten Borup Petersen }
88032cb165SMorten Borup Petersen
89032cb165SMorten Borup Petersen // We cannot do a direct replacement of the forOp since the while op returns
90032cb165SMorten Borup Petersen // an extra value (the induction variable escapes the loop through being
91032cb165SMorten Borup Petersen // carried in the set of iterargs). Instead, rewrite uses of the forOp
92032cb165SMorten Borup Petersen // results.
93e4853be2SMehdi Amini for (const auto &arg : llvm::enumerate(forOp.getResults()))
94032cb165SMorten Borup Petersen arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
95032cb165SMorten Borup Petersen
96032cb165SMorten Borup Petersen rewriter.eraseOp(forOp);
97032cb165SMorten Borup Petersen return success();
98032cb165SMorten Borup Petersen }
99032cb165SMorten Borup Petersen };
100032cb165SMorten Borup Petersen
101032cb165SMorten Borup Petersen struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
runOnOperation__anon26886cca0111::ForToWhileLoop10241574554SRiver Riddle void runOnOperation() override {
103*54998986SStella Laurenzo auto *parentOp = getOperation();
104*54998986SStella Laurenzo MLIRContext *ctx = parentOp->getContext();
105032cb165SMorten Borup Petersen RewritePatternSet patterns(ctx);
106032cb165SMorten Borup Petersen patterns.add<ForLoopLoweringPattern>(ctx);
107*54998986SStella Laurenzo (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
108032cb165SMorten Borup Petersen }
109032cb165SMorten Borup Petersen };
110032cb165SMorten Borup Petersen } // namespace
111032cb165SMorten Borup Petersen
createForToWhileLoopPass()112032cb165SMorten Borup Petersen std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
113032cb165SMorten Borup Petersen return std::make_unique<ForToWhileLoop>();
114032cb165SMorten Borup Petersen }
115