1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForOp's into SCF.WhileOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/SCF/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace llvm;
23 using namespace mlir;
24 using scf::ForOp;
25 using scf::WhileOp;
26 
27 namespace {
28 
29 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
30   using OpRewritePattern<ForOp>::OpRewritePattern;
31 
32   LogicalResult matchAndRewrite(ForOp forOp,
33                                 PatternRewriter &rewriter) const override {
34     // Generate type signature for the loop-carried values. The induction
35     // variable is placed first, followed by the forOp.iterArgs.
36     SmallVector<Type, 8> lcvTypes;
37     lcvTypes.push_back(forOp.getInductionVar().getType());
38     llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes),
39                     [&](auto v) { return v.getType(); });
40 
41     // Build scf.WhileOp
42     SmallVector<Value> initArgs;
43     initArgs.push_back(forOp.lowerBound());
44     llvm::append_range(initArgs, forOp.initArgs());
45     auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
46                                             forOp->getAttrs());
47 
48     // 'before' region contains the loop condition and forwarding of iteration
49     // arguments to the 'after' region.
50     auto *beforeBlock = rewriter.createBlock(
51         &whileOp.before(), whileOp.before().begin(), lcvTypes, {});
52     rewriter.setInsertionPointToStart(&whileOp.before().front());
53     auto cmpOp = rewriter.create<arith::CmpIOp>(
54         whileOp.getLoc(), arith::CmpIPredicate::slt,
55         beforeBlock->getArgument(0), forOp.upperBound());
56     rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
57                                       beforeBlock->getArguments());
58 
59     // Inline for-loop body into an executeRegion operation in the "after"
60     // region. The return type of the execRegionOp does not contain the
61     // iv - yields in the source for-loop contain only iterArgs.
62     auto *afterBlock = rewriter.createBlock(
63         &whileOp.after(), whileOp.after().begin(), lcvTypes, {});
64 
65     // Add induction variable incrementation
66     rewriter.setInsertionPointToEnd(afterBlock);
67     auto ivIncOp = rewriter.create<arith::AddIOp>(
68         whileOp.getLoc(), afterBlock->getArgument(0), forOp.step());
69 
70     // Rewrite uses of the for-loop block arguments to the new while-loop
71     // "after" arguments
72     for (auto barg : enumerate(forOp.getBody(0)->getArguments()))
73       barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
74 
75     // Inline for-loop body operations into 'after' region.
76     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
77       arg.moveBefore(afterBlock, afterBlock->end());
78 
79     // Add incremented IV to yield operations
80     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
81       SmallVector<Value> yieldOperands = yieldOp.getOperands();
82       yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
83       yieldOp->setOperands(yieldOperands);
84     }
85 
86     // We cannot do a direct replacement of the forOp since the while op returns
87     // an extra value (the induction variable escapes the loop through being
88     // carried in the set of iterargs). Instead, rewrite uses of the forOp
89     // results.
90     for (auto arg : llvm::enumerate(forOp.getResults()))
91       arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
92 
93     rewriter.eraseOp(forOp);
94     return success();
95   }
96 };
97 
98 struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
99   void runOnFunction() override {
100     FuncOp funcOp = getFunction();
101     MLIRContext *ctx = funcOp.getContext();
102     RewritePatternSet patterns(ctx);
103     patterns.add<ForLoopLoweringPattern>(ctx);
104     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
105   }
106 };
107 } // namespace
108 
109 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
110   return std::make_unique<ForToWhileLoop>();
111 }
112