1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unrolling.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/LoopAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/Passes.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/LoopUtils.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Support/Debug.h"
24 
25 using namespace mlir;
26 
27 #define DEBUG_TYPE "affine-loop-unroll"
28 
29 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
30 
31 // Loop unrolling factor.
32 static llvm::cl::opt<unsigned> clUnrollFactor(
33     "unroll-factor",
34     llvm::cl::desc("Use this unroll factor for all loops being unrolled"),
35     llvm::cl::cat(clOptionsCategory));
36 
37 static llvm::cl::opt<bool> clUnrollFull("unroll-full",
38                                         llvm::cl::desc("Fully unroll loops"),
39                                         llvm::cl::cat(clOptionsCategory));
40 
41 static llvm::cl::opt<unsigned> clUnrollNumRepetitions(
42     "unroll-num-reps",
43     llvm::cl::desc("Unroll innermost loops repeatedly this many times"),
44     llvm::cl::cat(clOptionsCategory));
45 
46 static llvm::cl::opt<unsigned> clUnrollFullThreshold(
47     "unroll-full-threshold", llvm::cl::Hidden,
48     llvm::cl::desc(
49         "Unroll all loops with trip count less than or equal to this"),
50     llvm::cl::cat(clOptionsCategory));
51 
52 namespace {
53 
54 // TODO: this is really a test pass and should be moved out of dialect
55 // transforms.
56 
57 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
58 /// full unroll threshold was specified, in which case, fully unrolls all loops
59 /// with trip count less than the specified threshold. The latter is for testing
60 /// purposes, especially for testing outer loop unrolling.
61 struct LoopUnroll : public FunctionPass<LoopUnroll> {
62 /// Include the generated pass utilities.
63 #define GEN_PASS_AffineUnroll
64 #include "mlir/Dialect/Affine/Passes.h.inc"
65 
66   const Optional<unsigned> unrollFactor;
67   const Optional<bool> unrollFull;
68   // Callback to obtain unroll factors; if this has a callable target, takes
69   // precedence over command-line argument or passed argument.
70   const std::function<unsigned(AffineForOp)> getUnrollFactor;
71 
72   explicit LoopUnroll(
73       Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
74       const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
75       : unrollFactor(unrollFactor), unrollFull(unrollFull),
76         getUnrollFactor(getUnrollFactor) {}
77 
78   void runOnFunction() override;
79 
80   /// Unroll this for op. Returns failure if nothing was done.
81   LogicalResult runOnAffineForOp(AffineForOp forOp);
82 
83   static const unsigned kDefaultUnrollFactor = 4;
84 };
85 } // end anonymous namespace
86 
87 /// Returns true if no other affine.for ops are nested within.
88 static bool isInnermostAffineForOp(AffineForOp forOp) {
89   // Only for the innermost affine.for op's.
90   bool isInnermost = true;
91   forOp.walk([&](AffineForOp thisForOp) {
92     // Since this is a post order walk, we are able to conclude here.
93     isInnermost = (thisForOp == forOp);
94     return WalkResult::interrupt();
95   });
96   return isInnermost;
97 }
98 
99 /// Gathers loops that have no affine.for's nested within.
100 static void gatherInnermostLoops(FuncOp f,
101                                  SmallVectorImpl<AffineForOp> &loops) {
102   f.walk([&](AffineForOp forOp) {
103     if (isInnermostAffineForOp(forOp))
104       loops.push_back(forOp);
105   });
106 }
107 
108 void LoopUnroll::runOnFunction() {
109   if (clUnrollFull.getNumOccurrences() > 0 &&
110       clUnrollFullThreshold.getNumOccurrences() > 0) {
111     // Store short loops as we walk.
112     SmallVector<AffineForOp, 4> loops;
113 
114     // Gathers all loops with trip count <= minTripCount. Do a post order walk
115     // so that loops are gathered from innermost to outermost (or else unrolling
116     // an outer one may delete gathered inner ones).
117     getFunction().walk([&](AffineForOp forOp) {
118       Optional<uint64_t> tripCount = getConstantTripCount(forOp);
119       if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold)
120         loops.push_back(forOp);
121     });
122     for (auto forOp : loops)
123       loopUnrollFull(forOp);
124     return;
125   }
126 
127   unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0
128                                 ? clUnrollNumRepetitions
129                                 : 1;
130   // If the call back is provided, we will recurse until no loops are found.
131   FuncOp func = getFunction();
132   SmallVector<AffineForOp, 4> loops;
133   for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
134     loops.clear();
135     gatherInnermostLoops(func, loops);
136     if (loops.empty())
137       break;
138     bool unrolled = false;
139     for (auto forOp : loops)
140       unrolled |= succeeded(runOnAffineForOp(forOp));
141     if (!unrolled)
142       // Break out if nothing was unrolled.
143       break;
144   }
145 }
146 
147 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
148 /// failure otherwise. The default unroll factor is 4.
149 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
150   // Use the function callback if one was provided.
151   if (getUnrollFactor) {
152     return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
153   }
154   // Unroll by the factor passed, if any.
155   if (unrollFactor.hasValue())
156     return loopUnrollByFactor(forOp, unrollFactor.getValue());
157   // Unroll by the command line factor if one was specified.
158   if (clUnrollFactor.getNumOccurrences() > 0)
159     return loopUnrollByFactor(forOp, clUnrollFactor);
160   // Unroll completely if full loop unroll was specified.
161   if (clUnrollFull.getNumOccurrences() > 0 ||
162       (unrollFull.hasValue() && unrollFull.getValue()))
163     return loopUnrollFull(forOp);
164 
165   // Unroll by four otherwise.
166   return loopUnrollByFactor(forOp, kDefaultUnrollFactor);
167 }
168 
169 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopUnrollPass(
170     int unrollFactor, int unrollFull,
171     const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
172   return std::make_unique<LoopUnroll>(
173       unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
174       unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
175 }
176