1 //===-- RewriteLoop.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Transforms/Passes.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/Support/CommandLine.h"
18 
19 using namespace fir;
20 
21 namespace {
22 
23 // Conversion of fir control ops to more primitive control-flow.
24 //
25 // FIR loops that cannot be converted to the affine dialect will remain as
26 // `fir.do_loop` operations.  These can be converted to control-flow operations.
27 
28 /// Convert `fir.do_loop` to CFG
29 class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
30 public:
31   using OpRewritePattern::OpRewritePattern;
32 
33   CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
34       : mlir::OpRewritePattern<fir::DoLoopOp>(ctx),
35         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
36 
37   mlir::LogicalResult
38   matchAndRewrite(DoLoopOp loop,
39                   mlir::PatternRewriter &rewriter) const override {
40     auto loc = loop.getLoc();
41 
42     // Create the start and end blocks that will wrap the DoLoopOp with an
43     // initalizer and an end point
44     auto *initBlock = rewriter.getInsertionBlock();
45     auto initPos = rewriter.getInsertionPoint();
46     auto *endBlock = rewriter.splitBlock(initBlock, initPos);
47 
48     // Split the first DoLoopOp block in two parts. The part before will be the
49     // conditional block since it already has the induction variable and
50     // loop-carried values as arguments.
51     auto *conditionalBlock = &loop.region().front();
52     conditionalBlock->addArgument(rewriter.getIndexType());
53     auto *firstBlock =
54         rewriter.splitBlock(conditionalBlock, conditionalBlock->begin());
55     auto *lastBlock = &loop.region().back();
56 
57     // Move the blocks from the DoLoopOp between initBlock and endBlock
58     rewriter.inlineRegionBefore(loop.region(), endBlock);
59 
60     // Get loop values from the DoLoopOp
61     auto low = loop.lowerBound();
62     auto high = loop.upperBound();
63     assert(low && high && "must be a Value");
64     auto step = loop.step();
65 
66     // Initalization block
67     rewriter.setInsertionPointToEnd(initBlock);
68     auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
69     auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
70     mlir::Value iters =
71         rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
72 
73     if (forceLoopToExecuteOnce) {
74       auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
75       auto cond = rewriter.create<mlir::arith::CmpIOp>(
76           loc, arith::CmpIPredicate::sle, iters, zero);
77       auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
78       iters = rewriter.create<mlir::SelectOp>(loc, cond, one, iters);
79     }
80 
81     llvm::SmallVector<mlir::Value> loopOperands;
82     loopOperands.push_back(low);
83     auto operands = loop.getIterOperands();
84     loopOperands.append(operands.begin(), operands.end());
85     loopOperands.push_back(iters);
86 
87     rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopOperands);
88 
89     // Last loop block
90     auto *terminator = lastBlock->getTerminator();
91     rewriter.setInsertionPointToEnd(lastBlock);
92     auto iv = conditionalBlock->getArgument(0);
93     mlir::Value steppedIndex =
94         rewriter.create<mlir::arith::AddIOp>(loc, iv, step);
95     assert(steppedIndex && "must be a Value");
96     auto lastArg = conditionalBlock->getNumArguments() - 1;
97     auto itersLeft = conditionalBlock->getArgument(lastArg);
98     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
99     mlir::Value itersMinusOne =
100         rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one);
101 
102     llvm::SmallVector<mlir::Value> loopCarried;
103     loopCarried.push_back(steppedIndex);
104     auto begin = loop.finalValue() ? std::next(terminator->operand_begin())
105                                    : terminator->operand_begin();
106     loopCarried.append(begin, terminator->operand_end());
107     loopCarried.push_back(itersMinusOne);
108     rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopCarried);
109     rewriter.eraseOp(terminator);
110 
111     // Conditional block
112     rewriter.setInsertionPointToEnd(conditionalBlock);
113     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
114     auto comparison = rewriter.create<mlir::arith::CmpIOp>(
115         loc, arith::CmpIPredicate::sgt, itersLeft, zero);
116 
117     rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBlock,
118                                         llvm::ArrayRef<mlir::Value>(), endBlock,
119                                         llvm::ArrayRef<mlir::Value>());
120 
121     // The result of the loop operation is the values of the condition block
122     // arguments except the induction variable on the last iteration.
123     auto args = loop.finalValue()
124                     ? conditionalBlock->getArguments()
125                     : conditionalBlock->getArguments().drop_front();
126     rewriter.replaceOp(loop, args.drop_back());
127     return success();
128   }
129 
130 private:
131   bool forceLoopToExecuteOnce;
132 };
133 
134 /// Convert `fir.if` to control-flow
135 class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
136 public:
137   using OpRewritePattern::OpRewritePattern;
138 
139   CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
140       : mlir::OpRewritePattern<fir::IfOp>(ctx),
141         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
142 
143   mlir::LogicalResult
144   matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override {
145     auto loc = ifOp.getLoc();
146 
147     // Split the block containing the 'fir.if' into two parts.  The part before
148     // will contain the condition, the part after will be the continuation
149     // point.
150     auto *condBlock = rewriter.getInsertionBlock();
151     auto opPosition = rewriter.getInsertionPoint();
152     auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
153     mlir::Block *continueBlock;
154     if (ifOp.getNumResults() == 0) {
155       continueBlock = remainingOpsBlock;
156     } else {
157       continueBlock =
158           rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
159       rewriter.create<mlir::BranchOp>(loc, remainingOpsBlock);
160     }
161 
162     // Move blocks from the "then" region to the region containing 'fir.if',
163     // place it before the continuation block, and branch to it.
164     auto &ifOpRegion = ifOp.thenRegion();
165     auto *ifOpBlock = &ifOpRegion.front();
166     auto *ifOpTerminator = ifOpRegion.back().getTerminator();
167     auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
168     rewriter.setInsertionPointToEnd(&ifOpRegion.back());
169     rewriter.create<mlir::BranchOp>(loc, continueBlock, ifOpTerminatorOperands);
170     rewriter.eraseOp(ifOpTerminator);
171     rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
172 
173     // Move blocks from the "else" region (if present) to the region containing
174     // 'fir.if', place it before the continuation block and branch to it.  It
175     // will be placed after the "then" regions.
176     auto *otherwiseBlock = continueBlock;
177     auto &otherwiseRegion = ifOp.elseRegion();
178     if (!otherwiseRegion.empty()) {
179       otherwiseBlock = &otherwiseRegion.front();
180       auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
181       auto otherwiseTermOperands = otherwiseTerm->getOperands();
182       rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
183       rewriter.create<mlir::BranchOp>(loc, continueBlock,
184                                       otherwiseTermOperands);
185       rewriter.eraseOp(otherwiseTerm);
186       rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
187     }
188 
189     rewriter.setInsertionPointToEnd(condBlock);
190     rewriter.create<mlir::CondBranchOp>(
191         loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
192         otherwiseBlock, llvm::ArrayRef<mlir::Value>());
193     rewriter.replaceOp(ifOp, continueBlock->getArguments());
194     return success();
195   }
196 
197 private:
198   bool forceLoopToExecuteOnce;
199 };
200 
201 /// Convert `fir.iter_while` to control-flow.
202 class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> {
203 public:
204   using OpRewritePattern::OpRewritePattern;
205 
206   CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
207       : mlir::OpRewritePattern<fir::IterWhileOp>(ctx),
208         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
209 
210   mlir::LogicalResult
211   matchAndRewrite(fir::IterWhileOp whileOp,
212                   mlir::PatternRewriter &rewriter) const override {
213     auto loc = whileOp.getLoc();
214 
215     // Start by splitting the block containing the 'fir.do_loop' into two parts.
216     // The part before will get the init code, the part after will be the end
217     // point.
218     auto *initBlock = rewriter.getInsertionBlock();
219     auto initPosition = rewriter.getInsertionPoint();
220     auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
221 
222     // Use the first block of the loop body as the condition block since it is
223     // the block that has the induction variable and loop-carried values as
224     // arguments. Split out all operations from the first block into a new
225     // block. Move all body blocks from the loop body region to the region
226     // containing the loop.
227     auto *conditionBlock = &whileOp.region().front();
228     auto *firstBodyBlock =
229         rewriter.splitBlock(conditionBlock, conditionBlock->begin());
230     auto *lastBodyBlock = &whileOp.region().back();
231     rewriter.inlineRegionBefore(whileOp.region(), endBlock);
232     auto iv = conditionBlock->getArgument(0);
233     auto iterateVar = conditionBlock->getArgument(1);
234 
235     // Append the induction variable stepping logic to the last body block and
236     // branch back to the condition block. Loop-carried values are taken from
237     // operands of the loop terminator.
238     auto *terminator = lastBodyBlock->getTerminator();
239     rewriter.setInsertionPointToEnd(lastBodyBlock);
240     auto step = whileOp.step();
241     mlir::Value stepped = rewriter.create<mlir::arith::AddIOp>(loc, iv, step);
242     assert(stepped && "must be a Value");
243 
244     llvm::SmallVector<mlir::Value> loopCarried;
245     loopCarried.push_back(stepped);
246     auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin())
247                                       : terminator->operand_begin();
248     loopCarried.append(begin, terminator->operand_end());
249     rewriter.create<mlir::BranchOp>(loc, conditionBlock, loopCarried);
250     rewriter.eraseOp(terminator);
251 
252     // Compute loop bounds before branching to the condition.
253     rewriter.setInsertionPointToEnd(initBlock);
254     auto lowerBound = whileOp.lowerBound();
255     auto upperBound = whileOp.upperBound();
256     assert(lowerBound && upperBound && "must be a Value");
257 
258     // The initial values of loop-carried values is obtained from the operands
259     // of the loop operation.
260     llvm::SmallVector<mlir::Value> destOperands;
261     destOperands.push_back(lowerBound);
262     auto iterOperands = whileOp.getIterOperands();
263     destOperands.append(iterOperands.begin(), iterOperands.end());
264     rewriter.create<mlir::BranchOp>(loc, conditionBlock, destOperands);
265 
266     // With the body block done, we can fill in the condition block.
267     rewriter.setInsertionPointToEnd(conditionBlock);
268     // The comparison depends on the sign of the step value. We fully expect
269     // this expression to be folded by the optimizer or LLVM. This expression
270     // is written this way so that `step == 0` always returns `false`.
271     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
272     auto compl0 = rewriter.create<mlir::arith::CmpIOp>(
273         loc, arith::CmpIPredicate::slt, zero, step);
274     auto compl1 = rewriter.create<mlir::arith::CmpIOp>(
275         loc, arith::CmpIPredicate::sle, iv, upperBound);
276     auto compl2 = rewriter.create<mlir::arith::CmpIOp>(
277         loc, arith::CmpIPredicate::slt, step, zero);
278     auto compl3 = rewriter.create<mlir::arith::CmpIOp>(
279         loc, arith::CmpIPredicate::sle, upperBound, iv);
280     auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1);
281     auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3);
282     auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1);
283     // Remember to AND in the early-exit bool.
284     auto comparison =
285         rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
286     rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBodyBlock,
287                                         llvm::ArrayRef<mlir::Value>(), endBlock,
288                                         llvm::ArrayRef<mlir::Value>());
289     // The result of the loop operation is the values of the condition block
290     // arguments except the induction variable on the last iteration.
291     auto args = whileOp.finalValue()
292                     ? conditionBlock->getArguments()
293                     : conditionBlock->getArguments().drop_front();
294     rewriter.replaceOp(whileOp, args);
295     return success();
296   }
297 
298 private:
299   bool forceLoopToExecuteOnce;
300 };
301 
302 /// Convert FIR structured control flow ops to CFG ops.
303 class CfgConversion : public CFGConversionBase<CfgConversion> {
304 public:
305   void runOnFunction() override {
306     auto *context = &getContext();
307     mlir::OwningRewritePatternList patterns(context);
308     patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
309         context, forceLoopToExecuteOnce);
310     mlir::ConversionTarget target(*context);
311     target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
312                            mlir::StandardOpsDialect>();
313 
314     // apply the patterns
315     target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
316     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
317     if (mlir::failed(mlir::applyPartialConversion(getFunction(), target,
318                                                   std::move(patterns)))) {
319       mlir::emitError(mlir::UnknownLoc::get(context),
320                       "error in converting to CFG\n");
321       signalPassFailure();
322     }
323   }
324 };
325 } // namespace
326 
327 /// Convert FIR's structured control flow ops to CFG ops.  This
328 /// conversion enables the `createLowerToCFGPass` to transform these to CFG
329 /// form.
330 std::unique_ptr<mlir::Pass> fir::createFirToCfgPass() {
331   return std::make_unique<CfgConversion>();
332 }
333