1 //===- ParallelLoopTiling.cpp - Tiles scf.parallel ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop tiling on parallel loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/SCF/Passes.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms.h" 19 #include "mlir/Dialect/SCF/Utils.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" 21 22 using namespace mlir; 23 using namespace mlir::scf; 24 25 /// Tile a parallel loop of the form 26 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) 27 /// step (%arg4, %arg5) 28 /// 29 /// into 30 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) 31 /// step (%arg4*tileSize[0], 32 /// %arg5*tileSize[1]) 33 /// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0) 34 /// min(%arg5*tileSize[1], %arg3-%i1)) 35 /// step (%arg4, %arg5) 36 /// 37 /// or, when no-min-max-bounds is true, into 38 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) 39 /// step (%arg4*tileSize[0], 40 /// %arg5*tileSize[1]) 41 /// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0], 42 /// %arg5*tileSize[1]) 43 /// step (%arg4, %arg5) 44 /// %inbound = (%j0 * %arg4 + %i0 < %arg2) && 45 /// (%j1 * %arg5 + %i1 < %arg3) 46 /// scf.if (%inbound) 47 /// .... 48 /// 49 /// where the uses of %i0 and %i1 in the loop body are replaced by 50 /// %i0 + j0 and %i1 + %j1. 51 // 52 /// The old loop is replaced with the new one. 53 std::pair<ParallelOp, ParallelOp> 54 mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes, 55 bool noMinMaxBounds) { 56 OpBuilder b(op); 57 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); 58 SmallVector<Value, 2> tileSizeConstants; 59 tileSizeConstants.reserve(op.upperBound().size()); 60 for (size_t i = 0, end = op.upperBound().size(); i != end; ++i) { 61 if (i < tileSizes.size()) 62 tileSizeConstants.push_back( 63 b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i])); 64 else 65 // Just pick 1 for the remaining dimensions. 66 tileSizeConstants.push_back( 67 b.create<arith::ConstantIndexOp>(op.getLoc(), 1)); 68 } 69 70 // Create the outer loop with adjusted steps. 71 SmallVector<Value, 2> newSteps; 72 newSteps.reserve(op.step().size()); 73 for (auto step : llvm::zip(op.step(), tileSizeConstants)) { 74 newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step), 75 std::get<1>(step))); 76 } 77 auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.lowerBound(), 78 op.upperBound(), newSteps); 79 b.setInsertionPointToStart(outerLoop.getBody()); 80 81 // Compute min(size, dim - offset) to avoid out-of-bounds accesses. 82 auto minMap = AffineMap::get( 83 /*dimCount=*/3, /*symbolCount=*/0, 84 {getAffineDimExpr(/*position=*/0, b.getContext()), 85 getAffineDimExpr(/*position=*/1, b.getContext()) - 86 getAffineDimExpr(/*position=*/2, b.getContext())}, 87 b.getContext()); 88 89 // Create the inner loop with adjusted bounds. 90 SmallVector<Value, 2> newBounds; 91 newBounds.reserve(op.upperBound().size()); 92 bool needInboundCheck = false; 93 for (auto dim : llvm::zip(outerLoop.lowerBound(), outerLoop.upperBound(), 94 outerLoop.step(), outerLoop.getInductionVars(), 95 op.step(), tileSizeConstants)) { 96 Value lowerBound, upperBound, newStep, iv, step, tileSizeConstant; 97 std::tie(lowerBound, upperBound, newStep, iv, step, tileSizeConstant) = dim; 98 // Collect the statically known loop bounds 99 auto lowerBoundConstant = 100 dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp()); 101 auto upperBoundConstant = 102 dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp()); 103 auto stepConstant = 104 dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp()); 105 auto tileSize = 106 cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value(); 107 // If the loop bounds and the loop step are constant and if the number of 108 // loop iterations is an integer multiple of the tile size, we use a static 109 // bound for the inner loop. 110 if (lowerBoundConstant && upperBoundConstant && stepConstant) { 111 auto numIterations = llvm::divideCeil(upperBoundConstant.value() - 112 lowerBoundConstant.value(), 113 stepConstant.value()); 114 if (numIterations % tileSize == 0) { 115 newBounds.push_back(newStep); 116 continue; 117 } 118 } 119 120 // For InboundCheck mode, just use the variable outer step 121 if (noMinMaxBounds) { 122 newBounds.push_back(newStep); 123 needInboundCheck = true; 124 continue; 125 } 126 127 // Otherwise, we dynamically compute the bound for 128 // each iteration of the outer loop. 129 newBounds.push_back( 130 b.create<AffineMinOp>(op.getLoc(), b.getIndexType(), minMap, 131 ValueRange{newStep, upperBound, iv})); 132 } 133 auto innerLoop = b.create<ParallelOp>( 134 op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds, 135 op.step()); 136 137 if (noMinMaxBounds && needInboundCheck) { 138 b.setInsertionPointToStart(innerLoop.getBody()); 139 // Insert in-bound check 140 Value inbound = 141 b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1)); 142 for (auto dim : 143 llvm::zip(outerLoop.upperBound(), outerLoop.getInductionVars(), 144 innerLoop.getInductionVars(), innerLoop.step())) { 145 Value outerUpperBound, outerIV, innerIV, innerStep; 146 std::tie(outerUpperBound, outerIV, innerIV, innerStep) = dim; 147 // %in_bound = %in_bound && 148 // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound) 149 Value index = b.create<arith::AddIOp>( 150 op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep), 151 outerIV); 152 Value dimInbound = b.create<arith::CmpIOp>( 153 op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound); 154 inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound); 155 } 156 auto ifInbound = b.create<IfOp>(op.getLoc(), 157 /*resultTypes*/ ArrayRef<Type>{}, inbound, 158 /*hasElseRegion*/ false); 159 ifInbound.thenRegion().takeBody(op.region()); 160 Block &thenBlock = ifInbound.thenRegion().front(); 161 b.setInsertionPointToStart(innerLoop.getBody()); 162 for (auto ivs : llvm::enumerate(llvm::zip(innerLoop.getInductionVars(), 163 outerLoop.getInductionVars()))) { 164 auto newIndex = b.create<arith::AddIOp>( 165 op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value())); 166 thenBlock.getArgument(ivs.index()) 167 .replaceAllUsesExcept(newIndex, newIndex); 168 } 169 thenBlock.eraseArguments(llvm::to_vector<4>( 170 llvm::seq((unsigned)0, thenBlock.getNumArguments()))); 171 } else { 172 innerLoop.region().takeBody(op.region()); 173 b.setInsertionPointToStart(innerLoop.getBody()); 174 for (auto ivs : llvm::zip(innerLoop.getInductionVars(), 175 outerLoop.getInductionVars())) { 176 Value innerIndex = std::get<0>(ivs); 177 auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs), 178 std::get<1>(ivs)); 179 innerIndex.replaceAllUsesExcept(newIndex, newIndex); 180 } 181 } 182 183 op.erase(); 184 return std::make_pair(outerLoop, innerLoop); 185 } 186 187 namespace { 188 struct ParallelLoopTiling 189 : public SCFParallelLoopTilingBase<ParallelLoopTiling> { 190 ParallelLoopTiling() = default; 191 explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes, 192 bool noMinMaxBounds = false) { 193 this->tileSizes = tileSizes; 194 this->noMinMaxBounds = noMinMaxBounds; 195 } 196 197 void runOnFunction() override { 198 SmallVector<ParallelOp, 2> innermostPloops; 199 getInnermostParallelLoops(getFunction().getOperation(), innermostPloops); 200 for (ParallelOp ploop : innermostPloops) { 201 // FIXME: Add reduction support. 202 if (ploop.getNumReductions() == 0) 203 tileParallelLoop(ploop, tileSizes, noMinMaxBounds); 204 } 205 } 206 }; 207 } // namespace 208 209 std::unique_ptr<Pass> 210 mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes, 211 bool noMinMaxBounds) { 212 return std::make_unique<ParallelLoopTiling>(tileSizes, noMinMaxBounds); 213 } 214