1 //===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to unroll loops by a specified unroll factor. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/Dialect/SCF/Utils/Utils.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/Pass/Pass.h" 18 19 using namespace mlir; 20 21 namespace { 22 23 static unsigned getNestingDepth(Operation *op) { 24 Operation *currOp = op; 25 unsigned depth = 0; 26 while ((currOp = currOp->getParentOp())) { 27 if (isa<scf::ForOp>(currOp)) 28 depth++; 29 } 30 return depth; 31 } 32 33 struct TestLoopUnrollingPass 34 : public PassWrapper<TestLoopUnrollingPass, OperationPass<>> { 35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopUnrollingPass) 36 37 StringRef getArgument() const final { return "test-loop-unrolling"; } 38 StringRef getDescription() const final { 39 return "Tests loop unrolling transformation"; 40 } 41 TestLoopUnrollingPass() = default; 42 TestLoopUnrollingPass(const TestLoopUnrollingPass &) {} 43 explicit TestLoopUnrollingPass(uint64_t unrollFactorParam, 44 unsigned loopDepthParam, 45 bool annotateLoopParam) { 46 unrollFactor = unrollFactorParam; 47 loopDepth = loopDepthParam; 48 annotateLoop = annotateLoopParam; 49 } 50 51 void getDependentDialects(DialectRegistry ®istry) const override { 52 registry.insert<arith::ArithmeticDialect>(); 53 } 54 55 void runOnOperation() override { 56 SmallVector<scf::ForOp, 4> loops; 57 getOperation()->walk([&](scf::ForOp forOp) { 58 if (getNestingDepth(forOp) == loopDepth) 59 loops.push_back(forOp); 60 }); 61 auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) { 62 if (annotateLoop) { 63 op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i)); 64 } 65 }; 66 for (auto loop : loops) 67 (void)loopUnrollByFactor(loop, unrollFactor, annotateFn); 68 } 69 Option<uint64_t> unrollFactor{*this, "unroll-factor", 70 llvm::cl::desc("Loop unroll factor."), 71 llvm::cl::init(1)}; 72 Option<bool> annotateLoop{*this, "annotate", 73 llvm::cl::desc("Annotate unrolled iterations."), 74 llvm::cl::init(false)}; 75 Option<bool> unrollUpToFactor{*this, "unroll-up-to-factor", 76 llvm::cl::desc("Loop unroll up to factor."), 77 llvm::cl::init(false)}; 78 Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."), 79 llvm::cl::init(0)}; 80 }; 81 } // namespace 82 83 namespace mlir { 84 namespace test { 85 void registerTestLoopUnrollingPass() { 86 PassRegistration<TestLoopUnrollingPass>(); 87 } 88 } // namespace test 89 } // namespace mlir 90