1 //===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to unroll loops by a specified unroll factor.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/Dialect/SCF/Utils/Utils.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Pass/Pass.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 
24 static unsigned getNestingDepth(Operation *op) {
25   Operation *currOp = op;
26   unsigned depth = 0;
27   while ((currOp = currOp->getParentOp())) {
28     if (isa<scf::ForOp>(currOp))
29       depth++;
30   }
31   return depth;
32 }
33 
34 class TestLoopUnrollingPass
35     : public PassWrapper<TestLoopUnrollingPass, OperationPass<FuncOp>> {
36 public:
37   StringRef getArgument() const final { return "test-loop-unrolling"; }
38   StringRef getDescription() const final {
39     return "Tests loop unrolling transformation";
40   }
41   TestLoopUnrollingPass() = default;
42   TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
43   explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
44                                  unsigned loopDepthParam,
45                                  bool annotateLoopParam) {
46     unrollFactor = unrollFactorParam;
47     loopDepth = loopDepthParam;
48     annotateLoop = annotateLoopParam;
49   }
50 
51   void getDependentDialects(DialectRegistry &registry) const override {
52     registry.insert<arith::ArithmeticDialect, StandardOpsDialect>();
53   }
54 
55   void runOnOperation() override {
56     FuncOp func = getOperation();
57     SmallVector<scf::ForOp, 4> loops;
58     func.walk([&](scf::ForOp forOp) {
59       if (getNestingDepth(forOp) == loopDepth)
60         loops.push_back(forOp);
61     });
62     auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) {
63       if (annotateLoop) {
64         op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
65       }
66     };
67     for (auto loop : loops)
68       (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
69   }
70   Option<uint64_t> unrollFactor{*this, "unroll-factor",
71                                 llvm::cl::desc("Loop unroll factor."),
72                                 llvm::cl::init(1)};
73   Option<bool> annotateLoop{*this, "annotate",
74                             llvm::cl::desc("Annotate unrolled iterations."),
75                             llvm::cl::init(false)};
76   Option<bool> unrollUpToFactor{*this, "unroll-up-to-factor",
77                                 llvm::cl::desc("Loop unroll up to factor."),
78                                 llvm::cl::init(false)};
79   Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
80                              llvm::cl::init(0)};
81 };
82 } // namespace
83 
84 namespace mlir {
85 namespace test {
86 void registerTestLoopUnrollingPass() {
87   PassRegistration<TestLoopUnrollingPass>();
88 }
89 } // namespace test
90 } // namespace mlir
91