1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Transforms SCF.ForOp's into SCF.WhileOp's. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/SCF/Passes.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/Dialect/SCF/Transforms.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace llvm; 23 using namespace mlir; 24 using scf::ForOp; 25 using scf::WhileOp; 26 27 namespace { 28 29 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { 30 using OpRewritePattern<ForOp>::OpRewritePattern; 31 32 LogicalResult matchAndRewrite(ForOp forOp, 33 PatternRewriter &rewriter) const override { 34 // Generate type signature for the loop-carried values. The induction 35 // variable is placed first, followed by the forOp.iterArgs. 36 SmallVector<Type> lcvTypes; 37 SmallVector<Location> lcvLocs; 38 lcvTypes.push_back(forOp.getInductionVar().getType()); 39 lcvLocs.push_back(forOp.getInductionVar().getLoc()); 40 for (Value value : forOp.getInitArgs()) { 41 lcvTypes.push_back(value.getType()); 42 lcvLocs.push_back(value.getLoc()); 43 } 44 45 // Build scf.WhileOp 46 SmallVector<Value> initArgs; 47 initArgs.push_back(forOp.getLowerBound()); 48 llvm::append_range(initArgs, forOp.getInitArgs()); 49 auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs, 50 forOp->getAttrs()); 51 52 // 'before' region contains the loop condition and forwarding of iteration 53 // arguments to the 'after' region. 54 auto *beforeBlock = rewriter.createBlock( 55 &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs); 56 rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); 57 auto cmpOp = rewriter.create<arith::CmpIOp>( 58 whileOp.getLoc(), arith::CmpIPredicate::slt, 59 beforeBlock->getArgument(0), forOp.getUpperBound()); 60 rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(), 61 beforeBlock->getArguments()); 62 63 // Inline for-loop body into an executeRegion operation in the "after" 64 // region. The return type of the execRegionOp does not contain the 65 // iv - yields in the source for-loop contain only iterArgs. 66 auto *afterBlock = rewriter.createBlock( 67 &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs); 68 69 // Add induction variable incrementation 70 rewriter.setInsertionPointToEnd(afterBlock); 71 auto ivIncOp = rewriter.create<arith::AddIOp>( 72 whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep()); 73 74 // Rewrite uses of the for-loop block arguments to the new while-loop 75 // "after" arguments 76 for (const auto &barg : enumerate(forOp.getBody(0)->getArguments())) 77 barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index())); 78 79 // Inline for-loop body operations into 'after' region. 80 for (auto &arg : llvm::make_early_inc_range(*forOp.getBody())) 81 arg.moveBefore(afterBlock, afterBlock->end()); 82 83 // Add incremented IV to yield operations 84 for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) { 85 SmallVector<Value> yieldOperands = yieldOp.getOperands(); 86 yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult()); 87 yieldOp->setOperands(yieldOperands); 88 } 89 90 // We cannot do a direct replacement of the forOp since the while op returns 91 // an extra value (the induction variable escapes the loop through being 92 // carried in the set of iterargs). Instead, rewrite uses of the forOp 93 // results. 94 for (const auto &arg : llvm::enumerate(forOp.getResults())) 95 arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1)); 96 97 rewriter.eraseOp(forOp); 98 return success(); 99 } 100 }; 101 102 struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> { 103 void runOnOperation() override { 104 FuncOp funcOp = getOperation(); 105 MLIRContext *ctx = funcOp.getContext(); 106 RewritePatternSet patterns(ctx); 107 patterns.add<ForLoopLoweringPattern>(ctx); 108 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 109 } 110 }; 111 } // namespace 112 113 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() { 114 return std::make_unique<ForToWhileLoop>(); 115 } 116