1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unrolling.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Analysis/LoopAnalysis.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Affine/Passes.h"
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/Transforms/LoopUtils.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Support/Debug.h"
24 
25 using namespace mlir;
26 
27 #define DEBUG_TYPE "affine-loop-unroll"
28 
29 namespace {
30 
31 // TODO: this is really a test pass and should be moved out of dialect
32 // transforms.
33 
34 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
35 /// full unroll threshold was specified, in which case, fully unrolls all loops
36 /// with trip count less than the specified threshold. The latter is for testing
37 /// purposes, especially for testing outer loop unrolling.
38 struct LoopUnroll : public AffineLoopUnrollBase<LoopUnroll> {
39   // Callback to obtain unroll factors; if this has a callable target, takes
40   // precedence over command-line argument or passed argument.
41   const std::function<unsigned(AffineForOp)> getUnrollFactor;
42 
43   LoopUnroll() : getUnrollFactor(nullptr) {}
44   LoopUnroll(const LoopUnroll &other)
45       : AffineLoopUnrollBase<LoopUnroll>(other),
46         getUnrollFactor(other.getUnrollFactor) {}
47   explicit LoopUnroll(
48       Optional<unsigned> unrollFactor = None, bool unrollFull = false,
49       const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
50       : getUnrollFactor(getUnrollFactor) {
51     if (unrollFactor)
52       this->unrollFactor = *unrollFactor;
53     this->unrollFull = unrollFull;
54   }
55 
56   void runOnFunction() override;
57 
58   /// Unroll this for op. Returns failure if nothing was done.
59   LogicalResult runOnAffineForOp(AffineForOp forOp);
60 };
61 } // end anonymous namespace
62 
63 /// Returns true if no other affine.for ops are nested within.
64 static bool isInnermostAffineForOp(AffineForOp forOp) {
65   // Only for the innermost affine.for op's.
66   bool isInnermost = true;
67   forOp.walk([&](AffineForOp thisForOp) {
68     // Since this is a post order walk, we are able to conclude here.
69     isInnermost = (thisForOp == forOp);
70     return WalkResult::interrupt();
71   });
72   return isInnermost;
73 }
74 
75 /// Gathers loops that have no affine.for's nested within.
76 static void gatherInnermostLoops(FuncOp f,
77                                  SmallVectorImpl<AffineForOp> &loops) {
78   f.walk([&](AffineForOp forOp) {
79     if (isInnermostAffineForOp(forOp))
80       loops.push_back(forOp);
81   });
82 }
83 
84 void LoopUnroll::runOnFunction() {
85   if (unrollFull && unrollFullThreshold.hasValue()) {
86     // Store short loops as we walk.
87     SmallVector<AffineForOp, 4> loops;
88 
89     // Gathers all loops with trip count <= minTripCount. Do a post order walk
90     // so that loops are gathered from innermost to outermost (or else unrolling
91     // an outer one may delete gathered inner ones).
92     getFunction().walk([&](AffineForOp forOp) {
93       Optional<uint64_t> tripCount = getConstantTripCount(forOp);
94       if (tripCount.hasValue() && tripCount.getValue() <= unrollFullThreshold)
95         loops.push_back(forOp);
96     });
97     for (auto forOp : loops)
98       loopUnrollFull(forOp);
99     return;
100   }
101 
102   // If the call back is provided, we will recurse until no loops are found.
103   FuncOp func = getFunction();
104   SmallVector<AffineForOp, 4> loops;
105   for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
106     loops.clear();
107     gatherInnermostLoops(func, loops);
108     if (loops.empty())
109       break;
110     bool unrolled = false;
111     for (auto forOp : loops)
112       unrolled |= succeeded(runOnAffineForOp(forOp));
113     if (!unrolled)
114       // Break out if nothing was unrolled.
115       break;
116   }
117 }
118 
119 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
120 /// failure otherwise. The default unroll factor is 4.
121 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
122   // Use the function callback if one was provided.
123   if (getUnrollFactor)
124     return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
125   // Unroll completely if full loop unroll was specified.
126   if (unrollFull)
127     return loopUnrollFull(forOp);
128   // Otherwise, unroll by the given unroll factor.
129   return loopUnrollByFactor(forOp, unrollFactor);
130 }
131 
132 std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopUnrollPass(
133     int unrollFactor, bool unrollFull,
134     const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
135   return std::make_unique<LoopUnroll>(
136       unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), unrollFull,
137       getUnrollFactor);
138 }
139