1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForOp's into SCF.WhileOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/SCF/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace llvm;
23 using namespace mlir;
24 using scf::ForOp;
25 using scf::WhileOp;
26 
27 namespace {
28 
29 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
30   using OpRewritePattern<ForOp>::OpRewritePattern;
31 
32   LogicalResult matchAndRewrite(ForOp forOp,
33                                 PatternRewriter &rewriter) const override {
34     // Generate type signature for the loop-carried values. The induction
35     // variable is placed first, followed by the forOp.iterArgs.
36     SmallVector<Type> lcvTypes;
37     SmallVector<Location> lcvLocs;
38     lcvTypes.push_back(forOp.getInductionVar().getType());
39     lcvLocs.push_back(forOp.getInductionVar().getLoc());
40     for (Value value : forOp.getInitArgs()) {
41       lcvTypes.push_back(value.getType());
42       lcvLocs.push_back(value.getLoc());
43     }
44 
45     // Build scf.WhileOp
46     SmallVector<Value> initArgs;
47     initArgs.push_back(forOp.getLowerBound());
48     llvm::append_range(initArgs, forOp.getInitArgs());
49     auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
50                                             forOp->getAttrs());
51 
52     // 'before' region contains the loop condition and forwarding of iteration
53     // arguments to the 'after' region.
54     auto *beforeBlock = rewriter.createBlock(
55         &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
56     rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
57     auto cmpOp = rewriter.create<arith::CmpIOp>(
58         whileOp.getLoc(), arith::CmpIPredicate::slt,
59         beforeBlock->getArgument(0), forOp.getUpperBound());
60     rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
61                                       beforeBlock->getArguments());
62 
63     // Inline for-loop body into an executeRegion operation in the "after"
64     // region. The return type of the execRegionOp does not contain the
65     // iv - yields in the source for-loop contain only iterArgs.
66     auto *afterBlock = rewriter.createBlock(
67         &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
68 
69     // Add induction variable incrementation
70     rewriter.setInsertionPointToEnd(afterBlock);
71     auto ivIncOp = rewriter.create<arith::AddIOp>(
72         whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
73 
74     // Rewrite uses of the for-loop block arguments to the new while-loop
75     // "after" arguments
76     for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
77       barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
78 
79     // Inline for-loop body operations into 'after' region.
80     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
81       arg.moveBefore(afterBlock, afterBlock->end());
82 
83     // Add incremented IV to yield operations
84     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
85       SmallVector<Value> yieldOperands = yieldOp.getOperands();
86       yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
87       yieldOp->setOperands(yieldOperands);
88     }
89 
90     // We cannot do a direct replacement of the forOp since the while op returns
91     // an extra value (the induction variable escapes the loop through being
92     // carried in the set of iterargs). Instead, rewrite uses of the forOp
93     // results.
94     for (const auto &arg : llvm::enumerate(forOp.getResults()))
95       arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
96 
97     rewriter.eraseOp(forOp);
98     return success();
99   }
100 };
101 
102 struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
103   void runOnOperation() override {
104     FuncOp funcOp = getOperation();
105     MLIRContext *ctx = funcOp.getContext();
106     RewritePatternSet patterns(ctx);
107     patterns.add<ForLoopLoweringPattern>(ctx);
108     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
109   }
110 };
111 } // namespace
112 
113 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
114   return std::make_unique<ForToWhileLoop>();
115 }
116