1 //===-- RewriteLoop.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Transforms/Passes.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/Support/CommandLine.h"
18 
19 using namespace fir;
20 
21 namespace {
22 
23 // Conversion of fir control ops to more primitive control-flow.
24 //
25 // FIR loops that cannot be converted to the affine dialect will remain as
26 // `fir.do_loop` operations.  These can be converted to control-flow operations.
27 
28 /// Convert `fir.do_loop` to CFG
29 class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
30 public:
31   using OpRewritePattern::OpRewritePattern;
32 
33   CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
34       : mlir::OpRewritePattern<fir::DoLoopOp>(ctx),
35         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
36 
37   mlir::LogicalResult
38   matchAndRewrite(DoLoopOp loop,
39                   mlir::PatternRewriter &rewriter) const override {
40     auto loc = loop.getLoc();
41 
42     // Create the start and end blocks that will wrap the DoLoopOp with an
43     // initalizer and an end point
44     auto *initBlock = rewriter.getInsertionBlock();
45     auto initPos = rewriter.getInsertionPoint();
46     auto *endBlock = rewriter.splitBlock(initBlock, initPos);
47 
48     // Split the first DoLoopOp block in two parts. The part before will be the
49     // conditional block since it already has the induction variable and
50     // loop-carried values as arguments.
51     auto *conditionalBlock = &loop.region().front();
52     conditionalBlock->addArgument(rewriter.getIndexType());
53     auto *firstBlock =
54         rewriter.splitBlock(conditionalBlock, conditionalBlock->begin());
55     auto *lastBlock = &loop.region().back();
56 
57     // Move the blocks from the DoLoopOp between initBlock and endBlock
58     rewriter.inlineRegionBefore(loop.region(), endBlock);
59 
60     // Get loop values from the DoLoopOp
61     auto low = loop.lowerBound();
62     auto high = loop.upperBound();
63     assert(low && high && "must be a Value");
64     auto step = loop.step();
65 
66     // Initalization block
67     rewriter.setInsertionPointToEnd(initBlock);
68     auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
69     auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
70     mlir::Value iters =
71         rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
72 
73     if (forceLoopToExecuteOnce) {
74       auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
75       auto cond = rewriter.create<mlir::arith::CmpIOp>(
76           loc, arith::CmpIPredicate::sle, iters, zero);
77       auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
78       iters = rewriter.create<mlir::SelectOp>(loc, cond, one, iters);
79     }
80 
81     llvm::SmallVector<mlir::Value> loopOperands;
82     loopOperands.push_back(low);
83     auto operands = loop.getIterOperands();
84     loopOperands.append(operands.begin(), operands.end());
85     loopOperands.push_back(iters);
86 
87     rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopOperands);
88 
89     // Last loop block
90     auto *terminator = lastBlock->getTerminator();
91     rewriter.setInsertionPointToEnd(lastBlock);
92     auto iv = conditionalBlock->getArgument(0);
93     mlir::Value steppedIndex =
94         rewriter.create<mlir::arith::AddIOp>(loc, iv, step);
95     assert(steppedIndex && "must be a Value");
96     auto lastArg = conditionalBlock->getNumArguments() - 1;
97     auto itersLeft = conditionalBlock->getArgument(lastArg);
98     auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
99     mlir::Value itersMinusOne =
100         rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one);
101 
102     llvm::SmallVector<mlir::Value> loopCarried;
103     loopCarried.push_back(steppedIndex);
104     auto begin = loop.finalValue() ? std::next(terminator->operand_begin())
105                                    : terminator->operand_begin();
106     loopCarried.append(begin, terminator->operand_end());
107     loopCarried.push_back(itersMinusOne);
108     rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopCarried);
109     rewriter.eraseOp(terminator);
110 
111     // Conditional block
112     rewriter.setInsertionPointToEnd(conditionalBlock);
113     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
114     auto comparison = rewriter.create<mlir::arith::CmpIOp>(
115         loc, arith::CmpIPredicate::sgt, itersLeft, zero);
116 
117     rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBlock,
118                                         llvm::ArrayRef<mlir::Value>(), endBlock,
119                                         llvm::ArrayRef<mlir::Value>());
120 
121     // The result of the loop operation is the values of the condition block
122     // arguments except the induction variable on the last iteration.
123     auto args = loop.finalValue()
124                     ? conditionalBlock->getArguments()
125                     : conditionalBlock->getArguments().drop_front();
126     rewriter.replaceOp(loop, args.drop_back());
127     return success();
128   }
129 
130 private:
131   bool forceLoopToExecuteOnce;
132 };
133 
134 /// Convert `fir.if` to control-flow
135 class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
136 public:
137   using OpRewritePattern::OpRewritePattern;
138 
139   CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
140       : mlir::OpRewritePattern<fir::IfOp>(ctx) {}
141 
142   mlir::LogicalResult
143   matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override {
144     auto loc = ifOp.getLoc();
145 
146     // Split the block containing the 'fir.if' into two parts.  The part before
147     // will contain the condition, the part after will be the continuation
148     // point.
149     auto *condBlock = rewriter.getInsertionBlock();
150     auto opPosition = rewriter.getInsertionPoint();
151     auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
152     mlir::Block *continueBlock;
153     if (ifOp.getNumResults() == 0) {
154       continueBlock = remainingOpsBlock;
155     } else {
156       continueBlock =
157           rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
158       rewriter.create<mlir::BranchOp>(loc, remainingOpsBlock);
159     }
160 
161     // Move blocks from the "then" region to the region containing 'fir.if',
162     // place it before the continuation block, and branch to it.
163     auto &ifOpRegion = ifOp.thenRegion();
164     auto *ifOpBlock = &ifOpRegion.front();
165     auto *ifOpTerminator = ifOpRegion.back().getTerminator();
166     auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
167     rewriter.setInsertionPointToEnd(&ifOpRegion.back());
168     rewriter.create<mlir::BranchOp>(loc, continueBlock, ifOpTerminatorOperands);
169     rewriter.eraseOp(ifOpTerminator);
170     rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
171 
172     // Move blocks from the "else" region (if present) to the region containing
173     // 'fir.if', place it before the continuation block and branch to it.  It
174     // will be placed after the "then" regions.
175     auto *otherwiseBlock = continueBlock;
176     auto &otherwiseRegion = ifOp.elseRegion();
177     if (!otherwiseRegion.empty()) {
178       otherwiseBlock = &otherwiseRegion.front();
179       auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
180       auto otherwiseTermOperands = otherwiseTerm->getOperands();
181       rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
182       rewriter.create<mlir::BranchOp>(loc, continueBlock,
183                                       otherwiseTermOperands);
184       rewriter.eraseOp(otherwiseTerm);
185       rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
186     }
187 
188     rewriter.setInsertionPointToEnd(condBlock);
189     rewriter.create<mlir::CondBranchOp>(
190         loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
191         otherwiseBlock, llvm::ArrayRef<mlir::Value>());
192     rewriter.replaceOp(ifOp, continueBlock->getArguments());
193     return success();
194   }
195 };
196 
197 /// Convert `fir.iter_while` to control-flow.
198 class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> {
199 public:
200   using OpRewritePattern::OpRewritePattern;
201 
202   CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
203       : mlir::OpRewritePattern<fir::IterWhileOp>(ctx) {}
204 
205   mlir::LogicalResult
206   matchAndRewrite(fir::IterWhileOp whileOp,
207                   mlir::PatternRewriter &rewriter) const override {
208     auto loc = whileOp.getLoc();
209 
210     // Start by splitting the block containing the 'fir.do_loop' into two parts.
211     // The part before will get the init code, the part after will be the end
212     // point.
213     auto *initBlock = rewriter.getInsertionBlock();
214     auto initPosition = rewriter.getInsertionPoint();
215     auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
216 
217     // Use the first block of the loop body as the condition block since it is
218     // the block that has the induction variable and loop-carried values as
219     // arguments. Split out all operations from the first block into a new
220     // block. Move all body blocks from the loop body region to the region
221     // containing the loop.
222     auto *conditionBlock = &whileOp.region().front();
223     auto *firstBodyBlock =
224         rewriter.splitBlock(conditionBlock, conditionBlock->begin());
225     auto *lastBodyBlock = &whileOp.region().back();
226     rewriter.inlineRegionBefore(whileOp.region(), endBlock);
227     auto iv = conditionBlock->getArgument(0);
228     auto iterateVar = conditionBlock->getArgument(1);
229 
230     // Append the induction variable stepping logic to the last body block and
231     // branch back to the condition block. Loop-carried values are taken from
232     // operands of the loop terminator.
233     auto *terminator = lastBodyBlock->getTerminator();
234     rewriter.setInsertionPointToEnd(lastBodyBlock);
235     auto step = whileOp.step();
236     mlir::Value stepped = rewriter.create<mlir::arith::AddIOp>(loc, iv, step);
237     assert(stepped && "must be a Value");
238 
239     llvm::SmallVector<mlir::Value> loopCarried;
240     loopCarried.push_back(stepped);
241     auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin())
242                                       : terminator->operand_begin();
243     loopCarried.append(begin, terminator->operand_end());
244     rewriter.create<mlir::BranchOp>(loc, conditionBlock, loopCarried);
245     rewriter.eraseOp(terminator);
246 
247     // Compute loop bounds before branching to the condition.
248     rewriter.setInsertionPointToEnd(initBlock);
249     auto lowerBound = whileOp.lowerBound();
250     auto upperBound = whileOp.upperBound();
251     assert(lowerBound && upperBound && "must be a Value");
252 
253     // The initial values of loop-carried values is obtained from the operands
254     // of the loop operation.
255     llvm::SmallVector<mlir::Value> destOperands;
256     destOperands.push_back(lowerBound);
257     auto iterOperands = whileOp.getIterOperands();
258     destOperands.append(iterOperands.begin(), iterOperands.end());
259     rewriter.create<mlir::BranchOp>(loc, conditionBlock, destOperands);
260 
261     // With the body block done, we can fill in the condition block.
262     rewriter.setInsertionPointToEnd(conditionBlock);
263     // The comparison depends on the sign of the step value. We fully expect
264     // this expression to be folded by the optimizer or LLVM. This expression
265     // is written this way so that `step == 0` always returns `false`.
266     auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
267     auto compl0 = rewriter.create<mlir::arith::CmpIOp>(
268         loc, arith::CmpIPredicate::slt, zero, step);
269     auto compl1 = rewriter.create<mlir::arith::CmpIOp>(
270         loc, arith::CmpIPredicate::sle, iv, upperBound);
271     auto compl2 = rewriter.create<mlir::arith::CmpIOp>(
272         loc, arith::CmpIPredicate::slt, step, zero);
273     auto compl3 = rewriter.create<mlir::arith::CmpIOp>(
274         loc, arith::CmpIPredicate::sle, upperBound, iv);
275     auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1);
276     auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3);
277     auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1);
278     // Remember to AND in the early-exit bool.
279     auto comparison =
280         rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
281     rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBodyBlock,
282                                         llvm::ArrayRef<mlir::Value>(), endBlock,
283                                         llvm::ArrayRef<mlir::Value>());
284     // The result of the loop operation is the values of the condition block
285     // arguments except the induction variable on the last iteration.
286     auto args = whileOp.finalValue()
287                     ? conditionBlock->getArguments()
288                     : conditionBlock->getArguments().drop_front();
289     rewriter.replaceOp(whileOp, args);
290     return success();
291   }
292 };
293 
294 /// Convert FIR structured control flow ops to CFG ops.
295 class CfgConversion : public CFGConversionBase<CfgConversion> {
296 public:
297   void runOnFunction() override {
298     auto *context = &getContext();
299     mlir::OwningRewritePatternList patterns(context);
300     patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
301         context, forceLoopToExecuteOnce);
302     mlir::ConversionTarget target(*context);
303     target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
304                            mlir::StandardOpsDialect>();
305 
306     // apply the patterns
307     target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
308     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
309     if (mlir::failed(mlir::applyPartialConversion(getFunction(), target,
310                                                   std::move(patterns)))) {
311       mlir::emitError(mlir::UnknownLoc::get(context),
312                       "error in converting to CFG\n");
313       signalPassFailure();
314     }
315   }
316 };
317 } // namespace
318 
319 /// Convert FIR's structured control flow ops to CFG ops.  This
320 /// conversion enables the `createLowerToCFGPass` to transform these to CFG
321 /// form.
322 std::unique_ptr<mlir::Pass> fir::createFirToCfgPass() {
323   return std::make_unique<CfgConversion>();
324 }
325