1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop unrolling. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/LoopAnalysis.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Affine/Passes.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/LoopUtils.h" 21 #include "llvm/ADT/DenseMap.h" 22 #include "llvm/Support/CommandLine.h" 23 #include "llvm/Support/Debug.h" 24 25 using namespace mlir; 26 27 #define DEBUG_TYPE "affine-loop-unroll" 28 29 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 30 31 // Loop unrolling factor. 32 static llvm::cl::opt<unsigned> clUnrollFactor( 33 "unroll-factor", 34 llvm::cl::desc("Use this unroll factor for all loops being unrolled"), 35 llvm::cl::cat(clOptionsCategory)); 36 37 static llvm::cl::opt<bool> clUnrollFull("unroll-full", 38 llvm::cl::desc("Fully unroll loops"), 39 llvm::cl::cat(clOptionsCategory)); 40 41 static llvm::cl::opt<unsigned> clUnrollNumRepetitions( 42 "unroll-num-reps", 43 llvm::cl::desc("Unroll innermost loops repeatedly this many times"), 44 llvm::cl::cat(clOptionsCategory)); 45 46 static llvm::cl::opt<unsigned> clUnrollFullThreshold( 47 "unroll-full-threshold", llvm::cl::Hidden, 48 llvm::cl::desc( 49 "Unroll all loops with trip count less than or equal to this"), 50 llvm::cl::cat(clOptionsCategory)); 51 52 namespace { 53 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a 54 /// full unroll threshold was specified, in which case, fully unrolls all loops 55 /// with trip count less than the specified threshold. The latter is for testing 56 /// purposes, especially for testing outer loop unrolling. 57 struct LoopUnroll : public FunctionPass<LoopUnroll> { 58 const Optional<unsigned> unrollFactor; 59 const Optional<bool> unrollFull; 60 // Callback to obtain unroll factors; if this has a callable target, takes 61 // precedence over command-line argument or passed argument. 62 const std::function<unsigned(AffineForOp)> getUnrollFactor; 63 64 explicit LoopUnroll( 65 Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None, 66 const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) 67 : unrollFactor(unrollFactor), unrollFull(unrollFull), 68 getUnrollFactor(getUnrollFactor) {} 69 70 void runOnFunction() override; 71 72 /// Unroll this for op. Returns failure if nothing was done. 73 LogicalResult runOnAffineForOp(AffineForOp forOp); 74 75 static const unsigned kDefaultUnrollFactor = 4; 76 }; 77 } // end anonymous namespace 78 79 void LoopUnroll::runOnFunction() { 80 // Gathers all innermost loops through a post order pruned walk. 81 struct InnermostLoopGatherer { 82 // Store innermost loops as we walk. 83 std::vector<AffineForOp> loops; 84 85 void walkPostOrder(FuncOp f) { 86 for (auto &b : f) 87 walkPostOrder(b.begin(), b.end()); 88 } 89 90 bool walkPostOrder(Block::iterator Start, Block::iterator End) { 91 bool hasInnerLoops = false; 92 // We need to walk all elements since all innermost loops need to be 93 // gathered as opposed to determining whether this list has any inner 94 // loops or not. 95 while (Start != End) 96 hasInnerLoops |= walkPostOrder(&(*Start++)); 97 return hasInnerLoops; 98 } 99 bool walkPostOrder(Operation *opInst) { 100 bool hasInnerLoops = false; 101 for (auto ®ion : opInst->getRegions()) 102 for (auto &block : region) 103 hasInnerLoops |= walkPostOrder(block.begin(), block.end()); 104 if (isa<AffineForOp>(opInst)) { 105 if (!hasInnerLoops) 106 loops.push_back(cast<AffineForOp>(opInst)); 107 return true; 108 } 109 return hasInnerLoops; 110 } 111 }; 112 113 if (clUnrollFull.getNumOccurrences() > 0 && 114 clUnrollFullThreshold.getNumOccurrences() > 0) { 115 // Store short loops as we walk. 116 std::vector<AffineForOp> loops; 117 118 // Gathers all loops with trip count <= minTripCount. Do a post order walk 119 // so that loops are gathered from innermost to outermost (or else unrolling 120 // an outer one may delete gathered inner ones). 121 getFunction().walk([&](AffineForOp forOp) { 122 Optional<uint64_t> tripCount = getConstantTripCount(forOp); 123 if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) 124 loops.push_back(forOp); 125 }); 126 for (auto forOp : loops) 127 loopUnrollFull(forOp); 128 return; 129 } 130 131 unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0 132 ? clUnrollNumRepetitions 133 : 1; 134 // If the call back is provided, we will recurse until no loops are found. 135 FuncOp func = getFunction(); 136 for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { 137 InnermostLoopGatherer ilg; 138 ilg.walkPostOrder(func); 139 auto &loops = ilg.loops; 140 if (loops.empty()) 141 break; 142 bool unrolled = false; 143 for (auto forOp : loops) 144 unrolled |= succeeded(runOnAffineForOp(forOp)); 145 if (!unrolled) 146 // Break out if nothing was unrolled. 147 break; 148 } 149 } 150 151 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, 152 /// failure otherwise. The default unroll factor is 4. 153 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { 154 // Use the function callback if one was provided. 155 if (getUnrollFactor) { 156 return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); 157 } 158 // Unroll by the factor passed, if any. 159 if (unrollFactor.hasValue()) 160 return loopUnrollByFactor(forOp, unrollFactor.getValue()); 161 // Unroll by the command line factor if one was specified. 162 if (clUnrollFactor.getNumOccurrences() > 0) 163 return loopUnrollByFactor(forOp, clUnrollFactor); 164 // Unroll completely if full loop unroll was specified. 165 if (clUnrollFull.getNumOccurrences() > 0 || 166 (unrollFull.hasValue() && unrollFull.getValue())) 167 return loopUnrollFull(forOp); 168 169 // Unroll by four otherwise. 170 return loopUnrollByFactor(forOp, kDefaultUnrollFactor); 171 } 172 173 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLoopUnrollPass( 174 int unrollFactor, int unrollFull, 175 const std::function<unsigned(AffineForOp)> &getUnrollFactor) { 176 return std::make_unique<LoopUnroll>( 177 unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), 178 unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor); 179 } 180 181 static PassRegistration<LoopUnroll> pass("affine-loop-unroll", "Unroll loops"); 182