1 //===-- RewriteLoop.cpp ---------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Dialect/FIRDialect.h" 11 #include "flang/Optimizer/Dialect/FIROps.h" 12 #include "flang/Optimizer/Transforms/Passes.h" 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "llvm/Support/CommandLine.h" 19 20 using namespace fir; 21 22 namespace { 23 24 // Conversion of fir control ops to more primitive control-flow. 25 // 26 // FIR loops that cannot be converted to the affine dialect will remain as 27 // `fir.do_loop` operations. These can be converted to control-flow operations. 28 29 /// Convert `fir.do_loop` to CFG 30 class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> { 31 public: 32 using OpRewritePattern::OpRewritePattern; 33 34 CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 35 : mlir::OpRewritePattern<fir::DoLoopOp>(ctx), 36 forceLoopToExecuteOnce(forceLoopToExecuteOnce) {} 37 38 mlir::LogicalResult 39 matchAndRewrite(DoLoopOp loop, 40 mlir::PatternRewriter &rewriter) const override { 41 auto loc = loop.getLoc(); 42 43 // Create the start and end blocks that will wrap the DoLoopOp with an 44 // initalizer and an end point 45 auto *initBlock = rewriter.getInsertionBlock(); 46 auto initPos = rewriter.getInsertionPoint(); 47 auto *endBlock = rewriter.splitBlock(initBlock, initPos); 48 49 // Split the first DoLoopOp block in two parts. The part before will be the 50 // conditional block since it already has the induction variable and 51 // loop-carried values as arguments. 52 auto *conditionalBlock = &loop.getRegion().front(); 53 conditionalBlock->addArgument(rewriter.getIndexType(), loc); 54 auto *firstBlock = 55 rewriter.splitBlock(conditionalBlock, conditionalBlock->begin()); 56 auto *lastBlock = &loop.getRegion().back(); 57 58 // Move the blocks from the DoLoopOp between initBlock and endBlock 59 rewriter.inlineRegionBefore(loop.getRegion(), endBlock); 60 61 // Get loop values from the DoLoopOp 62 auto low = loop.getLowerBound(); 63 auto high = loop.getUpperBound(); 64 assert(low && high && "must be a Value"); 65 auto step = loop.getStep(); 66 67 // Initalization block 68 rewriter.setInsertionPointToEnd(initBlock); 69 auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low); 70 auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step); 71 mlir::Value iters = 72 rewriter.create<mlir::arith::DivSIOp>(loc, distance, step); 73 74 if (forceLoopToExecuteOnce) { 75 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 76 auto cond = rewriter.create<mlir::arith::CmpIOp>( 77 loc, arith::CmpIPredicate::sle, iters, zero); 78 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 79 iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters); 80 } 81 82 llvm::SmallVector<mlir::Value> loopOperands; 83 loopOperands.push_back(low); 84 auto operands = loop.getIterOperands(); 85 loopOperands.append(operands.begin(), operands.end()); 86 loopOperands.push_back(iters); 87 88 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands); 89 90 // Last loop block 91 auto *terminator = lastBlock->getTerminator(); 92 rewriter.setInsertionPointToEnd(lastBlock); 93 auto iv = conditionalBlock->getArgument(0); 94 mlir::Value steppedIndex = 95 rewriter.create<mlir::arith::AddIOp>(loc, iv, step); 96 assert(steppedIndex && "must be a Value"); 97 auto lastArg = conditionalBlock->getNumArguments() - 1; 98 auto itersLeft = conditionalBlock->getArgument(lastArg); 99 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 100 mlir::Value itersMinusOne = 101 rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one); 102 103 llvm::SmallVector<mlir::Value> loopCarried; 104 loopCarried.push_back(steppedIndex); 105 auto begin = loop.getFinalValue() ? std::next(terminator->operand_begin()) 106 : terminator->operand_begin(); 107 loopCarried.append(begin, terminator->operand_end()); 108 loopCarried.push_back(itersMinusOne); 109 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried); 110 rewriter.eraseOp(terminator); 111 112 // Conditional block 113 rewriter.setInsertionPointToEnd(conditionalBlock); 114 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 115 auto comparison = rewriter.create<mlir::arith::CmpIOp>( 116 loc, arith::CmpIPredicate::sgt, itersLeft, zero); 117 118 rewriter.create<mlir::cf::CondBranchOp>( 119 loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock, 120 llvm::ArrayRef<mlir::Value>()); 121 122 // The result of the loop operation is the values of the condition block 123 // arguments except the induction variable on the last iteration. 124 auto args = loop.getFinalValue() 125 ? conditionalBlock->getArguments() 126 : conditionalBlock->getArguments().drop_front(); 127 rewriter.replaceOp(loop, args.drop_back()); 128 return success(); 129 } 130 131 private: 132 bool forceLoopToExecuteOnce; 133 }; 134 135 /// Convert `fir.if` to control-flow 136 class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> { 137 public: 138 using OpRewritePattern::OpRewritePattern; 139 140 CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 141 : mlir::OpRewritePattern<fir::IfOp>(ctx) {} 142 143 mlir::LogicalResult 144 matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override { 145 auto loc = ifOp.getLoc(); 146 147 // Split the block containing the 'fir.if' into two parts. The part before 148 // will contain the condition, the part after will be the continuation 149 // point. 150 auto *condBlock = rewriter.getInsertionBlock(); 151 auto opPosition = rewriter.getInsertionPoint(); 152 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 153 mlir::Block *continueBlock; 154 if (ifOp.getNumResults() == 0) { 155 continueBlock = remainingOpsBlock; 156 } else { 157 continueBlock = 158 rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes()); 159 rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock); 160 } 161 162 // Move blocks from the "then" region to the region containing 'fir.if', 163 // place it before the continuation block, and branch to it. 164 auto &ifOpRegion = ifOp.getThenRegion(); 165 auto *ifOpBlock = &ifOpRegion.front(); 166 auto *ifOpTerminator = ifOpRegion.back().getTerminator(); 167 auto ifOpTerminatorOperands = ifOpTerminator->getOperands(); 168 rewriter.setInsertionPointToEnd(&ifOpRegion.back()); 169 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, 170 ifOpTerminatorOperands); 171 rewriter.eraseOp(ifOpTerminator); 172 rewriter.inlineRegionBefore(ifOpRegion, continueBlock); 173 174 // Move blocks from the "else" region (if present) to the region containing 175 // 'fir.if', place it before the continuation block and branch to it. It 176 // will be placed after the "then" regions. 177 auto *otherwiseBlock = continueBlock; 178 auto &otherwiseRegion = ifOp.getElseRegion(); 179 if (!otherwiseRegion.empty()) { 180 otherwiseBlock = &otherwiseRegion.front(); 181 auto *otherwiseTerm = otherwiseRegion.back().getTerminator(); 182 auto otherwiseTermOperands = otherwiseTerm->getOperands(); 183 rewriter.setInsertionPointToEnd(&otherwiseRegion.back()); 184 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, 185 otherwiseTermOperands); 186 rewriter.eraseOp(otherwiseTerm); 187 rewriter.inlineRegionBefore(otherwiseRegion, continueBlock); 188 } 189 190 rewriter.setInsertionPointToEnd(condBlock); 191 rewriter.create<mlir::cf::CondBranchOp>( 192 loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(), 193 otherwiseBlock, llvm::ArrayRef<mlir::Value>()); 194 rewriter.replaceOp(ifOp, continueBlock->getArguments()); 195 return success(); 196 } 197 }; 198 199 /// Convert `fir.iter_while` to control-flow. 200 class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> { 201 public: 202 using OpRewritePattern::OpRewritePattern; 203 204 CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 205 : mlir::OpRewritePattern<fir::IterWhileOp>(ctx) {} 206 207 mlir::LogicalResult 208 matchAndRewrite(fir::IterWhileOp whileOp, 209 mlir::PatternRewriter &rewriter) const override { 210 auto loc = whileOp.getLoc(); 211 212 // Start by splitting the block containing the 'fir.do_loop' into two parts. 213 // The part before will get the init code, the part after will be the end 214 // point. 215 auto *initBlock = rewriter.getInsertionBlock(); 216 auto initPosition = rewriter.getInsertionPoint(); 217 auto *endBlock = rewriter.splitBlock(initBlock, initPosition); 218 219 // Use the first block of the loop body as the condition block since it is 220 // the block that has the induction variable and loop-carried values as 221 // arguments. Split out all operations from the first block into a new 222 // block. Move all body blocks from the loop body region to the region 223 // containing the loop. 224 auto *conditionBlock = &whileOp.getRegion().front(); 225 auto *firstBodyBlock = 226 rewriter.splitBlock(conditionBlock, conditionBlock->begin()); 227 auto *lastBodyBlock = &whileOp.getRegion().back(); 228 rewriter.inlineRegionBefore(whileOp.getRegion(), endBlock); 229 auto iv = conditionBlock->getArgument(0); 230 auto iterateVar = conditionBlock->getArgument(1); 231 232 // Append the induction variable stepping logic to the last body block and 233 // branch back to the condition block. Loop-carried values are taken from 234 // operands of the loop terminator. 235 auto *terminator = lastBodyBlock->getTerminator(); 236 rewriter.setInsertionPointToEnd(lastBodyBlock); 237 auto step = whileOp.getStep(); 238 mlir::Value stepped = rewriter.create<mlir::arith::AddIOp>(loc, iv, step); 239 assert(stepped && "must be a Value"); 240 241 llvm::SmallVector<mlir::Value> loopCarried; 242 loopCarried.push_back(stepped); 243 auto begin = whileOp.getFinalValue() 244 ? std::next(terminator->operand_begin()) 245 : terminator->operand_begin(); 246 loopCarried.append(begin, terminator->operand_end()); 247 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried); 248 rewriter.eraseOp(terminator); 249 250 // Compute loop bounds before branching to the condition. 251 rewriter.setInsertionPointToEnd(initBlock); 252 auto lowerBound = whileOp.getLowerBound(); 253 auto upperBound = whileOp.getUpperBound(); 254 assert(lowerBound && upperBound && "must be a Value"); 255 256 // The initial values of loop-carried values is obtained from the operands 257 // of the loop operation. 258 llvm::SmallVector<mlir::Value> destOperands; 259 destOperands.push_back(lowerBound); 260 auto iterOperands = whileOp.getIterOperands(); 261 destOperands.append(iterOperands.begin(), iterOperands.end()); 262 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands); 263 264 // With the body block done, we can fill in the condition block. 265 rewriter.setInsertionPointToEnd(conditionBlock); 266 // The comparison depends on the sign of the step value. We fully expect 267 // this expression to be folded by the optimizer or LLVM. This expression 268 // is written this way so that `step == 0` always returns `false`. 269 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 270 auto compl0 = rewriter.create<mlir::arith::CmpIOp>( 271 loc, arith::CmpIPredicate::slt, zero, step); 272 auto compl1 = rewriter.create<mlir::arith::CmpIOp>( 273 loc, arith::CmpIPredicate::sle, iv, upperBound); 274 auto compl2 = rewriter.create<mlir::arith::CmpIOp>( 275 loc, arith::CmpIPredicate::slt, step, zero); 276 auto compl3 = rewriter.create<mlir::arith::CmpIOp>( 277 loc, arith::CmpIPredicate::sle, upperBound, iv); 278 auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1); 279 auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3); 280 auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1); 281 // Remember to AND in the early-exit bool. 282 auto comparison = 283 rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2); 284 rewriter.create<mlir::cf::CondBranchOp>( 285 loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(), 286 endBlock, llvm::ArrayRef<mlir::Value>()); 287 // The result of the loop operation is the values of the condition block 288 // arguments except the induction variable on the last iteration. 289 auto args = whileOp.getFinalValue() 290 ? conditionBlock->getArguments() 291 : conditionBlock->getArguments().drop_front(); 292 rewriter.replaceOp(whileOp, args); 293 return success(); 294 } 295 }; 296 297 /// Convert FIR structured control flow ops to CFG ops. 298 class CfgConversion : public CFGConversionBase<CfgConversion> { 299 public: 300 void runOnOperation() override { 301 auto *context = &getContext(); 302 mlir::RewritePatternSet patterns(context); 303 patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>( 304 context, forceLoopToExecuteOnce); 305 mlir::ConversionTarget target(*context); 306 target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect, 307 FIROpsDialect, mlir::func::FuncDialect>(); 308 309 // apply the patterns 310 target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>(); 311 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 312 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, 313 std::move(patterns)))) { 314 mlir::emitError(mlir::UnknownLoc::get(context), 315 "error in converting to CFG\n"); 316 signalPassFailure(); 317 } 318 } 319 }; 320 } // namespace 321 322 /// Convert FIR's structured control flow ops to CFG ops. This 323 /// conversion enables the `createLowerToCFGPass` to transform these to CFG 324 /// form. 325 std::unique_ptr<mlir::Pass> fir::createFirToCfgPass() { 326 return std::make_unique<CfgConversion>(); 327 } 328