1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop unrolling. 10 // 11 //===----------------------------------------------------------------------===// 12 #include "PassDetail.h" 13 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Affine/Passes.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/Transforms/LoopUtils.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/Support/CommandLine.h" 22 #include "llvm/Support/Debug.h" 23 24 using namespace mlir; 25 26 #define DEBUG_TYPE "affine-loop-unroll" 27 28 namespace { 29 30 // TODO: this is really a test pass and should be moved out of dialect 31 // transforms. 32 33 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a 34 /// full unroll threshold was specified, in which case, fully unrolls all loops 35 /// with trip count less than the specified threshold. The latter is for testing 36 /// purposes, especially for testing outer loop unrolling. 37 struct LoopUnroll : public AffineLoopUnrollBase<LoopUnroll> { 38 // Callback to obtain unroll factors; if this has a callable target, takes 39 // precedence over command-line argument or passed argument. 40 const std::function<unsigned(AffineForOp)> getUnrollFactor; 41 42 LoopUnroll() : getUnrollFactor(nullptr) {} 43 LoopUnroll(const LoopUnroll &other) 44 45 = default; 46 explicit LoopUnroll( 47 Optional<unsigned> unrollFactor = None, bool unrollUpToFactor = false, 48 bool unrollFull = false, 49 const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) 50 : getUnrollFactor(getUnrollFactor) { 51 if (unrollFactor) 52 this->unrollFactor = *unrollFactor; 53 this->unrollUpToFactor = unrollUpToFactor; 54 this->unrollFull = unrollFull; 55 } 56 57 void runOnOperation() override; 58 59 /// Unroll this for op. Returns failure if nothing was done. 60 LogicalResult runOnAffineForOp(AffineForOp forOp); 61 }; 62 } // namespace 63 64 /// Returns true if no other affine.for ops are nested within. 65 static bool isInnermostAffineForOp(AffineForOp forOp) { 66 // Only for the innermost affine.for op's. 67 bool isInnermost = true; 68 forOp.walk([&](AffineForOp thisForOp) { 69 // Since this is a post order walk, we are able to conclude here. 70 isInnermost = (thisForOp == forOp); 71 return WalkResult::interrupt(); 72 }); 73 return isInnermost; 74 } 75 76 /// Gathers loops that have no affine.for's nested within. 77 static void gatherInnermostLoops(FuncOp f, 78 SmallVectorImpl<AffineForOp> &loops) { 79 f.walk([&](AffineForOp forOp) { 80 if (isInnermostAffineForOp(forOp)) 81 loops.push_back(forOp); 82 }); 83 } 84 85 void LoopUnroll::runOnOperation() { 86 FuncOp func = getOperation(); 87 if (func.isExternal()) 88 return; 89 90 if (unrollFull && unrollFullThreshold.hasValue()) { 91 // Store short loops as we walk. 92 SmallVector<AffineForOp, 4> loops; 93 94 // Gathers all loops with trip count <= minTripCount. Do a post order walk 95 // so that loops are gathered from innermost to outermost (or else unrolling 96 // an outer one may delete gathered inner ones). 97 getOperation().walk([&](AffineForOp forOp) { 98 Optional<uint64_t> tripCount = getConstantTripCount(forOp); 99 if (tripCount.hasValue() && tripCount.getValue() <= unrollFullThreshold) 100 loops.push_back(forOp); 101 }); 102 for (auto forOp : loops) 103 (void)loopUnrollFull(forOp); 104 return; 105 } 106 107 // If the call back is provided, we will recurse until no loops are found. 108 SmallVector<AffineForOp, 4> loops; 109 for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { 110 loops.clear(); 111 gatherInnermostLoops(func, loops); 112 if (loops.empty()) 113 break; 114 bool unrolled = false; 115 for (auto forOp : loops) 116 unrolled |= succeeded(runOnAffineForOp(forOp)); 117 if (!unrolled) 118 // Break out if nothing was unrolled. 119 break; 120 } 121 } 122 123 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, 124 /// failure otherwise. The default unroll factor is 4. 125 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { 126 // Use the function callback if one was provided. 127 if (getUnrollFactor) 128 return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); 129 // Unroll completely if full loop unroll was specified. 130 if (unrollFull) 131 return loopUnrollFull(forOp); 132 // Otherwise, unroll by the given unroll factor. 133 if (unrollUpToFactor) 134 return loopUnrollUpToFactor(forOp, unrollFactor); 135 return loopUnrollByFactor(forOp, unrollFactor); 136 } 137 138 std::unique_ptr<OperationPass<FuncOp>> mlir::createLoopUnrollPass( 139 int unrollFactor, bool unrollUpToFactor, bool unrollFull, 140 const std::function<unsigned(AffineForOp)> &getUnrollFactor) { 141 return std::make_unique<LoopUnroll>( 142 unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), 143 unrollUpToFactor, unrollFull, getUnrollFactor); 144 } 145