1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unrolling.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/LoopAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/Passes.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/LoopUtils.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Support/Debug.h"
24 
25 using namespace mlir;
26 
27 #define DEBUG_TYPE "affine-loop-unroll"
28 
29 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
30 
31 // Loop unrolling factor.
32 static llvm::cl::opt<unsigned> clUnrollFactor(
33     "unroll-factor",
34     llvm::cl::desc("Use this unroll factor for all loops being unrolled"),
35     llvm::cl::cat(clOptionsCategory));
36 
37 static llvm::cl::opt<bool> clUnrollFull("unroll-full",
38                                         llvm::cl::desc("Fully unroll loops"),
39                                         llvm::cl::cat(clOptionsCategory));
40 
41 static llvm::cl::opt<unsigned> clUnrollNumRepetitions(
42     "unroll-num-reps",
43     llvm::cl::desc("Unroll innermost loops repeatedly this many times"),
44     llvm::cl::cat(clOptionsCategory));
45 
46 static llvm::cl::opt<unsigned> clUnrollFullThreshold(
47     "unroll-full-threshold", llvm::cl::Hidden,
48     llvm::cl::desc(
49         "Unroll all loops with trip count less than or equal to this"),
50     llvm::cl::cat(clOptionsCategory));
51 
52 namespace {
53 
54 // TODO: this is really a test pass and should be moved out of dialect
55 // transforms.
56 
57 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
58 /// full unroll threshold was specified, in which case, fully unrolls all loops
59 /// with trip count less than the specified threshold. The latter is for testing
60 /// purposes, especially for testing outer loop unrolling.
61 struct LoopUnroll : public FunctionPass<LoopUnroll> {
62   const Optional<unsigned> unrollFactor;
63   const Optional<bool> unrollFull;
64   // Callback to obtain unroll factors; if this has a callable target, takes
65   // precedence over command-line argument or passed argument.
66   const std::function<unsigned(AffineForOp)> getUnrollFactor;
67 
68   explicit LoopUnroll(
69       Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
70       const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
71       : unrollFactor(unrollFactor), unrollFull(unrollFull),
72         getUnrollFactor(getUnrollFactor) {}
73 
74   void runOnFunction() override;
75 
76   /// Unroll this for op. Returns failure if nothing was done.
77   LogicalResult runOnAffineForOp(AffineForOp forOp);
78 
79   static const unsigned kDefaultUnrollFactor = 4;
80 };
81 } // end anonymous namespace
82 
83 /// Returns true if no other affine.for ops are nested within.
84 static bool isInnermostAffineForOp(AffineForOp forOp) {
85   // Only for the innermost affine.for op's.
86   bool isInnermost = true;
87   forOp.walk([&](AffineForOp thisForOp) {
88     // Since this is a post order walk, we are able to conclude here.
89     isInnermost = (thisForOp == forOp);
90     return WalkResult::interrupt();
91   });
92   return isInnermost;
93 }
94 
95 /// Gathers loops that have no affine.for's nested within.
96 static void gatherInnermostLoops(FuncOp f,
97                                  SmallVectorImpl<AffineForOp> &loops) {
98   f.walk([&](AffineForOp forOp) {
99     if (isInnermostAffineForOp(forOp))
100       loops.push_back(forOp);
101   });
102 }
103 
104 void LoopUnroll::runOnFunction() {
105   if (clUnrollFull.getNumOccurrences() > 0 &&
106       clUnrollFullThreshold.getNumOccurrences() > 0) {
107     // Store short loops as we walk.
108     SmallVector<AffineForOp, 4> loops;
109 
110     // Gathers all loops with trip count <= minTripCount. Do a post order walk
111     // so that loops are gathered from innermost to outermost (or else unrolling
112     // an outer one may delete gathered inner ones).
113     getFunction().walk([&](AffineForOp forOp) {
114       Optional<uint64_t> tripCount = getConstantTripCount(forOp);
115       if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold)
116         loops.push_back(forOp);
117     });
118     for (auto forOp : loops)
119       loopUnrollFull(forOp);
120     return;
121   }
122 
123   unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0
124                                 ? clUnrollNumRepetitions
125                                 : 1;
126   // If the call back is provided, we will recurse until no loops are found.
127   FuncOp func = getFunction();
128   SmallVector<AffineForOp, 4> loops;
129   for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
130     loops.clear();
131     gatherInnermostLoops(func, loops);
132     if (loops.empty())
133       break;
134     bool unrolled = false;
135     for (auto forOp : loops)
136       unrolled |= succeeded(runOnAffineForOp(forOp));
137     if (!unrolled)
138       // Break out if nothing was unrolled.
139       break;
140   }
141 }
142 
143 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
144 /// failure otherwise. The default unroll factor is 4.
145 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
146   // Use the function callback if one was provided.
147   if (getUnrollFactor) {
148     return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
149   }
150   // Unroll by the factor passed, if any.
151   if (unrollFactor.hasValue())
152     return loopUnrollByFactor(forOp, unrollFactor.getValue());
153   // Unroll by the command line factor if one was specified.
154   if (clUnrollFactor.getNumOccurrences() > 0)
155     return loopUnrollByFactor(forOp, clUnrollFactor);
156   // Unroll completely if full loop unroll was specified.
157   if (clUnrollFull.getNumOccurrences() > 0 ||
158       (unrollFull.hasValue() && unrollFull.getValue()))
159     return loopUnrollFull(forOp);
160 
161   // Unroll by four otherwise.
162   return loopUnrollByFactor(forOp, kDefaultUnrollFactor);
163 }
164 
165 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopUnrollPass(
166     int unrollFactor, int unrollFull,
167     const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
168   return std::make_unique<LoopUnroll>(
169       unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
170       unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
171 }
172 
173 static PassRegistration<LoopUnroll> pass("affine-loop-unroll", "Unroll loops");
174