1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Transforms SCF.ForOp's into SCF.WhileOp's. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/SCF/Passes.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/SCF/Transforms.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace llvm; 22 using namespace mlir; 23 using scf::ForOp; 24 using scf::WhileOp; 25 26 namespace { 27 28 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { 29 using OpRewritePattern<ForOp>::OpRewritePattern; 30 31 LogicalResult matchAndRewrite(ForOp forOp, 32 PatternRewriter &rewriter) const override { 33 // Generate type signature for the loop-carried values. The induction 34 // variable is placed first, followed by the forOp.iterArgs. 35 SmallVector<Type, 8> lcvTypes; 36 lcvTypes.push_back(forOp.getInductionVar().getType()); 37 llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes), 38 [&](auto v) { return v.getType(); }); 39 40 // Build scf.WhileOp 41 SmallVector<Value> initArgs; 42 initArgs.push_back(forOp.lowerBound()); 43 llvm::append_range(initArgs, forOp.initArgs()); 44 auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs, 45 forOp->getAttrs()); 46 47 // 'before' region contains the loop condition and forwarding of iteration 48 // arguments to the 'after' region. 49 auto *beforeBlock = rewriter.createBlock( 50 &whileOp.before(), whileOp.before().begin(), lcvTypes, {}); 51 rewriter.setInsertionPointToStart(&whileOp.before().front()); 52 auto cmpOp = rewriter.create<CmpIOp>(whileOp.getLoc(), CmpIPredicate::slt, 53 beforeBlock->getArgument(0), 54 forOp.upperBound()); 55 rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(), 56 beforeBlock->getArguments()); 57 58 // Inline for-loop body into an executeRegion operation in the "after" 59 // region. The return type of the execRegionOp does not contain the 60 // iv - yields in the source for-loop contain only iterArgs. 61 auto *afterBlock = rewriter.createBlock( 62 &whileOp.after(), whileOp.after().begin(), lcvTypes, {}); 63 64 // Add induction variable incrementation 65 rewriter.setInsertionPointToEnd(afterBlock); 66 auto ivIncOp = rewriter.create<AddIOp>( 67 whileOp.getLoc(), afterBlock->getArgument(0), forOp.step()); 68 69 // Rewrite uses of the for-loop block arguments to the new while-loop 70 // "after" arguments 71 for (auto barg : enumerate(forOp.getBody(0)->getArguments())) 72 barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index())); 73 74 // Inline for-loop body operations into 'after' region. 75 for (auto &arg : llvm::make_early_inc_range(*forOp.getBody())) 76 arg.moveBefore(afterBlock, afterBlock->end()); 77 78 // Add incremented IV to yield operations 79 for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) { 80 SmallVector<Value> yieldOperands = yieldOp.getOperands(); 81 yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult()); 82 yieldOp->setOperands(yieldOperands); 83 } 84 85 // We cannot do a direct replacement of the forOp since the while op returns 86 // an extra value (the induction variable escapes the loop through being 87 // carried in the set of iterargs). Instead, rewrite uses of the forOp 88 // results. 89 for (auto arg : llvm::enumerate(forOp.getResults())) 90 arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1)); 91 92 rewriter.eraseOp(forOp); 93 return success(); 94 } 95 }; 96 97 struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> { 98 void runOnFunction() override { 99 FuncOp funcOp = getFunction(); 100 MLIRContext *ctx = funcOp.getContext(); 101 RewritePatternSet patterns(ctx); 102 patterns.add<ForLoopLoweringPattern>(ctx); 103 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 104 } 105 }; 106 } // namespace 107 108 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() { 109 return std::make_unique<ForToWhileLoop>(); 110 } 111