1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop unrolling. 10 // 11 //===----------------------------------------------------------------------===// 12 #include "PassDetail.h" 13 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Affine/LoopUtils.h" 16 #include "mlir/Dialect/Affine/Passes.h" 17 #include "mlir/IR/AffineExpr.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/Builders.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/Support/CommandLine.h" 22 #include "llvm/Support/Debug.h" 23 24 using namespace mlir; 25 26 #define DEBUG_TYPE "affine-loop-unroll" 27 28 namespace { 29 30 // TODO: this is really a test pass and should be moved out of dialect 31 // transforms. 32 33 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a 34 /// full unroll threshold was specified, in which case, fully unrolls all loops 35 /// with trip count less than the specified threshold. The latter is for testing 36 /// purposes, especially for testing outer loop unrolling. 37 struct LoopUnroll : public AffineLoopUnrollBase<LoopUnroll> { 38 // Callback to obtain unroll factors; if this has a callable target, takes 39 // precedence over command-line argument or passed argument. 40 const std::function<unsigned(AffineForOp)> getUnrollFactor; 41 42 LoopUnroll() : getUnrollFactor(nullptr) {} 43 LoopUnroll(const LoopUnroll &other) 44 45 = default; 46 explicit LoopUnroll( 47 Optional<unsigned> unrollFactor = None, bool unrollUpToFactor = false, 48 bool unrollFull = false, 49 const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) 50 : getUnrollFactor(getUnrollFactor) { 51 if (unrollFactor) 52 this->unrollFactor = *unrollFactor; 53 this->unrollUpToFactor = unrollUpToFactor; 54 this->unrollFull = unrollFull; 55 } 56 57 void runOnOperation() override; 58 59 /// Unroll this for op. Returns failure if nothing was done. 60 LogicalResult runOnAffineForOp(AffineForOp forOp); 61 }; 62 } // namespace 63 64 /// Returns true if no other affine.for ops are nested within `op`. 65 static bool isInnermostAffineForOp(AffineForOp op) { 66 return !op.getBody() 67 ->walk([&](AffineForOp nestedForOp) { 68 return WalkResult::interrupt(); 69 }) 70 .wasInterrupted(); 71 } 72 73 /// Gathers loops that have no affine.for's nested within. 74 static void gatherInnermostLoops(func::FuncOp f, 75 SmallVectorImpl<AffineForOp> &loops) { 76 f.walk([&](AffineForOp forOp) { 77 if (isInnermostAffineForOp(forOp)) 78 loops.push_back(forOp); 79 }); 80 } 81 82 void LoopUnroll::runOnOperation() { 83 func::FuncOp func = getOperation(); 84 if (func.isExternal()) 85 return; 86 87 if (unrollFull && unrollFullThreshold.hasValue()) { 88 // Store short loops as we walk. 89 SmallVector<AffineForOp, 4> loops; 90 91 // Gathers all loops with trip count <= minTripCount. Do a post order walk 92 // so that loops are gathered from innermost to outermost (or else unrolling 93 // an outer one may delete gathered inner ones). 94 getOperation().walk([&](AffineForOp forOp) { 95 Optional<uint64_t> tripCount = getConstantTripCount(forOp); 96 if (tripCount && *tripCount <= unrollFullThreshold) 97 loops.push_back(forOp); 98 }); 99 for (auto forOp : loops) 100 (void)loopUnrollFull(forOp); 101 return; 102 } 103 104 // If the call back is provided, we will recurse until no loops are found. 105 SmallVector<AffineForOp, 4> loops; 106 for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { 107 loops.clear(); 108 gatherInnermostLoops(func, loops); 109 if (loops.empty()) 110 break; 111 bool unrolled = false; 112 for (auto forOp : loops) 113 unrolled |= succeeded(runOnAffineForOp(forOp)); 114 if (!unrolled) 115 // Break out if nothing was unrolled. 116 break; 117 } 118 } 119 120 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, 121 /// failure otherwise. The default unroll factor is 4. 122 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { 123 // Use the function callback if one was provided. 124 if (getUnrollFactor) 125 return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); 126 // Unroll completely if full loop unroll was specified. 127 if (unrollFull) 128 return loopUnrollFull(forOp); 129 // Otherwise, unroll by the given unroll factor. 130 if (unrollUpToFactor) 131 return loopUnrollUpToFactor(forOp, unrollFactor); 132 return loopUnrollByFactor(forOp, unrollFactor); 133 } 134 135 std::unique_ptr<OperationPass<func::FuncOp>> mlir::createLoopUnrollPass( 136 int unrollFactor, bool unrollUpToFactor, bool unrollFull, 137 const std::function<unsigned(AffineForOp)> &getUnrollFactor) { 138 return std::make_unique<LoopUnroll>( 139 unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), 140 unrollUpToFactor, unrollFull, getUnrollFactor); 141 } 142