1032cb165SMorten Borup Petersen //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===// 2032cb165SMorten Borup Petersen // 3032cb165SMorten Borup Petersen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4032cb165SMorten Borup Petersen // See https://llvm.org/LICENSE.txt for license information. 5032cb165SMorten Borup Petersen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6032cb165SMorten Borup Petersen // 7032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===// 8032cb165SMorten Borup Petersen // 9032cb165SMorten Borup Petersen // Transforms SCF.ForOp's into SCF.WhileOp's. 10032cb165SMorten Borup Petersen // 11032cb165SMorten Borup Petersen //===----------------------------------------------------------------------===// 12032cb165SMorten Borup Petersen 13032cb165SMorten Borup Petersen #include "PassDetail.h" 14*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15032cb165SMorten Borup Petersen #include "mlir/Dialect/SCF/Passes.h" 16032cb165SMorten Borup Petersen #include "mlir/Dialect/SCF/SCF.h" 17032cb165SMorten Borup Petersen #include "mlir/Dialect/SCF/Transforms.h" 18032cb165SMorten Borup Petersen #include "mlir/Dialect/StandardOps/IR/Ops.h" 19032cb165SMorten Borup Petersen #include "mlir/IR/PatternMatch.h" 20032cb165SMorten Borup Petersen #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21032cb165SMorten Borup Petersen 22032cb165SMorten Borup Petersen using namespace llvm; 23032cb165SMorten Borup Petersen using namespace mlir; 24032cb165SMorten Borup Petersen using scf::ForOp; 25032cb165SMorten Borup Petersen using scf::WhileOp; 26032cb165SMorten Borup Petersen 27032cb165SMorten Borup Petersen namespace { 28032cb165SMorten Borup Petersen 29032cb165SMorten Borup Petersen struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { 30032cb165SMorten Borup Petersen using OpRewritePattern<ForOp>::OpRewritePattern; 31032cb165SMorten Borup Petersen 32032cb165SMorten Borup Petersen LogicalResult matchAndRewrite(ForOp forOp, 33032cb165SMorten Borup Petersen PatternRewriter &rewriter) const override { 34032cb165SMorten Borup Petersen // Generate type signature for the loop-carried values. The induction 35032cb165SMorten Borup Petersen // variable is placed first, followed by the forOp.iterArgs. 36032cb165SMorten Borup Petersen SmallVector<Type, 8> lcvTypes; 37032cb165SMorten Borup Petersen lcvTypes.push_back(forOp.getInductionVar().getType()); 38032cb165SMorten Borup Petersen llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes), 39032cb165SMorten Borup Petersen [&](auto v) { return v.getType(); }); 40032cb165SMorten Borup Petersen 41032cb165SMorten Borup Petersen // Build scf.WhileOp 42032cb165SMorten Borup Petersen SmallVector<Value> initArgs; 43032cb165SMorten Borup Petersen initArgs.push_back(forOp.lowerBound()); 44032cb165SMorten Borup Petersen llvm::append_range(initArgs, forOp.initArgs()); 45032cb165SMorten Borup Petersen auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs, 46032cb165SMorten Borup Petersen forOp->getAttrs()); 47032cb165SMorten Borup Petersen 48032cb165SMorten Borup Petersen // 'before' region contains the loop condition and forwarding of iteration 49032cb165SMorten Borup Petersen // arguments to the 'after' region. 50032cb165SMorten Borup Petersen auto *beforeBlock = rewriter.createBlock( 51032cb165SMorten Borup Petersen &whileOp.before(), whileOp.before().begin(), lcvTypes, {}); 52032cb165SMorten Borup Petersen rewriter.setInsertionPointToStart(&whileOp.before().front()); 53*a54f4eaeSMogball auto cmpOp = rewriter.create<arith::CmpIOp>( 54*a54f4eaeSMogball whileOp.getLoc(), arith::CmpIPredicate::slt, 55*a54f4eaeSMogball beforeBlock->getArgument(0), forOp.upperBound()); 56032cb165SMorten Borup Petersen rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(), 57032cb165SMorten Borup Petersen beforeBlock->getArguments()); 58032cb165SMorten Borup Petersen 59032cb165SMorten Borup Petersen // Inline for-loop body into an executeRegion operation in the "after" 60032cb165SMorten Borup Petersen // region. The return type of the execRegionOp does not contain the 61032cb165SMorten Borup Petersen // iv - yields in the source for-loop contain only iterArgs. 62032cb165SMorten Borup Petersen auto *afterBlock = rewriter.createBlock( 63032cb165SMorten Borup Petersen &whileOp.after(), whileOp.after().begin(), lcvTypes, {}); 64032cb165SMorten Borup Petersen 65032cb165SMorten Borup Petersen // Add induction variable incrementation 66032cb165SMorten Borup Petersen rewriter.setInsertionPointToEnd(afterBlock); 67*a54f4eaeSMogball auto ivIncOp = rewriter.create<arith::AddIOp>( 68032cb165SMorten Borup Petersen whileOp.getLoc(), afterBlock->getArgument(0), forOp.step()); 69032cb165SMorten Borup Petersen 70032cb165SMorten Borup Petersen // Rewrite uses of the for-loop block arguments to the new while-loop 71032cb165SMorten Borup Petersen // "after" arguments 72032cb165SMorten Borup Petersen for (auto barg : enumerate(forOp.getBody(0)->getArguments())) 73032cb165SMorten Borup Petersen barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index())); 74032cb165SMorten Borup Petersen 75032cb165SMorten Borup Petersen // Inline for-loop body operations into 'after' region. 76032cb165SMorten Borup Petersen for (auto &arg : llvm::make_early_inc_range(*forOp.getBody())) 77032cb165SMorten Borup Petersen arg.moveBefore(afterBlock, afterBlock->end()); 78032cb165SMorten Borup Petersen 79032cb165SMorten Borup Petersen // Add incremented IV to yield operations 80032cb165SMorten Borup Petersen for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) { 81032cb165SMorten Borup Petersen SmallVector<Value> yieldOperands = yieldOp.getOperands(); 82032cb165SMorten Borup Petersen yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult()); 83032cb165SMorten Borup Petersen yieldOp->setOperands(yieldOperands); 84032cb165SMorten Borup Petersen } 85032cb165SMorten Borup Petersen 86032cb165SMorten Borup Petersen // We cannot do a direct replacement of the forOp since the while op returns 87032cb165SMorten Borup Petersen // an extra value (the induction variable escapes the loop through being 88032cb165SMorten Borup Petersen // carried in the set of iterargs). Instead, rewrite uses of the forOp 89032cb165SMorten Borup Petersen // results. 90032cb165SMorten Borup Petersen for (auto arg : llvm::enumerate(forOp.getResults())) 91032cb165SMorten Borup Petersen arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1)); 92032cb165SMorten Borup Petersen 93032cb165SMorten Borup Petersen rewriter.eraseOp(forOp); 94032cb165SMorten Borup Petersen return success(); 95032cb165SMorten Borup Petersen } 96032cb165SMorten Borup Petersen }; 97032cb165SMorten Borup Petersen 98032cb165SMorten Borup Petersen struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> { 99032cb165SMorten Borup Petersen void runOnFunction() override { 100032cb165SMorten Borup Petersen FuncOp funcOp = getFunction(); 101032cb165SMorten Borup Petersen MLIRContext *ctx = funcOp.getContext(); 102032cb165SMorten Borup Petersen RewritePatternSet patterns(ctx); 103032cb165SMorten Borup Petersen patterns.add<ForLoopLoweringPattern>(ctx); 104032cb165SMorten Borup Petersen (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 105032cb165SMorten Borup Petersen } 106032cb165SMorten Borup Petersen }; 107032cb165SMorten Borup Petersen } // namespace 108032cb165SMorten Borup Petersen 109032cb165SMorten Borup Petersen std::unique_ptr<Pass> mlir::createForToWhileLoopPass() { 110032cb165SMorten Borup Petersen return std::make_unique<ForToWhileLoop>(); 111032cb165SMorten Borup Petersen } 112