1 //===-- RewriteLoop.cpp ---------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Dialect/FIRDialect.h" 11 #include "flang/Optimizer/Dialect/FIROps.h" 12 #include "flang/Optimizer/Transforms/Passes.h" 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "llvm/Support/CommandLine.h" 19 20 using namespace fir; 21 using namespace mlir; 22 23 namespace { 24 25 // Conversion of fir control ops to more primitive control-flow. 26 // 27 // FIR loops that cannot be converted to the affine dialect will remain as 28 // `fir.do_loop` operations. These can be converted to control-flow operations. 29 30 /// Convert `fir.do_loop` to CFG 31 class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> { 32 public: 33 using OpRewritePattern::OpRewritePattern; 34 35 CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 36 : mlir::OpRewritePattern<fir::DoLoopOp>(ctx), 37 forceLoopToExecuteOnce(forceLoopToExecuteOnce) {} 38 39 mlir::LogicalResult 40 matchAndRewrite(DoLoopOp loop, 41 mlir::PatternRewriter &rewriter) const override { 42 auto loc = loop.getLoc(); 43 44 // Create the start and end blocks that will wrap the DoLoopOp with an 45 // initalizer and an end point 46 auto *initBlock = rewriter.getInsertionBlock(); 47 auto initPos = rewriter.getInsertionPoint(); 48 auto *endBlock = rewriter.splitBlock(initBlock, initPos); 49 50 // Split the first DoLoopOp block in two parts. The part before will be the 51 // conditional block since it already has the induction variable and 52 // loop-carried values as arguments. 53 auto *conditionalBlock = &loop.getRegion().front(); 54 conditionalBlock->addArgument(rewriter.getIndexType(), loc); 55 auto *firstBlock = 56 rewriter.splitBlock(conditionalBlock, conditionalBlock->begin()); 57 auto *lastBlock = &loop.getRegion().back(); 58 59 // Move the blocks from the DoLoopOp between initBlock and endBlock 60 rewriter.inlineRegionBefore(loop.getRegion(), endBlock); 61 62 // Get loop values from the DoLoopOp 63 auto low = loop.getLowerBound(); 64 auto high = loop.getUpperBound(); 65 assert(low && high && "must be a Value"); 66 auto step = loop.getStep(); 67 68 // Initalization block 69 rewriter.setInsertionPointToEnd(initBlock); 70 auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low); 71 auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step); 72 mlir::Value iters = 73 rewriter.create<mlir::arith::DivSIOp>(loc, distance, step); 74 75 if (forceLoopToExecuteOnce) { 76 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 77 auto cond = rewriter.create<mlir::arith::CmpIOp>( 78 loc, arith::CmpIPredicate::sle, iters, zero); 79 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 80 iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters); 81 } 82 83 llvm::SmallVector<mlir::Value> loopOperands; 84 loopOperands.push_back(low); 85 auto operands = loop.getIterOperands(); 86 loopOperands.append(operands.begin(), operands.end()); 87 loopOperands.push_back(iters); 88 89 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands); 90 91 // Last loop block 92 auto *terminator = lastBlock->getTerminator(); 93 rewriter.setInsertionPointToEnd(lastBlock); 94 auto iv = conditionalBlock->getArgument(0); 95 mlir::Value steppedIndex = 96 rewriter.create<mlir::arith::AddIOp>(loc, iv, step); 97 assert(steppedIndex && "must be a Value"); 98 auto lastArg = conditionalBlock->getNumArguments() - 1; 99 auto itersLeft = conditionalBlock->getArgument(lastArg); 100 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 101 mlir::Value itersMinusOne = 102 rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one); 103 104 llvm::SmallVector<mlir::Value> loopCarried; 105 loopCarried.push_back(steppedIndex); 106 auto begin = loop.getFinalValue() ? std::next(terminator->operand_begin()) 107 : terminator->operand_begin(); 108 loopCarried.append(begin, terminator->operand_end()); 109 loopCarried.push_back(itersMinusOne); 110 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried); 111 rewriter.eraseOp(terminator); 112 113 // Conditional block 114 rewriter.setInsertionPointToEnd(conditionalBlock); 115 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 116 auto comparison = rewriter.create<mlir::arith::CmpIOp>( 117 loc, arith::CmpIPredicate::sgt, itersLeft, zero); 118 119 rewriter.create<mlir::cf::CondBranchOp>( 120 loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock, 121 llvm::ArrayRef<mlir::Value>()); 122 123 // The result of the loop operation is the values of the condition block 124 // arguments except the induction variable on the last iteration. 125 auto args = loop.getFinalValue() 126 ? conditionalBlock->getArguments() 127 : conditionalBlock->getArguments().drop_front(); 128 rewriter.replaceOp(loop, args.drop_back()); 129 return success(); 130 } 131 132 private: 133 bool forceLoopToExecuteOnce; 134 }; 135 136 /// Convert `fir.if` to control-flow 137 class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> { 138 public: 139 using OpRewritePattern::OpRewritePattern; 140 141 CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 142 : mlir::OpRewritePattern<fir::IfOp>(ctx) {} 143 144 mlir::LogicalResult 145 matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override { 146 auto loc = ifOp.getLoc(); 147 148 // Split the block containing the 'fir.if' into two parts. The part before 149 // will contain the condition, the part after will be the continuation 150 // point. 151 auto *condBlock = rewriter.getInsertionBlock(); 152 auto opPosition = rewriter.getInsertionPoint(); 153 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 154 mlir::Block *continueBlock; 155 if (ifOp.getNumResults() == 0) { 156 continueBlock = remainingOpsBlock; 157 } else { 158 continueBlock = 159 rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes()); 160 rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock); 161 } 162 163 // Move blocks from the "then" region to the region containing 'fir.if', 164 // place it before the continuation block, and branch to it. 165 auto &ifOpRegion = ifOp.getThenRegion(); 166 auto *ifOpBlock = &ifOpRegion.front(); 167 auto *ifOpTerminator = ifOpRegion.back().getTerminator(); 168 auto ifOpTerminatorOperands = ifOpTerminator->getOperands(); 169 rewriter.setInsertionPointToEnd(&ifOpRegion.back()); 170 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, 171 ifOpTerminatorOperands); 172 rewriter.eraseOp(ifOpTerminator); 173 rewriter.inlineRegionBefore(ifOpRegion, continueBlock); 174 175 // Move blocks from the "else" region (if present) to the region containing 176 // 'fir.if', place it before the continuation block and branch to it. It 177 // will be placed after the "then" regions. 178 auto *otherwiseBlock = continueBlock; 179 auto &otherwiseRegion = ifOp.getElseRegion(); 180 if (!otherwiseRegion.empty()) { 181 otherwiseBlock = &otherwiseRegion.front(); 182 auto *otherwiseTerm = otherwiseRegion.back().getTerminator(); 183 auto otherwiseTermOperands = otherwiseTerm->getOperands(); 184 rewriter.setInsertionPointToEnd(&otherwiseRegion.back()); 185 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, 186 otherwiseTermOperands); 187 rewriter.eraseOp(otherwiseTerm); 188 rewriter.inlineRegionBefore(otherwiseRegion, continueBlock); 189 } 190 191 rewriter.setInsertionPointToEnd(condBlock); 192 rewriter.create<mlir::cf::CondBranchOp>( 193 loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(), 194 otherwiseBlock, llvm::ArrayRef<mlir::Value>()); 195 rewriter.replaceOp(ifOp, continueBlock->getArguments()); 196 return success(); 197 } 198 }; 199 200 /// Convert `fir.iter_while` to control-flow. 201 class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> { 202 public: 203 using OpRewritePattern::OpRewritePattern; 204 205 CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 206 : mlir::OpRewritePattern<fir::IterWhileOp>(ctx) {} 207 208 mlir::LogicalResult 209 matchAndRewrite(fir::IterWhileOp whileOp, 210 mlir::PatternRewriter &rewriter) const override { 211 auto loc = whileOp.getLoc(); 212 213 // Start by splitting the block containing the 'fir.do_loop' into two parts. 214 // The part before will get the init code, the part after will be the end 215 // point. 216 auto *initBlock = rewriter.getInsertionBlock(); 217 auto initPosition = rewriter.getInsertionPoint(); 218 auto *endBlock = rewriter.splitBlock(initBlock, initPosition); 219 220 // Use the first block of the loop body as the condition block since it is 221 // the block that has the induction variable and loop-carried values as 222 // arguments. Split out all operations from the first block into a new 223 // block. Move all body blocks from the loop body region to the region 224 // containing the loop. 225 auto *conditionBlock = &whileOp.getRegion().front(); 226 auto *firstBodyBlock = 227 rewriter.splitBlock(conditionBlock, conditionBlock->begin()); 228 auto *lastBodyBlock = &whileOp.getRegion().back(); 229 rewriter.inlineRegionBefore(whileOp.getRegion(), endBlock); 230 auto iv = conditionBlock->getArgument(0); 231 auto iterateVar = conditionBlock->getArgument(1); 232 233 // Append the induction variable stepping logic to the last body block and 234 // branch back to the condition block. Loop-carried values are taken from 235 // operands of the loop terminator. 236 auto *terminator = lastBodyBlock->getTerminator(); 237 rewriter.setInsertionPointToEnd(lastBodyBlock); 238 auto step = whileOp.getStep(); 239 mlir::Value stepped = rewriter.create<mlir::arith::AddIOp>(loc, iv, step); 240 assert(stepped && "must be a Value"); 241 242 llvm::SmallVector<mlir::Value> loopCarried; 243 loopCarried.push_back(stepped); 244 auto begin = whileOp.getFinalValue() 245 ? std::next(terminator->operand_begin()) 246 : terminator->operand_begin(); 247 loopCarried.append(begin, terminator->operand_end()); 248 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried); 249 rewriter.eraseOp(terminator); 250 251 // Compute loop bounds before branching to the condition. 252 rewriter.setInsertionPointToEnd(initBlock); 253 auto lowerBound = whileOp.getLowerBound(); 254 auto upperBound = whileOp.getUpperBound(); 255 assert(lowerBound && upperBound && "must be a Value"); 256 257 // The initial values of loop-carried values is obtained from the operands 258 // of the loop operation. 259 llvm::SmallVector<mlir::Value> destOperands; 260 destOperands.push_back(lowerBound); 261 auto iterOperands = whileOp.getIterOperands(); 262 destOperands.append(iterOperands.begin(), iterOperands.end()); 263 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands); 264 265 // With the body block done, we can fill in the condition block. 266 rewriter.setInsertionPointToEnd(conditionBlock); 267 // The comparison depends on the sign of the step value. We fully expect 268 // this expression to be folded by the optimizer or LLVM. This expression 269 // is written this way so that `step == 0` always returns `false`. 270 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 271 auto compl0 = rewriter.create<mlir::arith::CmpIOp>( 272 loc, arith::CmpIPredicate::slt, zero, step); 273 auto compl1 = rewriter.create<mlir::arith::CmpIOp>( 274 loc, arith::CmpIPredicate::sle, iv, upperBound); 275 auto compl2 = rewriter.create<mlir::arith::CmpIOp>( 276 loc, arith::CmpIPredicate::slt, step, zero); 277 auto compl3 = rewriter.create<mlir::arith::CmpIOp>( 278 loc, arith::CmpIPredicate::sle, upperBound, iv); 279 auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1); 280 auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3); 281 auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1); 282 // Remember to AND in the early-exit bool. 283 auto comparison = 284 rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2); 285 rewriter.create<mlir::cf::CondBranchOp>( 286 loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(), 287 endBlock, llvm::ArrayRef<mlir::Value>()); 288 // The result of the loop operation is the values of the condition block 289 // arguments except the induction variable on the last iteration. 290 auto args = whileOp.getFinalValue() 291 ? conditionBlock->getArguments() 292 : conditionBlock->getArguments().drop_front(); 293 rewriter.replaceOp(whileOp, args); 294 return success(); 295 } 296 }; 297 298 /// Convert FIR structured control flow ops to CFG ops. 299 class CfgConversion : public CFGConversionBase<CfgConversion> { 300 public: 301 void runOnOperation() override { 302 auto *context = &getContext(); 303 mlir::RewritePatternSet patterns(context); 304 patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>( 305 context, forceLoopToExecuteOnce); 306 mlir::ConversionTarget target(*context); 307 target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect, 308 FIROpsDialect, mlir::func::FuncDialect>(); 309 310 // apply the patterns 311 target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>(); 312 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 313 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, 314 std::move(patterns)))) { 315 mlir::emitError(mlir::UnknownLoc::get(context), 316 "error in converting to CFG\n"); 317 signalPassFailure(); 318 } 319 } 320 }; 321 } // namespace 322 323 /// Convert FIR's structured control flow ops to CFG ops. This 324 /// conversion enables the `createLowerToCFGPass` to transform these to CFG 325 /// form. 326 std::unique_ptr<mlir::Pass> fir::createFirToCfgPass() { 327 return std::make_unique<CfgConversion>(); 328 } 329