1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop unrolling. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/LoopAnalysis.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Affine/Passes.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/LoopUtils.h" 21 #include "llvm/ADT/DenseMap.h" 22 #include "llvm/Support/CommandLine.h" 23 #include "llvm/Support/Debug.h" 24 25 using namespace mlir; 26 27 #define DEBUG_TYPE "affine-loop-unroll" 28 29 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 30 31 // Loop unrolling factor. 32 static llvm::cl::opt<unsigned> clUnrollFactor( 33 "unroll-factor", 34 llvm::cl::desc("Use this unroll factor for all loops being unrolled"), 35 llvm::cl::cat(clOptionsCategory)); 36 37 static llvm::cl::opt<bool> clUnrollFull("unroll-full", 38 llvm::cl::desc("Fully unroll loops"), 39 llvm::cl::cat(clOptionsCategory)); 40 41 static llvm::cl::opt<unsigned> clUnrollNumRepetitions( 42 "unroll-num-reps", 43 llvm::cl::desc("Unroll innermost loops repeatedly this many times"), 44 llvm::cl::cat(clOptionsCategory)); 45 46 static llvm::cl::opt<unsigned> clUnrollFullThreshold( 47 "unroll-full-threshold", llvm::cl::Hidden, 48 llvm::cl::desc( 49 "Unroll all loops with trip count less than or equal to this"), 50 llvm::cl::cat(clOptionsCategory)); 51 52 namespace { 53 54 // TODO: this is really a test pass and should be moved out of dialect 55 // transforms. 56 57 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a 58 /// full unroll threshold was specified, in which case, fully unrolls all loops 59 /// with trip count less than the specified threshold. The latter is for testing 60 /// purposes, especially for testing outer loop unrolling. 61 struct LoopUnroll : public FunctionPass<LoopUnroll> { 62 /// Include the generated pass utilities. 63 #define GEN_PASS_AffineUnroll 64 #include "mlir/Dialect/Affine/Passes.h.inc" 65 66 const Optional<unsigned> unrollFactor; 67 const Optional<bool> unrollFull; 68 // Callback to obtain unroll factors; if this has a callable target, takes 69 // precedence over command-line argument or passed argument. 70 const std::function<unsigned(AffineForOp)> getUnrollFactor; 71 72 explicit LoopUnroll( 73 Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None, 74 const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) 75 : unrollFactor(unrollFactor), unrollFull(unrollFull), 76 getUnrollFactor(getUnrollFactor) {} 77 78 void runOnFunction() override; 79 80 /// Unroll this for op. Returns failure if nothing was done. 81 LogicalResult runOnAffineForOp(AffineForOp forOp); 82 83 static const unsigned kDefaultUnrollFactor = 4; 84 }; 85 } // end anonymous namespace 86 87 /// Returns true if no other affine.for ops are nested within. 88 static bool isInnermostAffineForOp(AffineForOp forOp) { 89 // Only for the innermost affine.for op's. 90 bool isInnermost = true; 91 forOp.walk([&](AffineForOp thisForOp) { 92 // Since this is a post order walk, we are able to conclude here. 93 isInnermost = (thisForOp == forOp); 94 return WalkResult::interrupt(); 95 }); 96 return isInnermost; 97 } 98 99 /// Gathers loops that have no affine.for's nested within. 100 static void gatherInnermostLoops(FuncOp f, 101 SmallVectorImpl<AffineForOp> &loops) { 102 f.walk([&](AffineForOp forOp) { 103 if (isInnermostAffineForOp(forOp)) 104 loops.push_back(forOp); 105 }); 106 } 107 108 void LoopUnroll::runOnFunction() { 109 if (clUnrollFull.getNumOccurrences() > 0 && 110 clUnrollFullThreshold.getNumOccurrences() > 0) { 111 // Store short loops as we walk. 112 SmallVector<AffineForOp, 4> loops; 113 114 // Gathers all loops with trip count <= minTripCount. Do a post order walk 115 // so that loops are gathered from innermost to outermost (or else unrolling 116 // an outer one may delete gathered inner ones). 117 getFunction().walk([&](AffineForOp forOp) { 118 Optional<uint64_t> tripCount = getConstantTripCount(forOp); 119 if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) 120 loops.push_back(forOp); 121 }); 122 for (auto forOp : loops) 123 loopUnrollFull(forOp); 124 return; 125 } 126 127 unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0 128 ? clUnrollNumRepetitions 129 : 1; 130 // If the call back is provided, we will recurse until no loops are found. 131 FuncOp func = getFunction(); 132 SmallVector<AffineForOp, 4> loops; 133 for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { 134 loops.clear(); 135 gatherInnermostLoops(func, loops); 136 if (loops.empty()) 137 break; 138 bool unrolled = false; 139 for (auto forOp : loops) 140 unrolled |= succeeded(runOnAffineForOp(forOp)); 141 if (!unrolled) 142 // Break out if nothing was unrolled. 143 break; 144 } 145 } 146 147 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, 148 /// failure otherwise. The default unroll factor is 4. 149 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { 150 // Use the function callback if one was provided. 151 if (getUnrollFactor) { 152 return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); 153 } 154 // Unroll by the factor passed, if any. 155 if (unrollFactor.hasValue()) 156 return loopUnrollByFactor(forOp, unrollFactor.getValue()); 157 // Unroll by the command line factor if one was specified. 158 if (clUnrollFactor.getNumOccurrences() > 0) 159 return loopUnrollByFactor(forOp, clUnrollFactor); 160 // Unroll completely if full loop unroll was specified. 161 if (clUnrollFull.getNumOccurrences() > 0 || 162 (unrollFull.hasValue() && unrollFull.getValue())) 163 return loopUnrollFull(forOp); 164 165 // Unroll by four otherwise. 166 return loopUnrollByFactor(forOp, kDefaultUnrollFactor); 167 } 168 169 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopUnrollPass( 170 int unrollFactor, int unrollFull, 171 const std::function<unsigned(AffineForOp)> &getUnrollFactor) { 172 return std::make_unique<LoopUnroll>( 173 unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), 174 unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor); 175 } 176