1 //===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to unroll loops by a specified unroll factor. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/Dialect/SCF/Utils/Utils.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/Pass/Pass.h" 18 19 using namespace mlir; 20 21 namespace { 22 23 static unsigned getNestingDepth(Operation *op) { 24 Operation *currOp = op; 25 unsigned depth = 0; 26 while ((currOp = currOp->getParentOp())) { 27 if (isa<scf::ForOp>(currOp)) 28 depth++; 29 } 30 return depth; 31 } 32 33 class TestLoopUnrollingPass 34 : public PassWrapper<TestLoopUnrollingPass, OperationPass<>> { 35 public: 36 StringRef getArgument() const final { return "test-loop-unrolling"; } 37 StringRef getDescription() const final { 38 return "Tests loop unrolling transformation"; 39 } 40 TestLoopUnrollingPass() = default; 41 TestLoopUnrollingPass(const TestLoopUnrollingPass &) {} 42 explicit TestLoopUnrollingPass(uint64_t unrollFactorParam, 43 unsigned loopDepthParam, 44 bool annotateLoopParam) { 45 unrollFactor = unrollFactorParam; 46 loopDepth = loopDepthParam; 47 annotateLoop = annotateLoopParam; 48 } 49 50 void getDependentDialects(DialectRegistry ®istry) const override { 51 registry.insert<arith::ArithmeticDialect>(); 52 } 53 54 void runOnOperation() override { 55 SmallVector<scf::ForOp, 4> loops; 56 getOperation()->walk([&](scf::ForOp forOp) { 57 if (getNestingDepth(forOp) == loopDepth) 58 loops.push_back(forOp); 59 }); 60 auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) { 61 if (annotateLoop) { 62 op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i)); 63 } 64 }; 65 for (auto loop : loops) 66 (void)loopUnrollByFactor(loop, unrollFactor, annotateFn); 67 } 68 Option<uint64_t> unrollFactor{*this, "unroll-factor", 69 llvm::cl::desc("Loop unroll factor."), 70 llvm::cl::init(1)}; 71 Option<bool> annotateLoop{*this, "annotate", 72 llvm::cl::desc("Annotate unrolled iterations."), 73 llvm::cl::init(false)}; 74 Option<bool> unrollUpToFactor{*this, "unroll-up-to-factor", 75 llvm::cl::desc("Loop unroll up to factor."), 76 llvm::cl::init(false)}; 77 Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."), 78 llvm::cl::init(0)}; 79 }; 80 } // namespace 81 82 namespace mlir { 83 namespace test { 84 void registerTestLoopUnrollingPass() { 85 PassRegistration<TestLoopUnrollingPass>(); 86 } 87 } // namespace test 88 } // namespace mlir 89