1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForOp's into SCF.WhileOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/SCF/Passes.h"
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace llvm;
22 using namespace mlir;
23 using scf::ForOp;
24 using scf::WhileOp;
25 
26 namespace {
27 
28 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
29   using OpRewritePattern<ForOp>::OpRewritePattern;
30 
31   LogicalResult matchAndRewrite(ForOp forOp,
32                                 PatternRewriter &rewriter) const override {
33     // Generate type signature for the loop-carried values. The induction
34     // variable is placed first, followed by the forOp.iterArgs.
35     SmallVector<Type, 8> lcvTypes;
36     lcvTypes.push_back(forOp.getInductionVar().getType());
37     llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes),
38                     [&](auto v) { return v.getType(); });
39 
40     // Build scf.WhileOp
41     SmallVector<Value> initArgs;
42     initArgs.push_back(forOp.lowerBound());
43     llvm::append_range(initArgs, forOp.initArgs());
44     auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
45                                             forOp->getAttrs());
46 
47     // 'before' region contains the loop condition and forwarding of iteration
48     // arguments to the 'after' region.
49     auto *beforeBlock = rewriter.createBlock(
50         &whileOp.before(), whileOp.before().begin(), lcvTypes, {});
51     rewriter.setInsertionPointToStart(&whileOp.before().front());
52     auto cmpOp = rewriter.create<CmpIOp>(whileOp.getLoc(), CmpIPredicate::slt,
53                                          beforeBlock->getArgument(0),
54                                          forOp.upperBound());
55     rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
56                                       beforeBlock->getArguments());
57 
58     // Inline for-loop body into an executeRegion operation in the "after"
59     // region. The return type of the execRegionOp does not contain the
60     // iv - yields in the source for-loop contain only iterArgs.
61     auto *afterBlock = rewriter.createBlock(
62         &whileOp.after(), whileOp.after().begin(), lcvTypes, {});
63 
64     // Add induction variable incrementation
65     rewriter.setInsertionPointToEnd(afterBlock);
66     auto ivIncOp = rewriter.create<AddIOp>(
67         whileOp.getLoc(), afterBlock->getArgument(0), forOp.step());
68 
69     // Rewrite uses of the for-loop block arguments to the new while-loop
70     // "after" arguments
71     for (auto barg : enumerate(forOp.getBody(0)->getArguments()))
72       barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
73 
74     // Inline for-loop body operations into 'after' region.
75     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
76       arg.moveBefore(afterBlock, afterBlock->end());
77 
78     // Add incremented IV to yield operations
79     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
80       SmallVector<Value> yieldOperands = yieldOp.getOperands();
81       yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
82       yieldOp->setOperands(yieldOperands);
83     }
84 
85     // We cannot do a direct replacement of the forOp since the while op returns
86     // an extra value (the induction variable escapes the loop through being
87     // carried in the set of iterargs). Instead, rewrite uses of the forOp
88     // results.
89     for (auto arg : llvm::enumerate(forOp.getResults()))
90       arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
91 
92     rewriter.eraseOp(forOp);
93     return success();
94   }
95 };
96 
97 struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
98   void runOnFunction() override {
99     FuncOp funcOp = getFunction();
100     MLIRContext *ctx = funcOp.getContext();
101     RewritePatternSet patterns(ctx);
102     patterns.add<ForLoopLoweringPattern>(ctx);
103     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
104   }
105 };
106 } // namespace
107 
108 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
109   return std::make_unique<ForToWhileLoop>();
110 }
111