1 //===- ParallelLoopTiling.cpp - Tiles scf.parallel ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements loop tiling on parallel loops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/SCF/Passes.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms.h" 19 #include "mlir/Dialect/SCF/Utils.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" 21 22 using namespace mlir; 23 using namespace mlir::scf; 24 25 /// Tile a parallel loop of the form 26 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) 27 /// step (%arg4, %arg5) 28 /// 29 /// into 30 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) 31 /// step (%arg4*tileSize[0], 32 /// %arg5*tileSize[1]) 33 /// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0) 34 /// min(%arg5*tileSize[1], %arg3-%i1)) 35 /// step (%arg4, %arg5) 36 /// 37 /// or, when no-min-max-bounds is true, into 38 /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) 39 /// step (%arg4*tileSize[0], 40 /// %arg5*tileSize[1]) 41 /// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0], 42 /// %arg5*tileSize[1]) 43 /// step (%arg4, %arg5) 44 /// %inbound = (%j0 * %arg4 + %i0 < %arg2) && 45 /// (%j1 * %arg5 + %i1 < %arg3) 46 /// scf.if (%inbound) 47 /// .... 48 /// 49 /// where the uses of %i0 and %i1 in the loop body are replaced by 50 /// %i0 + j0 and %i1 + %j1. 51 // 52 /// The old loop is replaced with the new one. 53 std::pair<ParallelOp, ParallelOp> 54 mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes, 55 bool noMinMaxBounds) { 56 OpBuilder b(op); 57 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); 58 SmallVector<Value, 2> tileSizeConstants; 59 tileSizeConstants.reserve(op.getUpperBound().size()); 60 for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) { 61 if (i < tileSizes.size()) 62 tileSizeConstants.push_back( 63 b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i])); 64 else 65 // Just pick 1 for the remaining dimensions. 66 tileSizeConstants.push_back( 67 b.create<arith::ConstantIndexOp>(op.getLoc(), 1)); 68 } 69 70 // Create the outer loop with adjusted steps. 71 SmallVector<Value, 2> newSteps; 72 newSteps.reserve(op.getStep().size()); 73 for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) { 74 newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step), 75 std::get<1>(step))); 76 } 77 auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(), 78 op.getUpperBound(), newSteps); 79 b.setInsertionPointToStart(outerLoop.getBody()); 80 81 // Compute min(size, dim - offset) to avoid out-of-bounds accesses. 82 auto minMap = AffineMap::get( 83 /*dimCount=*/3, /*symbolCount=*/0, 84 {getAffineDimExpr(/*position=*/0, b.getContext()), 85 getAffineDimExpr(/*position=*/1, b.getContext()) - 86 getAffineDimExpr(/*position=*/2, b.getContext())}, 87 b.getContext()); 88 89 // Create the inner loop with adjusted bounds. 90 SmallVector<Value, 2> newBounds; 91 newBounds.reserve(op.getUpperBound().size()); 92 bool needInboundCheck = false; 93 for (auto dim : 94 llvm::zip(outerLoop.getLowerBound(), outerLoop.getUpperBound(), 95 outerLoop.getStep(), outerLoop.getInductionVars(), 96 op.getStep(), tileSizeConstants)) { 97 Value lowerBound, upperBound, newStep, iv, step, tileSizeConstant; 98 std::tie(lowerBound, upperBound, newStep, iv, step, tileSizeConstant) = dim; 99 // Collect the statically known loop bounds 100 auto lowerBoundConstant = 101 dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp()); 102 auto upperBoundConstant = 103 dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp()); 104 auto stepConstant = 105 dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp()); 106 auto tileSize = 107 cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value(); 108 // If the loop bounds and the loop step are constant and if the number of 109 // loop iterations is an integer multiple of the tile size, we use a static 110 // bound for the inner loop. 111 if (lowerBoundConstant && upperBoundConstant && stepConstant) { 112 auto numIterations = llvm::divideCeil(upperBoundConstant.value() - 113 lowerBoundConstant.value(), 114 stepConstant.value()); 115 if (numIterations % tileSize == 0) { 116 newBounds.push_back(newStep); 117 continue; 118 } 119 } 120 121 // For InboundCheck mode, just use the variable outer step 122 if (noMinMaxBounds) { 123 newBounds.push_back(newStep); 124 needInboundCheck = true; 125 continue; 126 } 127 128 // Otherwise, we dynamically compute the bound for 129 // each iteration of the outer loop. 130 newBounds.push_back( 131 b.create<AffineMinOp>(op.getLoc(), b.getIndexType(), minMap, 132 ValueRange{newStep, upperBound, iv})); 133 } 134 auto innerLoop = b.create<ParallelOp>( 135 op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds, 136 op.getStep()); 137 138 if (noMinMaxBounds && needInboundCheck) { 139 b.setInsertionPointToStart(innerLoop.getBody()); 140 // Insert in-bound check 141 Value inbound = 142 b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1)); 143 for (auto dim : 144 llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(), 145 innerLoop.getInductionVars(), innerLoop.getStep())) { 146 Value outerUpperBound, outerIV, innerIV, innerStep; 147 std::tie(outerUpperBound, outerIV, innerIV, innerStep) = dim; 148 // %in_bound = %in_bound && 149 // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound) 150 Value index = b.create<arith::AddIOp>( 151 op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep), 152 outerIV); 153 Value dimInbound = b.create<arith::CmpIOp>( 154 op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound); 155 inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound); 156 } 157 auto ifInbound = b.create<IfOp>(op.getLoc(), 158 /*resultTypes*/ ArrayRef<Type>{}, inbound, 159 /*hasElseRegion*/ false); 160 ifInbound.getThenRegion().takeBody(op.getRegion()); 161 Block &thenBlock = ifInbound.getThenRegion().front(); 162 b.setInsertionPointToStart(innerLoop.getBody()); 163 for (const auto &ivs : llvm::enumerate(llvm::zip( 164 innerLoop.getInductionVars(), outerLoop.getInductionVars()))) { 165 auto newIndex = b.create<arith::AddIOp>( 166 op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value())); 167 thenBlock.getArgument(ivs.index()) 168 .replaceAllUsesExcept(newIndex, newIndex); 169 } 170 thenBlock.eraseArguments(llvm::to_vector<4>( 171 llvm::seq((unsigned)0, thenBlock.getNumArguments()))); 172 } else { 173 innerLoop.getRegion().takeBody(op.getRegion()); 174 b.setInsertionPointToStart(innerLoop.getBody()); 175 for (auto ivs : llvm::zip(innerLoop.getInductionVars(), 176 outerLoop.getInductionVars())) { 177 Value innerIndex = std::get<0>(ivs); 178 auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs), 179 std::get<1>(ivs)); 180 innerIndex.replaceAllUsesExcept(newIndex, newIndex); 181 } 182 } 183 184 op.erase(); 185 return std::make_pair(outerLoop, innerLoop); 186 } 187 188 namespace { 189 struct ParallelLoopTiling 190 : public SCFParallelLoopTilingBase<ParallelLoopTiling> { 191 ParallelLoopTiling() = default; 192 explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes, 193 bool noMinMaxBounds = false) { 194 this->tileSizes = tileSizes; 195 this->noMinMaxBounds = noMinMaxBounds; 196 } 197 198 void runOnOperation() override { 199 SmallVector<ParallelOp, 2> innermostPloops; 200 getInnermostParallelLoops(getOperation().getOperation(), innermostPloops); 201 for (ParallelOp ploop : innermostPloops) { 202 // FIXME: Add reduction support. 203 if (ploop.getNumReductions() == 0) 204 tileParallelLoop(ploop, tileSizes, noMinMaxBounds); 205 } 206 } 207 }; 208 } // namespace 209 210 std::unique_ptr<Pass> 211 mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes, 212 bool noMinMaxBounds) { 213 return std::make_unique<ParallelLoopTiling>(tileSizes, noMinMaxBounds); 214 } 215