1 //===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to unroll loops by a specified unroll factor.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/SCF/Utils/Utils.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/Pass/Pass.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 
23 static unsigned getNestingDepth(Operation *op) {
24   Operation *currOp = op;
25   unsigned depth = 0;
26   while ((currOp = currOp->getParentOp())) {
27     if (isa<scf::ForOp>(currOp))
28       depth++;
29   }
30   return depth;
31 }
32 
33 class TestLoopUnrollingPass
34     : public PassWrapper<TestLoopUnrollingPass, OperationPass<>> {
35 public:
36   StringRef getArgument() const final { return "test-loop-unrolling"; }
37   StringRef getDescription() const final {
38     return "Tests loop unrolling transformation";
39   }
40   TestLoopUnrollingPass() = default;
41   TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
42   explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
43                                  unsigned loopDepthParam,
44                                  bool annotateLoopParam) {
45     unrollFactor = unrollFactorParam;
46     loopDepth = loopDepthParam;
47     annotateLoop = annotateLoopParam;
48   }
49 
50   void getDependentDialects(DialectRegistry &registry) const override {
51     registry.insert<arith::ArithmeticDialect>();
52   }
53 
54   void runOnOperation() override {
55     SmallVector<scf::ForOp, 4> loops;
56     getOperation()->walk([&](scf::ForOp forOp) {
57       if (getNestingDepth(forOp) == loopDepth)
58         loops.push_back(forOp);
59     });
60     auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) {
61       if (annotateLoop) {
62         op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
63       }
64     };
65     for (auto loop : loops)
66       (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
67   }
68   Option<uint64_t> unrollFactor{*this, "unroll-factor",
69                                 llvm::cl::desc("Loop unroll factor."),
70                                 llvm::cl::init(1)};
71   Option<bool> annotateLoop{*this, "annotate",
72                             llvm::cl::desc("Annotate unrolled iterations."),
73                             llvm::cl::init(false)};
74   Option<bool> unrollUpToFactor{*this, "unroll-up-to-factor",
75                                 llvm::cl::desc("Loop unroll up to factor."),
76                                 llvm::cl::init(false)};
77   Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
78                              llvm::cl::init(0)};
79 };
80 } // namespace
81 
82 namespace mlir {
83 namespace test {
84 void registerTestLoopUnrollingPass() {
85   PassRegistration<TestLoopUnrollingPass>();
86 }
87 } // namespace test
88 } // namespace mlir
89