1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop unrolling. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/LoopAnalysis.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Affine/Passes.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/LoopUtils.h" 21 #include "llvm/ADT/DenseMap.h" 22 #include "llvm/Support/CommandLine.h" 23 #include "llvm/Support/Debug.h" 24 25 using namespace mlir; 26 27 #define DEBUG_TYPE "affine-loop-unroll" 28 29 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 30 31 // Loop unrolling factor. 32 static llvm::cl::opt<unsigned> clUnrollFactor( 33 "unroll-factor", 34 llvm::cl::desc("Use this unroll factor for all loops being unrolled"), 35 llvm::cl::cat(clOptionsCategory)); 36 37 static llvm::cl::opt<bool> clUnrollFull("unroll-full", 38 llvm::cl::desc("Fully unroll loops"), 39 llvm::cl::cat(clOptionsCategory)); 40 41 static llvm::cl::opt<unsigned> clUnrollNumRepetitions( 42 "unroll-num-reps", 43 llvm::cl::desc("Unroll innermost loops repeatedly this many times"), 44 llvm::cl::cat(clOptionsCategory)); 45 46 static llvm::cl::opt<unsigned> clUnrollFullThreshold( 47 "unroll-full-threshold", llvm::cl::Hidden, 48 llvm::cl::desc( 49 "Unroll all loops with trip count less than or equal to this"), 50 llvm::cl::cat(clOptionsCategory)); 51 52 namespace { 53 54 // TODO: this is really a test pass and should be moved out of dialect 55 // transforms. 56 57 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a 58 /// full unroll threshold was specified, in which case, fully unrolls all loops 59 /// with trip count less than the specified threshold. The latter is for testing 60 /// purposes, especially for testing outer loop unrolling. 61 struct LoopUnroll : public FunctionPass<LoopUnroll> { 62 const Optional<unsigned> unrollFactor; 63 const Optional<bool> unrollFull; 64 // Callback to obtain unroll factors; if this has a callable target, takes 65 // precedence over command-line argument or passed argument. 66 const std::function<unsigned(AffineForOp)> getUnrollFactor; 67 68 explicit LoopUnroll( 69 Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None, 70 const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) 71 : unrollFactor(unrollFactor), unrollFull(unrollFull), 72 getUnrollFactor(getUnrollFactor) {} 73 74 void runOnFunction() override; 75 76 /// Unroll this for op. Returns failure if nothing was done. 77 LogicalResult runOnAffineForOp(AffineForOp forOp); 78 79 static const unsigned kDefaultUnrollFactor = 4; 80 }; 81 } // end anonymous namespace 82 83 /// Returns true if no other affine.for ops are nested within. 84 static bool isInnermostAffineForOp(AffineForOp forOp) { 85 // Only for the innermost affine.for op's. 86 bool isInnermost = true; 87 forOp.walk([&](AffineForOp thisForOp) { 88 // Since this is a post order walk, we are able to conclude here. 89 isInnermost = (thisForOp == forOp); 90 return WalkResult::interrupt(); 91 }); 92 return isInnermost; 93 } 94 95 /// Gathers loops that have no affine.for's nested within. 96 static void gatherInnermostLoops(FuncOp f, 97 SmallVectorImpl<AffineForOp> &loops) { 98 f.walk([&](AffineForOp forOp) { 99 if (isInnermostAffineForOp(forOp)) 100 loops.push_back(forOp); 101 }); 102 } 103 104 void LoopUnroll::runOnFunction() { 105 if (clUnrollFull.getNumOccurrences() > 0 && 106 clUnrollFullThreshold.getNumOccurrences() > 0) { 107 // Store short loops as we walk. 108 SmallVector<AffineForOp, 4> loops; 109 110 // Gathers all loops with trip count <= minTripCount. Do a post order walk 111 // so that loops are gathered from innermost to outermost (or else unrolling 112 // an outer one may delete gathered inner ones). 113 getFunction().walk([&](AffineForOp forOp) { 114 Optional<uint64_t> tripCount = getConstantTripCount(forOp); 115 if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) 116 loops.push_back(forOp); 117 }); 118 for (auto forOp : loops) 119 loopUnrollFull(forOp); 120 return; 121 } 122 123 unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0 124 ? clUnrollNumRepetitions 125 : 1; 126 // If the call back is provided, we will recurse until no loops are found. 127 FuncOp func = getFunction(); 128 SmallVector<AffineForOp, 4> loops; 129 for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { 130 loops.clear(); 131 gatherInnermostLoops(func, loops); 132 if (loops.empty()) 133 break; 134 bool unrolled = false; 135 for (auto forOp : loops) 136 unrolled |= succeeded(runOnAffineForOp(forOp)); 137 if (!unrolled) 138 // Break out if nothing was unrolled. 139 break; 140 } 141 } 142 143 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, 144 /// failure otherwise. The default unroll factor is 4. 145 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { 146 // Use the function callback if one was provided. 147 if (getUnrollFactor) { 148 return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); 149 } 150 // Unroll by the factor passed, if any. 151 if (unrollFactor.hasValue()) 152 return loopUnrollByFactor(forOp, unrollFactor.getValue()); 153 // Unroll by the command line factor if one was specified. 154 if (clUnrollFactor.getNumOccurrences() > 0) 155 return loopUnrollByFactor(forOp, clUnrollFactor); 156 // Unroll completely if full loop unroll was specified. 157 if (clUnrollFull.getNumOccurrences() > 0 || 158 (unrollFull.hasValue() && unrollFull.getValue())) 159 return loopUnrollFull(forOp); 160 161 // Unroll by four otherwise. 162 return loopUnrollByFactor(forOp, kDefaultUnrollFactor); 163 } 164 165 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopUnrollPass( 166 int unrollFactor, int unrollFull, 167 const std::function<unsigned(AffineForOp)> &getUnrollFactor) { 168 return std::make_unique<LoopUnroll>( 169 unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), 170 unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor); 171 } 172 173 static PassRegistration<LoopUnroll> pass("affine-loop-unroll", "Unroll loops"); 174