1 //===- ParallelLoopTiling.cpp - Tiles scf.parallel ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop tiling on parallel loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/SCF/Transforms/Passes.h" 18 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 21 using namespace mlir; 22 using namespace mlir::scf; 23 24 /// Tile a parallel loop of the form 25 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) 26 /// step (%arg4, %arg5) 27 /// 28 /// into 29 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) 30 /// step (%arg4*tileSize[0], 31 /// %arg5*tileSize[1]) 32 /// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0) 33 /// min(%arg5*tileSize[1], %arg3-%i1)) 34 /// step (%arg4, %arg5) 35 /// 36 /// or, when no-min-max-bounds is true, into 37 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) 38 /// step (%arg4*tileSize[0], 39 /// %arg5*tileSize[1]) 40 /// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0], 41 /// %arg5*tileSize[1]) 42 /// step (%arg4, %arg5) 43 /// %inbound = (%j0 * %arg4 + %i0 < %arg2) && 44 /// (%j1 * %arg5 + %i1 < %arg3) 45 /// scf.if (%inbound) 46 /// .... 47 /// 48 /// where the uses of %i0 and %i1 in the loop body are replaced by 49 /// %i0 + j0 and %i1 + %j1. 50 /// 51 /// The old loop is replaced with the new one. 52 std::pair<ParallelOp, ParallelOp> 53 mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes, 54 bool noMinMaxBounds) { 55 OpBuilder b(op); 56 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); 57 SmallVector<Value, 2> tileSizeConstants; 58 tileSizeConstants.reserve(op.getUpperBound().size()); 59 for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) { 60 if (i < tileSizes.size()) 61 tileSizeConstants.push_back( 62 b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i])); 63 else 64 // Just pick 1 for the remaining dimensions. 65 tileSizeConstants.push_back( 66 b.create<arith::ConstantIndexOp>(op.getLoc(), 1)); 67 } 68 69 // Create the outer loop with adjusted steps. 70 SmallVector<Value, 2> newSteps; 71 newSteps.reserve(op.getStep().size()); 72 for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) { 73 newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step), 74 std::get<1>(step))); 75 } 76 auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(), 77 op.getUpperBound(), newSteps); 78 b.setInsertionPointToStart(outerLoop.getBody()); 79 80 // Compute min(size, dim - offset) to avoid out-of-bounds accesses. 81 auto minMap = AffineMap::get( 82 /*dimCount=*/3, /*symbolCount=*/0, 83 {getAffineDimExpr(/*position=*/0, b.getContext()), 84 getAffineDimExpr(/*position=*/1, b.getContext()) - 85 getAffineDimExpr(/*position=*/2, b.getContext())}, 86 b.getContext()); 87 88 // Create the inner loop with adjusted bounds. 89 SmallVector<Value, 2> newBounds; 90 newBounds.reserve(op.getUpperBound().size()); 91 bool needInboundCheck = false; 92 for (auto dim : 93 llvm::zip(outerLoop.getLowerBound(), outerLoop.getUpperBound(), 94 outerLoop.getStep(), outerLoop.getInductionVars(), 95 op.getStep(), tileSizeConstants)) { 96 Value lowerBound, upperBound, newStep, iv, step, tileSizeConstant; 97 std::tie(lowerBound, upperBound, newStep, iv, step, tileSizeConstant) = dim; 98 // Collect the statically known loop bounds 99 auto lowerBoundConstant = 100 dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp()); 101 auto upperBoundConstant = 102 dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp()); 103 auto stepConstant = 104 dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp()); 105 auto tileSize = 106 cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value(); 107 // If the loop bounds and the loop step are constant and if the number of 108 // loop iterations is an integer multiple of the tile size, we use a static 109 // bound for the inner loop. 110 if (lowerBoundConstant && upperBoundConstant && stepConstant) { 111 auto numIterations = llvm::divideCeil(upperBoundConstant.value() - 112 lowerBoundConstant.value(), 113 stepConstant.value()); 114 if (numIterations % tileSize == 0) { 115 newBounds.push_back(newStep); 116 continue; 117 } 118 } 119 120 // For InboundCheck mode, just use the variable outer step 121 if (noMinMaxBounds) { 122 newBounds.push_back(newStep); 123 needInboundCheck = true; 124 continue; 125 } 126 127 // Otherwise, we dynamically compute the bound for 128 // each iteration of the outer loop. 129 newBounds.push_back( 130 b.create<AffineMinOp>(op.getLoc(), b.getIndexType(), minMap, 131 ValueRange{newStep, upperBound, iv})); 132 } 133 auto innerLoop = b.create<ParallelOp>( 134 op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds, 135 op.getStep()); 136 137 if (noMinMaxBounds && needInboundCheck) { 138 b.setInsertionPointToStart(innerLoop.getBody()); 139 // Insert in-bound check 140 Value inbound = 141 b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1)); 142 for (auto dim : 143 llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(), 144 innerLoop.getInductionVars(), innerLoop.getStep())) { 145 Value outerUpperBound, outerIV, innerIV, innerStep; 146 std::tie(outerUpperBound, outerIV, innerIV, innerStep) = dim; 147 // %in_bound = %in_bound && 148 // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound) 149 Value index = b.create<arith::AddIOp>( 150 op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep), 151 outerIV); 152 Value dimInbound = b.create<arith::CmpIOp>( 153 op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound); 154 inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound); 155 } 156 auto ifInbound = b.create<IfOp>(op.getLoc(), 157 /*resultTypes*/ ArrayRef<Type>{}, inbound, 158 /*hasElseRegion*/ false); 159 ifInbound.getThenRegion().takeBody(op.getRegion()); 160 Block &thenBlock = ifInbound.getThenRegion().front(); 161 b.setInsertionPointToStart(innerLoop.getBody()); 162 for (const auto &ivs : llvm::enumerate(llvm::zip( 163 innerLoop.getInductionVars(), outerLoop.getInductionVars()))) { 164 auto newIndex = b.create<arith::AddIOp>( 165 op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value())); 166 thenBlock.getArgument(ivs.index()) 167 .replaceAllUsesExcept(newIndex, newIndex); 168 } 169 thenBlock.eraseArguments(llvm::to_vector<4>( 170 llvm::seq((unsigned)0, thenBlock.getNumArguments()))); 171 } else { 172 innerLoop.getRegion().takeBody(op.getRegion()); 173 b.setInsertionPointToStart(innerLoop.getBody()); 174 for (auto ivs : llvm::zip(innerLoop.getInductionVars(), 175 outerLoop.getInductionVars())) { 176 Value innerIndex = std::get<0>(ivs); 177 auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs), 178 std::get<1>(ivs)); 179 innerIndex.replaceAllUsesExcept(newIndex, newIndex); 180 } 181 } 182 183 op.erase(); 184 return std::make_pair(outerLoop, innerLoop); 185 } 186 187 namespace { 188 struct ParallelLoopTiling 189 : public SCFParallelLoopTilingBase<ParallelLoopTiling> { 190 ParallelLoopTiling() = default; 191 explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes, 192 bool noMinMaxBounds = false) { 193 this->tileSizes = tileSizes; 194 this->noMinMaxBounds = noMinMaxBounds; 195 } 196 197 void runOnOperation() override { 198 SmallVector<ParallelOp, 2> innermostPloops; 199 getInnermostParallelLoops(getOperation().getOperation(), innermostPloops); 200 for (ParallelOp ploop : innermostPloops) { 201 // FIXME: Add reduction support. 202 if (ploop.getNumReductions() == 0) 203 tileParallelLoop(ploop, tileSizes, noMinMaxBounds); 204 } 205 } 206 }; 207 } // namespace 208 209 std::unique_ptr<Pass> 210 mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes, 211 bool noMinMaxBounds) { 212 return std::make_unique<ParallelLoopTiling>(tileSizes, noMinMaxBounds); 213 } 214