1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unrolling.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/Passes.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Transforms/LoopUtils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/Support/CommandLine.h"
22 #include "llvm/Support/Debug.h"
23 
24 using namespace mlir;
25 
26 #define DEBUG_TYPE "affine-loop-unroll"
27 
28 namespace {
29 
30 // TODO: this is really a test pass and should be moved out of dialect
31 // transforms.
32 
33 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
34 /// full unroll threshold was specified, in which case, fully unrolls all loops
35 /// with trip count less than the specified threshold. The latter is for testing
36 /// purposes, especially for testing outer loop unrolling.
37 struct LoopUnroll : public AffineLoopUnrollBase<LoopUnroll> {
38   // Callback to obtain unroll factors; if this has a callable target, takes
39   // precedence over command-line argument or passed argument.
40   const std::function<unsigned(AffineForOp)> getUnrollFactor;
41 
42   LoopUnroll() : getUnrollFactor(nullptr) {}
43   LoopUnroll(const LoopUnroll &other)
44 
45       = default;
46   explicit LoopUnroll(
47       Optional<unsigned> unrollFactor = None, bool unrollUpToFactor = false,
48       bool unrollFull = false,
49       const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
50       : getUnrollFactor(getUnrollFactor) {
51     if (unrollFactor)
52       this->unrollFactor = *unrollFactor;
53     this->unrollUpToFactor = unrollUpToFactor;
54     this->unrollFull = unrollFull;
55   }
56 
57   void runOnOperation() override;
58 
59   /// Unroll this for op. Returns failure if nothing was done.
60   LogicalResult runOnAffineForOp(AffineForOp forOp);
61 };
62 } // namespace
63 
64 /// Returns true if no other affine.for ops are nested within.
65 static bool isInnermostAffineForOp(AffineForOp forOp) {
66   // Only for the innermost affine.for op's.
67   bool isInnermost = true;
68   forOp.walk([&](AffineForOp thisForOp) {
69     // Since this is a post order walk, we are able to conclude here.
70     isInnermost = (thisForOp == forOp);
71     return WalkResult::interrupt();
72   });
73   return isInnermost;
74 }
75 
76 /// Gathers loops that have no affine.for's nested within.
77 static void gatherInnermostLoops(FuncOp f,
78                                  SmallVectorImpl<AffineForOp> &loops) {
79   f.walk([&](AffineForOp forOp) {
80     if (isInnermostAffineForOp(forOp))
81       loops.push_back(forOp);
82   });
83 }
84 
85 void LoopUnroll::runOnOperation() {
86   FuncOp func = getOperation();
87   if (func.isExternal())
88     return;
89 
90   if (unrollFull && unrollFullThreshold.hasValue()) {
91     // Store short loops as we walk.
92     SmallVector<AffineForOp, 4> loops;
93 
94     // Gathers all loops with trip count <= minTripCount. Do a post order walk
95     // so that loops are gathered from innermost to outermost (or else unrolling
96     // an outer one may delete gathered inner ones).
97     getOperation().walk([&](AffineForOp forOp) {
98       Optional<uint64_t> tripCount = getConstantTripCount(forOp);
99       if (tripCount.hasValue() && tripCount.getValue() <= unrollFullThreshold)
100         loops.push_back(forOp);
101     });
102     for (auto forOp : loops)
103       (void)loopUnrollFull(forOp);
104     return;
105   }
106 
107   // If the call back is provided, we will recurse until no loops are found.
108   SmallVector<AffineForOp, 4> loops;
109   for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
110     loops.clear();
111     gatherInnermostLoops(func, loops);
112     if (loops.empty())
113       break;
114     bool unrolled = false;
115     for (auto forOp : loops)
116       unrolled |= succeeded(runOnAffineForOp(forOp));
117     if (!unrolled)
118       // Break out if nothing was unrolled.
119       break;
120   }
121 }
122 
123 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
124 /// failure otherwise. The default unroll factor is 4.
125 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
126   // Use the function callback if one was provided.
127   if (getUnrollFactor)
128     return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
129   // Unroll completely if full loop unroll was specified.
130   if (unrollFull)
131     return loopUnrollFull(forOp);
132   // Otherwise, unroll by the given unroll factor.
133   if (unrollUpToFactor)
134     return loopUnrollUpToFactor(forOp, unrollFactor);
135   return loopUnrollByFactor(forOp, unrollFactor);
136 }
137 
138 std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopUnrollPass(
139     int unrollFactor, bool unrollUpToFactor, bool unrollFull,
140     const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
141   return std::make_unique<LoopUnroll>(
142       unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
143       unrollUpToFactor, unrollFull, getUnrollFactor);
144 }
145