1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unrolling.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopUtils.h"
16 #include "mlir/Dialect/Affine/Passes.h"
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Builders.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/Support/CommandLine.h"
22 #include "llvm/Support/Debug.h"
23 
24 using namespace mlir;
25 
26 #define DEBUG_TYPE "affine-loop-unroll"
27 
28 namespace {
29 
30 // TODO: this is really a test pass and should be moved out of dialect
31 // transforms.
32 
33 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
34 /// full unroll threshold was specified, in which case, fully unrolls all loops
35 /// with trip count less than the specified threshold. The latter is for testing
36 /// purposes, especially for testing outer loop unrolling.
37 struct LoopUnroll : public AffineLoopUnrollBase<LoopUnroll> {
38   // Callback to obtain unroll factors; if this has a callable target, takes
39   // precedence over command-line argument or passed argument.
40   const std::function<unsigned(AffineForOp)> getUnrollFactor;
41 
42   LoopUnroll() : getUnrollFactor(nullptr) {}
43   LoopUnroll(const LoopUnroll &other)
44 
45       = default;
46   explicit LoopUnroll(
47       Optional<unsigned> unrollFactor = None, bool unrollUpToFactor = false,
48       bool unrollFull = false,
49       const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
50       : getUnrollFactor(getUnrollFactor) {
51     if (unrollFactor)
52       this->unrollFactor = *unrollFactor;
53     this->unrollUpToFactor = unrollUpToFactor;
54     this->unrollFull = unrollFull;
55   }
56 
57   void runOnOperation() override;
58 
59   /// Unroll this for op. Returns failure if nothing was done.
60   LogicalResult runOnAffineForOp(AffineForOp forOp);
61 };
62 } // namespace
63 
64 /// Returns true if no other affine.for ops are nested within `op`.
65 static bool isInnermostAffineForOp(AffineForOp op) {
66   return !op.getBody()
67               ->walk([&](AffineForOp nestedForOp) {
68                 return WalkResult::interrupt();
69               })
70               .wasInterrupted();
71 }
72 
73 /// Gathers loops that have no affine.for's nested within.
74 static void gatherInnermostLoops(func::FuncOp f,
75                                  SmallVectorImpl<AffineForOp> &loops) {
76   f.walk([&](AffineForOp forOp) {
77     if (isInnermostAffineForOp(forOp))
78       loops.push_back(forOp);
79   });
80 }
81 
82 void LoopUnroll::runOnOperation() {
83   func::FuncOp func = getOperation();
84   if (func.isExternal())
85     return;
86 
87   if (unrollFull && unrollFullThreshold.hasValue()) {
88     // Store short loops as we walk.
89     SmallVector<AffineForOp, 4> loops;
90 
91     // Gathers all loops with trip count <= minTripCount. Do a post order walk
92     // so that loops are gathered from innermost to outermost (or else unrolling
93     // an outer one may delete gathered inner ones).
94     getOperation().walk([&](AffineForOp forOp) {
95       Optional<uint64_t> tripCount = getConstantTripCount(forOp);
96       if (tripCount && *tripCount <= unrollFullThreshold)
97         loops.push_back(forOp);
98     });
99     for (auto forOp : loops)
100       (void)loopUnrollFull(forOp);
101     return;
102   }
103 
104   // If the call back is provided, we will recurse until no loops are found.
105   SmallVector<AffineForOp, 4> loops;
106   for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
107     loops.clear();
108     gatherInnermostLoops(func, loops);
109     if (loops.empty())
110       break;
111     bool unrolled = false;
112     for (auto forOp : loops)
113       unrolled |= succeeded(runOnAffineForOp(forOp));
114     if (!unrolled)
115       // Break out if nothing was unrolled.
116       break;
117   }
118 }
119 
120 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
121 /// failure otherwise. The default unroll factor is 4.
122 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
123   // Use the function callback if one was provided.
124   if (getUnrollFactor)
125     return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
126   // Unroll completely if full loop unroll was specified.
127   if (unrollFull)
128     return loopUnrollFull(forOp);
129   // Otherwise, unroll by the given unroll factor.
130   if (unrollUpToFactor)
131     return loopUnrollUpToFactor(forOp, unrollFactor);
132   return loopUnrollByFactor(forOp, unrollFactor);
133 }
134 
135 std::unique_ptr<OperationPass<func::FuncOp>> mlir::createLoopUnrollPass(
136     int unrollFactor, bool unrollUpToFactor, bool unrollFull,
137     const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
138   return std::make_unique<LoopUnroll>(
139       unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
140       unrollUpToFactor, unrollFull, getUnrollFactor);
141 }
142