1 //===-- RewriteLoop.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Dialect/FIROps.h"
12 #include "flang/Optimizer/Transforms/Passes.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/Support/CommandLine.h"
18 
19 using namespace fir;
20 
21 namespace {
22 
23 // Conversion of fir control ops to more primitive control-flow.
24 //
25 // FIR loops that cannot be converted to the affine dialect will remain as
26 // `fir.do_loop` operations.  These can be converted to control-flow operations.
27 
28 /// Convert `fir.do_loop` to CFG
29 class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
30 public:
31   using OpRewritePattern::OpRewritePattern;
32 
33   CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
34       : mlir::OpRewritePattern<fir::DoLoopOp>(ctx),
35         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
36 
37   mlir::LogicalResult
38   matchAndRewrite(DoLoopOp loop,
39                   mlir::PatternRewriter &rewriter) const override {
40     auto loc = loop.getLoc();
41 
42     // Create the start and end blocks that will wrap the DoLoopOp with an
43     // initalizer and an end point
44     auto *initBlock = rewriter.getInsertionBlock();
45     auto initPos = rewriter.getInsertionPoint();
46     auto *endBlock = rewriter.splitBlock(initBlock, initPos);
47 
48     // Split the first DoLoopOp block in two parts. The part before will be the
49     // conditional block since it already has the induction variable and
50     // loop-carried values as arguments.
51     auto *conditionalBlock = &loop.region().front();
52     conditionalBlock->addArgument(rewriter.getIndexType());
53     auto *firstBlock =
54         rewriter.splitBlock(conditionalBlock, conditionalBlock->begin());
55     auto *lastBlock = &loop.region().back();
56 
57     // Move the blocks from the DoLoopOp between initBlock and endBlock
58     rewriter.inlineRegionBefore(loop.region(), endBlock);
59 
60     // Get loop values from the DoLoopOp
61     auto low = loop.lowerBound();
62     auto high = loop.upperBound();
63     assert(low && high && "must be a Value");
64     auto step = loop.step();
65 
66     // Initalization block
67     rewriter.setInsertionPointToEnd(initBlock);
68     auto diff = rewriter.create<mlir::SubIOp>(loc, high, low);
69     auto distance = rewriter.create<mlir::AddIOp>(loc, diff, step);
70     mlir::Value iters =
71         rewriter.create<mlir::SignedDivIOp>(loc, distance, step);
72 
73     if (forceLoopToExecuteOnce) {
74       auto zero = rewriter.create<mlir::ConstantIndexOp>(loc, 0);
75       auto cond =
76           rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sle, iters, zero);
77       auto one = rewriter.create<mlir::ConstantIndexOp>(loc, 1);
78       iters = rewriter.create<mlir::SelectOp>(loc, cond, one, iters);
79     }
80 
81     llvm::SmallVector<mlir::Value> loopOperands;
82     loopOperands.push_back(low);
83     auto operands = loop.getIterOperands();
84     loopOperands.append(operands.begin(), operands.end());
85     loopOperands.push_back(iters);
86 
87     rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopOperands);
88 
89     // Last loop block
90     auto *terminator = lastBlock->getTerminator();
91     rewriter.setInsertionPointToEnd(lastBlock);
92     auto iv = conditionalBlock->getArgument(0);
93     mlir::Value steppedIndex = rewriter.create<mlir::AddIOp>(loc, iv, step);
94     assert(steppedIndex && "must be a Value");
95     auto lastArg = conditionalBlock->getNumArguments() - 1;
96     auto itersLeft = conditionalBlock->getArgument(lastArg);
97     auto one = rewriter.create<mlir::ConstantIndexOp>(loc, 1);
98     mlir::Value itersMinusOne =
99         rewriter.create<mlir::SubIOp>(loc, itersLeft, one);
100 
101     llvm::SmallVector<mlir::Value> loopCarried;
102     loopCarried.push_back(steppedIndex);
103     auto begin = loop.finalValue() ? std::next(terminator->operand_begin())
104                                    : terminator->operand_begin();
105     loopCarried.append(begin, terminator->operand_end());
106     loopCarried.push_back(itersMinusOne);
107     rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopCarried);
108     rewriter.eraseOp(terminator);
109 
110     // Conditional block
111     rewriter.setInsertionPointToEnd(conditionalBlock);
112     auto zero = rewriter.create<mlir::ConstantIndexOp>(loc, 0);
113     auto comparison =
114         rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, itersLeft, zero);
115 
116     rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBlock,
117                                         llvm::ArrayRef<mlir::Value>(), endBlock,
118                                         llvm::ArrayRef<mlir::Value>());
119 
120     // The result of the loop operation is the values of the condition block
121     // arguments except the induction variable on the last iteration.
122     auto args = loop.finalValue()
123                     ? conditionalBlock->getArguments()
124                     : conditionalBlock->getArguments().drop_front();
125     rewriter.replaceOp(loop, args.drop_back());
126     return success();
127   }
128 
129 private:
130   bool forceLoopToExecuteOnce;
131 };
132 
133 /// Convert `fir.if` to control-flow
134 class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
135 public:
136   using OpRewritePattern::OpRewritePattern;
137 
138   CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
139       : mlir::OpRewritePattern<fir::IfOp>(ctx),
140         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
141 
142   mlir::LogicalResult
143   matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override {
144     auto loc = ifOp.getLoc();
145 
146     // Split the block containing the 'fir.if' into two parts.  The part before
147     // will contain the condition, the part after will be the continuation
148     // point.
149     auto *condBlock = rewriter.getInsertionBlock();
150     auto opPosition = rewriter.getInsertionPoint();
151     auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
152     mlir::Block *continueBlock;
153     if (ifOp.getNumResults() == 0) {
154       continueBlock = remainingOpsBlock;
155     } else {
156       continueBlock =
157           rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
158       rewriter.create<mlir::BranchOp>(loc, remainingOpsBlock);
159     }
160 
161     // Move blocks from the "then" region to the region containing 'fir.if',
162     // place it before the continuation block, and branch to it.
163     auto &ifOpRegion = ifOp.thenRegion();
164     auto *ifOpBlock = &ifOpRegion.front();
165     auto *ifOpTerminator = ifOpRegion.back().getTerminator();
166     auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
167     rewriter.setInsertionPointToEnd(&ifOpRegion.back());
168     rewriter.create<mlir::BranchOp>(loc, continueBlock, ifOpTerminatorOperands);
169     rewriter.eraseOp(ifOpTerminator);
170     rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
171 
172     // Move blocks from the "else" region (if present) to the region containing
173     // 'fir.if', place it before the continuation block and branch to it.  It
174     // will be placed after the "then" regions.
175     auto *otherwiseBlock = continueBlock;
176     auto &otherwiseRegion = ifOp.elseRegion();
177     if (!otherwiseRegion.empty()) {
178       otherwiseBlock = &otherwiseRegion.front();
179       auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
180       auto otherwiseTermOperands = otherwiseTerm->getOperands();
181       rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
182       rewriter.create<mlir::BranchOp>(loc, continueBlock,
183                                       otherwiseTermOperands);
184       rewriter.eraseOp(otherwiseTerm);
185       rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
186     }
187 
188     rewriter.setInsertionPointToEnd(condBlock);
189     rewriter.create<mlir::CondBranchOp>(
190         loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
191         otherwiseBlock, llvm::ArrayRef<mlir::Value>());
192     rewriter.replaceOp(ifOp, continueBlock->getArguments());
193     return success();
194   }
195 
196 private:
197   bool forceLoopToExecuteOnce;
198 };
199 
200 /// Convert `fir.iter_while` to control-flow.
201 class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> {
202 public:
203   using OpRewritePattern::OpRewritePattern;
204 
205   CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce)
206       : mlir::OpRewritePattern<fir::IterWhileOp>(ctx),
207         forceLoopToExecuteOnce(forceLoopToExecuteOnce) {}
208 
209   mlir::LogicalResult
210   matchAndRewrite(fir::IterWhileOp whileOp,
211                   mlir::PatternRewriter &rewriter) const override {
212     auto loc = whileOp.getLoc();
213 
214     // Start by splitting the block containing the 'fir.do_loop' into two parts.
215     // The part before will get the init code, the part after will be the end
216     // point.
217     auto *initBlock = rewriter.getInsertionBlock();
218     auto initPosition = rewriter.getInsertionPoint();
219     auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
220 
221     // Use the first block of the loop body as the condition block since it is
222     // the block that has the induction variable and loop-carried values as
223     // arguments. Split out all operations from the first block into a new
224     // block. Move all body blocks from the loop body region to the region
225     // containing the loop.
226     auto *conditionBlock = &whileOp.region().front();
227     auto *firstBodyBlock =
228         rewriter.splitBlock(conditionBlock, conditionBlock->begin());
229     auto *lastBodyBlock = &whileOp.region().back();
230     rewriter.inlineRegionBefore(whileOp.region(), endBlock);
231     auto iv = conditionBlock->getArgument(0);
232     auto iterateVar = conditionBlock->getArgument(1);
233 
234     // Append the induction variable stepping logic to the last body block and
235     // branch back to the condition block. Loop-carried values are taken from
236     // operands of the loop terminator.
237     auto *terminator = lastBodyBlock->getTerminator();
238     rewriter.setInsertionPointToEnd(lastBodyBlock);
239     auto step = whileOp.step();
240     mlir::Value stepped = rewriter.create<mlir::AddIOp>(loc, iv, step);
241     assert(stepped && "must be a Value");
242 
243     llvm::SmallVector<mlir::Value> loopCarried;
244     loopCarried.push_back(stepped);
245     auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin())
246                                       : terminator->operand_begin();
247     loopCarried.append(begin, terminator->operand_end());
248     rewriter.create<mlir::BranchOp>(loc, conditionBlock, loopCarried);
249     rewriter.eraseOp(terminator);
250 
251     // Compute loop bounds before branching to the condition.
252     rewriter.setInsertionPointToEnd(initBlock);
253     auto lowerBound = whileOp.lowerBound();
254     auto upperBound = whileOp.upperBound();
255     assert(lowerBound && upperBound && "must be a Value");
256 
257     // The initial values of loop-carried values is obtained from the operands
258     // of the loop operation.
259     llvm::SmallVector<mlir::Value> destOperands;
260     destOperands.push_back(lowerBound);
261     auto iterOperands = whileOp.getIterOperands();
262     destOperands.append(iterOperands.begin(), iterOperands.end());
263     rewriter.create<mlir::BranchOp>(loc, conditionBlock, destOperands);
264 
265     // With the body block done, we can fill in the condition block.
266     rewriter.setInsertionPointToEnd(conditionBlock);
267     // The comparison depends on the sign of the step value. We fully expect
268     // this expression to be folded by the optimizer or LLVM. This expression
269     // is written this way so that `step == 0` always returns `false`.
270     auto zero = rewriter.create<mlir::ConstantIndexOp>(loc, 0);
271     auto compl0 =
272         rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt, zero, step);
273     auto compl1 =
274         rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sle, iv, upperBound);
275     auto compl2 =
276         rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt, step, zero);
277     auto compl3 =
278         rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sle, upperBound, iv);
279     auto cmp0 = rewriter.create<mlir::AndOp>(loc, compl0, compl1);
280     auto cmp1 = rewriter.create<mlir::AndOp>(loc, compl2, compl3);
281     auto cmp2 = rewriter.create<mlir::OrOp>(loc, cmp0, cmp1);
282     // Remember to AND in the early-exit bool.
283     auto comparison = rewriter.create<mlir::AndOp>(loc, iterateVar, cmp2);
284     rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBodyBlock,
285                                         llvm::ArrayRef<mlir::Value>(), endBlock,
286                                         llvm::ArrayRef<mlir::Value>());
287     // The result of the loop operation is the values of the condition block
288     // arguments except the induction variable on the last iteration.
289     auto args = whileOp.finalValue()
290                     ? conditionBlock->getArguments()
291                     : conditionBlock->getArguments().drop_front();
292     rewriter.replaceOp(whileOp, args);
293     return success();
294   }
295 
296 private:
297   bool forceLoopToExecuteOnce;
298 };
299 
300 /// Convert FIR structured control flow ops to CFG ops.
301 class CfgConversion : public CFGConversionBase<CfgConversion> {
302 public:
303   void runOnFunction() override {
304     auto *context = &getContext();
305     mlir::OwningRewritePatternList patterns(context);
306     patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
307         context, forceLoopToExecuteOnce);
308     mlir::ConversionTarget target(*context);
309     target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
310                            mlir::StandardOpsDialect>();
311 
312     // apply the patterns
313     target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
314     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
315     if (mlir::failed(mlir::applyPartialConversion(getFunction(), target,
316                                                   std::move(patterns)))) {
317       mlir::emitError(mlir::UnknownLoc::get(context),
318                       "error in converting to CFG\n");
319       signalPassFailure();
320     }
321   }
322 };
323 } // namespace
324 
325 /// Convert FIR's structured control flow ops to CFG ops.  This
326 /// conversion enables the `createLowerToCFGPass` to transform these to CFG
327 /// form.
328 std::unique_ptr<mlir::Pass> fir::createFirToCfgPass() {
329   return std::make_unique<CfgConversion>();
330 }
331