1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop unrolling. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Analysis/LoopAnalysis.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Affine/Passes.h" 17 #include "mlir/IR/AffineExpr.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/Transforms/LoopUtils.h" 21 #include "llvm/ADT/DenseMap.h" 22 #include "llvm/Support/CommandLine.h" 23 #include "llvm/Support/Debug.h" 24 25 using namespace mlir; 26 27 #define DEBUG_TYPE "affine-loop-unroll" 28 29 namespace { 30 31 // TODO: this is really a test pass and should be moved out of dialect 32 // transforms. 33 34 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a 35 /// full unroll threshold was specified, in which case, fully unrolls all loops 36 /// with trip count less than the specified threshold. The latter is for testing 37 /// purposes, especially for testing outer loop unrolling. 38 struct LoopUnroll : public AffineLoopUnrollBase<LoopUnroll> { 39 // Callback to obtain unroll factors; if this has a callable target, takes 40 // precedence over command-line argument or passed argument. 41 const std::function<unsigned(AffineForOp)> getUnrollFactor; 42 43 LoopUnroll() : getUnrollFactor(nullptr) {} 44 LoopUnroll(const LoopUnroll &other) 45 : AffineLoopUnrollBase<LoopUnroll>(other), 46 getUnrollFactor(other.getUnrollFactor) {} 47 explicit LoopUnroll( 48 Optional<unsigned> unrollFactor = None, bool unrollFull = false, 49 const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) 50 : getUnrollFactor(getUnrollFactor) { 51 if (unrollFactor) 52 this->unrollFactor = *unrollFactor; 53 this->unrollFull = unrollFull; 54 } 55 56 void runOnFunction() override; 57 58 /// Unroll this for op. Returns failure if nothing was done. 59 LogicalResult runOnAffineForOp(AffineForOp forOp); 60 }; 61 } // end anonymous namespace 62 63 /// Returns true if no other affine.for ops are nested within. 64 static bool isInnermostAffineForOp(AffineForOp forOp) { 65 // Only for the innermost affine.for op's. 66 bool isInnermost = true; 67 forOp.walk([&](AffineForOp thisForOp) { 68 // Since this is a post order walk, we are able to conclude here. 69 isInnermost = (thisForOp == forOp); 70 return WalkResult::interrupt(); 71 }); 72 return isInnermost; 73 } 74 75 /// Gathers loops that have no affine.for's nested within. 76 static void gatherInnermostLoops(FuncOp f, 77 SmallVectorImpl<AffineForOp> &loops) { 78 f.walk([&](AffineForOp forOp) { 79 if (isInnermostAffineForOp(forOp)) 80 loops.push_back(forOp); 81 }); 82 } 83 84 void LoopUnroll::runOnFunction() { 85 if (unrollFull && unrollFullThreshold.hasValue()) { 86 // Store short loops as we walk. 87 SmallVector<AffineForOp, 4> loops; 88 89 // Gathers all loops with trip count <= minTripCount. Do a post order walk 90 // so that loops are gathered from innermost to outermost (or else unrolling 91 // an outer one may delete gathered inner ones). 92 getFunction().walk([&](AffineForOp forOp) { 93 Optional<uint64_t> tripCount = getConstantTripCount(forOp); 94 if (tripCount.hasValue() && tripCount.getValue() <= unrollFullThreshold) 95 loops.push_back(forOp); 96 }); 97 for (auto forOp : loops) 98 loopUnrollFull(forOp); 99 return; 100 } 101 102 // If the call back is provided, we will recurse until no loops are found. 103 FuncOp func = getFunction(); 104 SmallVector<AffineForOp, 4> loops; 105 for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { 106 loops.clear(); 107 gatherInnermostLoops(func, loops); 108 if (loops.empty()) 109 break; 110 bool unrolled = false; 111 for (auto forOp : loops) 112 unrolled |= succeeded(runOnAffineForOp(forOp)); 113 if (!unrolled) 114 // Break out if nothing was unrolled. 115 break; 116 } 117 } 118 119 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, 120 /// failure otherwise. The default unroll factor is 4. 121 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { 122 // Use the function callback if one was provided. 123 if (getUnrollFactor) 124 return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); 125 // Unroll completely if full loop unroll was specified. 126 if (unrollFull) 127 return loopUnrollFull(forOp); 128 // Otherwise, unroll by the given unroll factor. 129 return loopUnrollByFactor(forOp, unrollFactor); 130 } 131 132 std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopUnrollPass( 133 int unrollFactor, bool unrollFull, 134 const std::function<unsigned(AffineForOp)> &getUnrollFactor) { 135 return std::make_unique<LoopUnroll>( 136 unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), unrollFull, 137 getUnrollFactor); 138 } 139