1 //===-- RewriteLoop.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Transforms/Passes.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "llvm/Support/CommandLine.h"
19 
20 using namespace fir;
21 
22 namespace {
23 
24 // Conversion of fir control ops to more primitive control-flow.
25 //
26 // FIR loops that cannot be converted to the affine dialect will remain as
27 // `fir.do_loop` operations.  These can be converted to control-flow operations.
28 
29 /// Convert `fir.do_loop` to CFG
30 class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
31 public:
32   using OpRewritePattern::OpRewritePattern;
33 
34   CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
35       : mlir::OpRewritePattern<fir::DoLoopOp>(ctx),
36         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
37 
38   mlir::LogicalResult
39   matchAndRewrite(DoLoopOp loop,
40                   mlir::PatternRewriter &rewriter) const override {
41     auto loc = loop.getLoc();
42 
43     // Create the start and end blocks that will wrap the DoLoopOp with an
44     // initalizer and an end point
45     auto *initBlock = rewriter.getInsertionBlock();
46     auto initPos = rewriter.getInsertionPoint();
47     auto *endBlock = rewriter.splitBlock(initBlock, initPos);
48 
49     // Split the first DoLoopOp block in two parts. The part before will be the
50     // conditional block since it already has the induction variable and
51     // loop-carried values as arguments.
52     auto *conditionalBlock = &loop.getRegion().front();
53     conditionalBlock->addArgument(rewriter.getIndexType(), loc);
54     auto *firstBlock =
55         rewriter.splitBlock(conditionalBlock, conditionalBlock->begin());
56     auto *lastBlock = &loop.getRegion().back();
57 
58     // Move the blocks from the DoLoopOp between initBlock and endBlock
59     rewriter.inlineRegionBefore(loop.getRegion(), endBlock);
60 
61     // Get loop values from the DoLoopOp
62     auto low = loop.getLowerBound();
63     auto high = loop.getUpperBound();
64     assert(low && high && "must be a Value");
65     auto step = loop.getStep();
66 
67     // Initalization block
68     rewriter.setInsertionPointToEnd(initBlock);
69     auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
70     auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
71     mlir::Value iters =
72         rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
73 
74     if (forceLoopToExecuteOnce) {
75       auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
76       auto cond = rewriter.create<mlir::arith::CmpIOp>(
77           loc, arith::CmpIPredicate::sle, iters, zero);
78       auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
79       iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters);
80     }
81 
82     llvm::SmallVector<mlir::Value> loopOperands;
83     loopOperands.push_back(low);
84     auto operands = loop.getIterOperands();
85     loopOperands.append(operands.begin(), operands.end());
86     loopOperands.push_back(iters);
87 
88     rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands);
89 
90     // Last loop block
91     auto *terminator = lastBlock->getTerminator();
92     rewriter.setInsertionPointToEnd(lastBlock);
93     auto iv = conditionalBlock->getArgument(0);
94     mlir::Value steppedIndex =
95         rewriter.create<mlir::arith::AddIOp>(loc, iv, step);
96     assert(steppedIndex && "must be a Value");
97     auto lastArg = conditionalBlock->getNumArguments() - 1;
98     auto itersLeft = conditionalBlock->getArgument(lastArg);
99     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
100     mlir::Value itersMinusOne =
101         rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one);
102 
103     llvm::SmallVector<mlir::Value> loopCarried;
104     loopCarried.push_back(steppedIndex);
105     auto begin = loop.getFinalValue() ? std::next(terminator->operand_begin())
106                                       : terminator->operand_begin();
107     loopCarried.append(begin, terminator->operand_end());
108     loopCarried.push_back(itersMinusOne);
109     rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried);
110     rewriter.eraseOp(terminator);
111 
112     // Conditional block
113     rewriter.setInsertionPointToEnd(conditionalBlock);
114     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
115     auto comparison = rewriter.create<mlir::arith::CmpIOp>(
116         loc, arith::CmpIPredicate::sgt, itersLeft, zero);
117 
118     rewriter.create<mlir::cf::CondBranchOp>(
119         loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
120         llvm::ArrayRef<mlir::Value>());
121 
122     // The result of the loop operation is the values of the condition block
123     // arguments except the induction variable on the last iteration.
124     auto args = loop.getFinalValue()
125                     ? conditionalBlock->getArguments()
126                     : conditionalBlock->getArguments().drop_front();
127     rewriter.replaceOp(loop, args.drop_back());
128     return success();
129   }
130 
131 private:
132   bool forceLoopToExecuteOnce;
133 };
134 
135 /// Convert `fir.if` to control-flow
136 class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
137 public:
138   using OpRewritePattern::OpRewritePattern;
139 
140   CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
141       : mlir::OpRewritePattern<fir::IfOp>(ctx) {}
142 
143   mlir::LogicalResult
144   matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override {
145     auto loc = ifOp.getLoc();
146 
147     // Split the block containing the 'fir.if' into two parts.  The part before
148     // will contain the condition, the part after will be the continuation
149     // point.
150     auto *condBlock = rewriter.getInsertionBlock();
151     auto opPosition = rewriter.getInsertionPoint();
152     auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
153     mlir::Block *continueBlock;
154     if (ifOp.getNumResults() == 0) {
155       continueBlock = remainingOpsBlock;
156     } else {
157       continueBlock =
158           rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
159       rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock);
160     }
161 
162     // Move blocks from the "then" region to the region containing 'fir.if',
163     // place it before the continuation block, and branch to it.
164     auto &ifOpRegion = ifOp.getThenRegion();
165     auto *ifOpBlock = &ifOpRegion.front();
166     auto *ifOpTerminator = ifOpRegion.back().getTerminator();
167     auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
168     rewriter.setInsertionPointToEnd(&ifOpRegion.back());
169     rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
170                                         ifOpTerminatorOperands);
171     rewriter.eraseOp(ifOpTerminator);
172     rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
173 
174     // Move blocks from the "else" region (if present) to the region containing
175     // 'fir.if', place it before the continuation block and branch to it.  It
176     // will be placed after the "then" regions.
177     auto *otherwiseBlock = continueBlock;
178     auto &otherwiseRegion = ifOp.getElseRegion();
179     if (!otherwiseRegion.empty()) {
180       otherwiseBlock = &otherwiseRegion.front();
181       auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
182       auto otherwiseTermOperands = otherwiseTerm->getOperands();
183       rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
184       rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
185                                           otherwiseTermOperands);
186       rewriter.eraseOp(otherwiseTerm);
187       rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
188     }
189 
190     rewriter.setInsertionPointToEnd(condBlock);
191     rewriter.create<mlir::cf::CondBranchOp>(
192         loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
193         otherwiseBlock, llvm::ArrayRef<mlir::Value>());
194     rewriter.replaceOp(ifOp, continueBlock->getArguments());
195     return success();
196   }
197 };
198 
199 /// Convert `fir.iter_while` to control-flow.
200 class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> {
201 public:
202   using OpRewritePattern::OpRewritePattern;
203 
204   CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
205       : mlir::OpRewritePattern<fir::IterWhileOp>(ctx) {}
206 
207   mlir::LogicalResult
208   matchAndRewrite(fir::IterWhileOp whileOp,
209                   mlir::PatternRewriter &rewriter) const override {
210     auto loc = whileOp.getLoc();
211 
212     // Start by splitting the block containing the 'fir.do_loop' into two parts.
213     // The part before will get the init code, the part after will be the end
214     // point.
215     auto *initBlock = rewriter.getInsertionBlock();
216     auto initPosition = rewriter.getInsertionPoint();
217     auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
218 
219     // Use the first block of the loop body as the condition block since it is
220     // the block that has the induction variable and loop-carried values as
221     // arguments. Split out all operations from the first block into a new
222     // block. Move all body blocks from the loop body region to the region
223     // containing the loop.
224     auto *conditionBlock = &whileOp.getRegion().front();
225     auto *firstBodyBlock =
226         rewriter.splitBlock(conditionBlock, conditionBlock->begin());
227     auto *lastBodyBlock = &whileOp.getRegion().back();
228     rewriter.inlineRegionBefore(whileOp.getRegion(), endBlock);
229     auto iv = conditionBlock->getArgument(0);
230     auto iterateVar = conditionBlock->getArgument(1);
231 
232     // Append the induction variable stepping logic to the last body block and
233     // branch back to the condition block. Loop-carried values are taken from
234     // operands of the loop terminator.
235     auto *terminator = lastBodyBlock->getTerminator();
236     rewriter.setInsertionPointToEnd(lastBodyBlock);
237     auto step = whileOp.getStep();
238     mlir::Value stepped = rewriter.create<mlir::arith::AddIOp>(loc, iv, step);
239     assert(stepped && "must be a Value");
240 
241     llvm::SmallVector<mlir::Value> loopCarried;
242     loopCarried.push_back(stepped);
243     auto begin = whileOp.getFinalValue()
244                      ? std::next(terminator->operand_begin())
245                      : terminator->operand_begin();
246     loopCarried.append(begin, terminator->operand_end());
247     rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried);
248     rewriter.eraseOp(terminator);
249 
250     // Compute loop bounds before branching to the condition.
251     rewriter.setInsertionPointToEnd(initBlock);
252     auto lowerBound = whileOp.getLowerBound();
253     auto upperBound = whileOp.getUpperBound();
254     assert(lowerBound && upperBound && "must be a Value");
255 
256     // The initial values of loop-carried values is obtained from the operands
257     // of the loop operation.
258     llvm::SmallVector<mlir::Value> destOperands;
259     destOperands.push_back(lowerBound);
260     auto iterOperands = whileOp.getIterOperands();
261     destOperands.append(iterOperands.begin(), iterOperands.end());
262     rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands);
263 
264     // With the body block done, we can fill in the condition block.
265     rewriter.setInsertionPointToEnd(conditionBlock);
266     // The comparison depends on the sign of the step value. We fully expect
267     // this expression to be folded by the optimizer or LLVM. This expression
268     // is written this way so that `step == 0` always returns `false`.
269     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
270     auto compl0 = rewriter.create<mlir::arith::CmpIOp>(
271         loc, arith::CmpIPredicate::slt, zero, step);
272     auto compl1 = rewriter.create<mlir::arith::CmpIOp>(
273         loc, arith::CmpIPredicate::sle, iv, upperBound);
274     auto compl2 = rewriter.create<mlir::arith::CmpIOp>(
275         loc, arith::CmpIPredicate::slt, step, zero);
276     auto compl3 = rewriter.create<mlir::arith::CmpIOp>(
277         loc, arith::CmpIPredicate::sle, upperBound, iv);
278     auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1);
279     auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3);
280     auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1);
281     // Remember to AND in the early-exit bool.
282     auto comparison =
283         rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
284     rewriter.create<mlir::cf::CondBranchOp>(
285         loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(),
286         endBlock, llvm::ArrayRef<mlir::Value>());
287     // The result of the loop operation is the values of the condition block
288     // arguments except the induction variable on the last iteration.
289     auto args = whileOp.getFinalValue()
290                     ? conditionBlock->getArguments()
291                     : conditionBlock->getArguments().drop_front();
292     rewriter.replaceOp(whileOp, args);
293     return success();
294   }
295 };
296 
297 /// Convert FIR structured control flow ops to CFG ops.
298 class CfgConversion : public CFGConversionBase<CfgConversion> {
299 public:
300   void runOnOperation() override {
301     auto *context = &getContext();
302     mlir::RewritePatternSet patterns(context);
303     patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
304         context, forceLoopToExecuteOnce);
305     mlir::ConversionTarget target(*context);
306     target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect,
307                            FIROpsDialect, mlir::func::FuncDialect>();
308 
309     // apply the patterns
310     target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
311     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
312     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
313                                                   std::move(patterns)))) {
314       mlir::emitError(mlir::UnknownLoc::get(context),
315                       "error in converting to CFG\n");
316       signalPassFailure();
317     }
318   }
319 };
320 } // namespace
321 
322 /// Convert FIR's structured control flow ops to CFG ops.  This
323 /// conversion enables the `createLowerToCFGPass` to transform these to CFG
324 /// form.
325 std::unique_ptr<mlir::Pass> fir::createFirToCfgPass() {
326   return std::make_unique<CfgConversion>();
327 }
328