1 //===-- RewriteLoop.cpp ---------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Dialect/FIRDialect.h" 11 #include "flang/Optimizer/Dialect/FIROps.h" 12 #include "flang/Optimizer/Transforms/Passes.h" 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/Support/CommandLine.h" 18 19 using namespace fir; 20 21 namespace { 22 23 // Conversion of fir control ops to more primitive control-flow. 24 // 25 // FIR loops that cannot be converted to the affine dialect will remain as 26 // `fir.do_loop` operations. These can be converted to control-flow operations. 27 28 /// Convert `fir.do_loop` to CFG 29 class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> { 30 public: 31 using OpRewritePattern::OpRewritePattern; 32 33 CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 34 : mlir::OpRewritePattern<fir::DoLoopOp>(ctx), 35 forceLoopToExecuteOnce(forceLoopToExecuteOnce) {} 36 37 mlir::LogicalResult 38 matchAndRewrite(DoLoopOp loop, 39 mlir::PatternRewriter &rewriter) const override { 40 auto loc = loop.getLoc(); 41 42 // Create the start and end blocks that will wrap the DoLoopOp with an 43 // initalizer and an end point 44 auto *initBlock = rewriter.getInsertionBlock(); 45 auto initPos = rewriter.getInsertionPoint(); 46 auto *endBlock = rewriter.splitBlock(initBlock, initPos); 47 48 // Split the first DoLoopOp block in two parts. The part before will be the 49 // conditional block since it already has the induction variable and 50 // loop-carried values as arguments. 51 auto *conditionalBlock = &loop.region().front(); 52 conditionalBlock->addArgument(rewriter.getIndexType()); 53 auto *firstBlock = 54 rewriter.splitBlock(conditionalBlock, conditionalBlock->begin()); 55 auto *lastBlock = &loop.region().back(); 56 57 // Move the blocks from the DoLoopOp between initBlock and endBlock 58 rewriter.inlineRegionBefore(loop.region(), endBlock); 59 60 // Get loop values from the DoLoopOp 61 auto low = loop.lowerBound(); 62 auto high = loop.upperBound(); 63 assert(low && high && "must be a Value"); 64 auto step = loop.step(); 65 66 // Initalization block 67 rewriter.setInsertionPointToEnd(initBlock); 68 auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low); 69 auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step); 70 mlir::Value iters = 71 rewriter.create<mlir::arith::DivSIOp>(loc, distance, step); 72 73 if (forceLoopToExecuteOnce) { 74 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 75 auto cond = rewriter.create<mlir::arith::CmpIOp>( 76 loc, arith::CmpIPredicate::sle, iters, zero); 77 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 78 iters = rewriter.create<mlir::SelectOp>(loc, cond, one, iters); 79 } 80 81 llvm::SmallVector<mlir::Value> loopOperands; 82 loopOperands.push_back(low); 83 auto operands = loop.getIterOperands(); 84 loopOperands.append(operands.begin(), operands.end()); 85 loopOperands.push_back(iters); 86 87 rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopOperands); 88 89 // Last loop block 90 auto *terminator = lastBlock->getTerminator(); 91 rewriter.setInsertionPointToEnd(lastBlock); 92 auto iv = conditionalBlock->getArgument(0); 93 mlir::Value steppedIndex = 94 rewriter.create<mlir::arith::AddIOp>(loc, iv, step); 95 assert(steppedIndex && "must be a Value"); 96 auto lastArg = conditionalBlock->getNumArguments() - 1; 97 auto itersLeft = conditionalBlock->getArgument(lastArg); 98 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); 99 mlir::Value itersMinusOne = 100 rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one); 101 102 llvm::SmallVector<mlir::Value> loopCarried; 103 loopCarried.push_back(steppedIndex); 104 auto begin = loop.finalValue() ? std::next(terminator->operand_begin()) 105 : terminator->operand_begin(); 106 loopCarried.append(begin, terminator->operand_end()); 107 loopCarried.push_back(itersMinusOne); 108 rewriter.create<mlir::BranchOp>(loc, conditionalBlock, loopCarried); 109 rewriter.eraseOp(terminator); 110 111 // Conditional block 112 rewriter.setInsertionPointToEnd(conditionalBlock); 113 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 114 auto comparison = rewriter.create<mlir::arith::CmpIOp>( 115 loc, arith::CmpIPredicate::sgt, itersLeft, zero); 116 117 rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBlock, 118 llvm::ArrayRef<mlir::Value>(), endBlock, 119 llvm::ArrayRef<mlir::Value>()); 120 121 // The result of the loop operation is the values of the condition block 122 // arguments except the induction variable on the last iteration. 123 auto args = loop.finalValue() 124 ? conditionalBlock->getArguments() 125 : conditionalBlock->getArguments().drop_front(); 126 rewriter.replaceOp(loop, args.drop_back()); 127 return success(); 128 } 129 130 private: 131 bool forceLoopToExecuteOnce; 132 }; 133 134 /// Convert `fir.if` to control-flow 135 class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> { 136 public: 137 using OpRewritePattern::OpRewritePattern; 138 139 CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 140 : mlir::OpRewritePattern<fir::IfOp>(ctx) {} 141 142 mlir::LogicalResult 143 matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override { 144 auto loc = ifOp.getLoc(); 145 146 // Split the block containing the 'fir.if' into two parts. The part before 147 // will contain the condition, the part after will be the continuation 148 // point. 149 auto *condBlock = rewriter.getInsertionBlock(); 150 auto opPosition = rewriter.getInsertionPoint(); 151 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); 152 mlir::Block *continueBlock; 153 if (ifOp.getNumResults() == 0) { 154 continueBlock = remainingOpsBlock; 155 } else { 156 continueBlock = 157 rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes()); 158 rewriter.create<mlir::BranchOp>(loc, remainingOpsBlock); 159 } 160 161 // Move blocks from the "then" region to the region containing 'fir.if', 162 // place it before the continuation block, and branch to it. 163 auto &ifOpRegion = ifOp.thenRegion(); 164 auto *ifOpBlock = &ifOpRegion.front(); 165 auto *ifOpTerminator = ifOpRegion.back().getTerminator(); 166 auto ifOpTerminatorOperands = ifOpTerminator->getOperands(); 167 rewriter.setInsertionPointToEnd(&ifOpRegion.back()); 168 rewriter.create<mlir::BranchOp>(loc, continueBlock, ifOpTerminatorOperands); 169 rewriter.eraseOp(ifOpTerminator); 170 rewriter.inlineRegionBefore(ifOpRegion, continueBlock); 171 172 // Move blocks from the "else" region (if present) to the region containing 173 // 'fir.if', place it before the continuation block and branch to it. It 174 // will be placed after the "then" regions. 175 auto *otherwiseBlock = continueBlock; 176 auto &otherwiseRegion = ifOp.elseRegion(); 177 if (!otherwiseRegion.empty()) { 178 otherwiseBlock = &otherwiseRegion.front(); 179 auto *otherwiseTerm = otherwiseRegion.back().getTerminator(); 180 auto otherwiseTermOperands = otherwiseTerm->getOperands(); 181 rewriter.setInsertionPointToEnd(&otherwiseRegion.back()); 182 rewriter.create<mlir::BranchOp>(loc, continueBlock, 183 otherwiseTermOperands); 184 rewriter.eraseOp(otherwiseTerm); 185 rewriter.inlineRegionBefore(otherwiseRegion, continueBlock); 186 } 187 188 rewriter.setInsertionPointToEnd(condBlock); 189 rewriter.create<mlir::CondBranchOp>( 190 loc, ifOp.condition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(), 191 otherwiseBlock, llvm::ArrayRef<mlir::Value>()); 192 rewriter.replaceOp(ifOp, continueBlock->getArguments()); 193 return success(); 194 } 195 }; 196 197 /// Convert `fir.iter_while` to control-flow. 198 class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> { 199 public: 200 using OpRewritePattern::OpRewritePattern; 201 202 CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) 203 : mlir::OpRewritePattern<fir::IterWhileOp>(ctx) {} 204 205 mlir::LogicalResult 206 matchAndRewrite(fir::IterWhileOp whileOp, 207 mlir::PatternRewriter &rewriter) const override { 208 auto loc = whileOp.getLoc(); 209 210 // Start by splitting the block containing the 'fir.do_loop' into two parts. 211 // The part before will get the init code, the part after will be the end 212 // point. 213 auto *initBlock = rewriter.getInsertionBlock(); 214 auto initPosition = rewriter.getInsertionPoint(); 215 auto *endBlock = rewriter.splitBlock(initBlock, initPosition); 216 217 // Use the first block of the loop body as the condition block since it is 218 // the block that has the induction variable and loop-carried values as 219 // arguments. Split out all operations from the first block into a new 220 // block. Move all body blocks from the loop body region to the region 221 // containing the loop. 222 auto *conditionBlock = &whileOp.region().front(); 223 auto *firstBodyBlock = 224 rewriter.splitBlock(conditionBlock, conditionBlock->begin()); 225 auto *lastBodyBlock = &whileOp.region().back(); 226 rewriter.inlineRegionBefore(whileOp.region(), endBlock); 227 auto iv = conditionBlock->getArgument(0); 228 auto iterateVar = conditionBlock->getArgument(1); 229 230 // Append the induction variable stepping logic to the last body block and 231 // branch back to the condition block. Loop-carried values are taken from 232 // operands of the loop terminator. 233 auto *terminator = lastBodyBlock->getTerminator(); 234 rewriter.setInsertionPointToEnd(lastBodyBlock); 235 auto step = whileOp.step(); 236 mlir::Value stepped = rewriter.create<mlir::arith::AddIOp>(loc, iv, step); 237 assert(stepped && "must be a Value"); 238 239 llvm::SmallVector<mlir::Value> loopCarried; 240 loopCarried.push_back(stepped); 241 auto begin = whileOp.finalValue() ? std::next(terminator->operand_begin()) 242 : terminator->operand_begin(); 243 loopCarried.append(begin, terminator->operand_end()); 244 rewriter.create<mlir::BranchOp>(loc, conditionBlock, loopCarried); 245 rewriter.eraseOp(terminator); 246 247 // Compute loop bounds before branching to the condition. 248 rewriter.setInsertionPointToEnd(initBlock); 249 auto lowerBound = whileOp.lowerBound(); 250 auto upperBound = whileOp.upperBound(); 251 assert(lowerBound && upperBound && "must be a Value"); 252 253 // The initial values of loop-carried values is obtained from the operands 254 // of the loop operation. 255 llvm::SmallVector<mlir::Value> destOperands; 256 destOperands.push_back(lowerBound); 257 auto iterOperands = whileOp.getIterOperands(); 258 destOperands.append(iterOperands.begin(), iterOperands.end()); 259 rewriter.create<mlir::BranchOp>(loc, conditionBlock, destOperands); 260 261 // With the body block done, we can fill in the condition block. 262 rewriter.setInsertionPointToEnd(conditionBlock); 263 // The comparison depends on the sign of the step value. We fully expect 264 // this expression to be folded by the optimizer or LLVM. This expression 265 // is written this way so that `step == 0` always returns `false`. 266 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); 267 auto compl0 = rewriter.create<mlir::arith::CmpIOp>( 268 loc, arith::CmpIPredicate::slt, zero, step); 269 auto compl1 = rewriter.create<mlir::arith::CmpIOp>( 270 loc, arith::CmpIPredicate::sle, iv, upperBound); 271 auto compl2 = rewriter.create<mlir::arith::CmpIOp>( 272 loc, arith::CmpIPredicate::slt, step, zero); 273 auto compl3 = rewriter.create<mlir::arith::CmpIOp>( 274 loc, arith::CmpIPredicate::sle, upperBound, iv); 275 auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1); 276 auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3); 277 auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1); 278 // Remember to AND in the early-exit bool. 279 auto comparison = 280 rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2); 281 rewriter.create<mlir::CondBranchOp>(loc, comparison, firstBodyBlock, 282 llvm::ArrayRef<mlir::Value>(), endBlock, 283 llvm::ArrayRef<mlir::Value>()); 284 // The result of the loop operation is the values of the condition block 285 // arguments except the induction variable on the last iteration. 286 auto args = whileOp.finalValue() 287 ? conditionBlock->getArguments() 288 : conditionBlock->getArguments().drop_front(); 289 rewriter.replaceOp(whileOp, args); 290 return success(); 291 } 292 }; 293 294 /// Convert FIR structured control flow ops to CFG ops. 295 class CfgConversion : public CFGConversionBase<CfgConversion> { 296 public: 297 void runOnFunction() override { 298 auto *context = &getContext(); 299 mlir::OwningRewritePatternList patterns(context); 300 patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>( 301 context, forceLoopToExecuteOnce); 302 mlir::ConversionTarget target(*context); 303 target.addLegalDialect<mlir::AffineDialect, FIROpsDialect, 304 mlir::StandardOpsDialect>(); 305 306 // apply the patterns 307 target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>(); 308 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 309 if (mlir::failed(mlir::applyPartialConversion(getFunction(), target, 310 std::move(patterns)))) { 311 mlir::emitError(mlir::UnknownLoc::get(context), 312 "error in converting to CFG\n"); 313 signalPassFailure(); 314 } 315 } 316 }; 317 } // namespace 318 319 /// Convert FIR's structured control flow ops to CFG ops. This 320 /// conversion enables the `createLowerToCFGPass` to transform these to CFG 321 /// form. 322 std::unique_ptr<mlir::Pass> fir::createFirToCfgPass() { 323 return std::make_unique<CfgConversion>(); 324 } 325