1 //===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to unroll loops by a specified unroll factor. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/Dialect/SCF/Utils/Utils.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/Pass/Pass.h" 19 20 using namespace mlir; 21 22 namespace { 23 24 static unsigned getNestingDepth(Operation *op) { 25 Operation *currOp = op; 26 unsigned depth = 0; 27 while ((currOp = currOp->getParentOp())) { 28 if (isa<scf::ForOp>(currOp)) 29 depth++; 30 } 31 return depth; 32 } 33 34 class TestLoopUnrollingPass 35 : public PassWrapper<TestLoopUnrollingPass, OperationPass<FuncOp>> { 36 public: 37 StringRef getArgument() const final { return "test-loop-unrolling"; } 38 StringRef getDescription() const final { 39 return "Tests loop unrolling transformation"; 40 } 41 TestLoopUnrollingPass() = default; 42 TestLoopUnrollingPass(const TestLoopUnrollingPass &) {} 43 explicit TestLoopUnrollingPass(uint64_t unrollFactorParam, 44 unsigned loopDepthParam, 45 bool annotateLoopParam) { 46 unrollFactor = unrollFactorParam; 47 loopDepth = loopDepthParam; 48 annotateLoop = annotateLoopParam; 49 } 50 51 void getDependentDialects(DialectRegistry ®istry) const override { 52 registry.insert<arith::ArithmeticDialect, StandardOpsDialect>(); 53 } 54 55 void runOnOperation() override { 56 FuncOp func = getOperation(); 57 SmallVector<scf::ForOp, 4> loops; 58 func.walk([&](scf::ForOp forOp) { 59 if (getNestingDepth(forOp) == loopDepth) 60 loops.push_back(forOp); 61 }); 62 auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) { 63 if (annotateLoop) { 64 op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i)); 65 } 66 }; 67 for (auto loop : loops) 68 (void)loopUnrollByFactor(loop, unrollFactor, annotateFn); 69 } 70 Option<uint64_t> unrollFactor{*this, "unroll-factor", 71 llvm::cl::desc("Loop unroll factor."), 72 llvm::cl::init(1)}; 73 Option<bool> annotateLoop{*this, "annotate", 74 llvm::cl::desc("Annotate unrolled iterations."), 75 llvm::cl::init(false)}; 76 Option<bool> unrollUpToFactor{*this, "unroll-up-to-factor", 77 llvm::cl::desc("Loop unroll up to factor."), 78 llvm::cl::init(false)}; 79 Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."), 80 llvm::cl::init(0)}; 81 }; 82 } // namespace 83 84 namespace mlir { 85 namespace test { 86 void registerTestLoopUnrollingPass() { 87 PassRegistration<TestLoopUnrollingPass>(); 88 } 89 } // namespace test 90 } // namespace mlir 91