1 //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements scf.parallel to src.for + async.execute conversion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Async/IR/Async.h" 15 #include "mlir/Dialect/Async/Passes.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::async; 24 25 #define DEBUG_TYPE "async-parallel-for" 26 27 namespace { 28 29 // Rewrite scf.parallel operation into multiple concurrent async.execute 30 // operations over non overlapping subranges of the original loop. 31 // 32 // Example: 33 // 34 // scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) { 35 // "do_some_compute"(%i, %j): () -> () 36 // } 37 // 38 // Converted to: 39 // 40 // %c0 = constant 0 : index 41 // %c1 = constant 1 : index 42 // 43 // // Compute blocks sizes for each induction variable. 44 // %num_blocks_i = ... : index 45 // %num_blocks_j = ... : index 46 // %block_size_i = ... : index 47 // %block_size_j = ... : index 48 // 49 // // Create an async group to track async execute ops. 50 // %group = async.create_group 51 // 52 // scf.for %bi = %c0 to %num_blocks_i step %c1 { 53 // %block_start_i = ... : index 54 // %block_end_i = ... : index 55 // 56 // scf.for %bj = %c0 to %num_blocks_j step %c1 { 57 // %block_start_j = ... : index 58 // %block_end_j = ... : index 59 // 60 // // Execute the body of original parallel operation for the current 61 // // block. 62 // %token = async.execute { 63 // scf.for %i = %block_start_i to %block_end_i step %si { 64 // scf.for %j = %block_start_j to %block_end_j step %sj { 65 // "do_some_compute"(%i, %j): () -> () 66 // } 67 // } 68 // } 69 // 70 // // Add produced async token to the group. 71 // async.add_to_group %token, %group 72 // } 73 // } 74 // 75 // // Await completion of all async.execute operations. 76 // async.await_all %group 77 // 78 // In this example outer loop launches inner block level loops as separate async 79 // execute operations which will be executed concurrently. 80 // 81 // At the end it waits for the completiom of all async execute operations. 82 // 83 struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { 84 public: 85 AsyncParallelForRewrite(MLIRContext *ctx, int numConcurrentAsyncExecute) 86 : OpRewritePattern(ctx), 87 numConcurrentAsyncExecute(numConcurrentAsyncExecute) {} 88 89 LogicalResult matchAndRewrite(scf::ParallelOp op, 90 PatternRewriter &rewriter) const override; 91 92 private: 93 int numConcurrentAsyncExecute; 94 }; 95 96 struct AsyncParallelForPass 97 : public AsyncParallelForBase<AsyncParallelForPass> { 98 AsyncParallelForPass() = default; 99 AsyncParallelForPass(int numWorkerThreads) { 100 assert(numWorkerThreads >= 1); 101 numConcurrentAsyncExecute = numWorkerThreads; 102 } 103 void runOnOperation() override; 104 }; 105 106 } // namespace 107 108 LogicalResult 109 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, 110 PatternRewriter &rewriter) const { 111 // We do not currently support rewrite for parallel op with reductions. 112 if (op.getNumReductions() != 0) 113 return failure(); 114 115 MLIRContext *ctx = op.getContext(); 116 Location loc = op.getLoc(); 117 118 // Index constants used below. 119 auto indexTy = IndexType::get(ctx); 120 auto zero = IntegerAttr::get(indexTy, 0); 121 auto one = IntegerAttr::get(indexTy, 1); 122 auto c0 = rewriter.create<ConstantOp>(loc, indexTy, zero); 123 auto c1 = rewriter.create<ConstantOp>(loc, indexTy, one); 124 125 // Shorthand for signed integer ceil division operation. 126 auto divup = [&](Value x, Value y) -> Value { 127 return rewriter.create<SignedCeilDivIOp>(loc, x, y); 128 }; 129 130 // Compute trip count for each loop induction variable: 131 // tripCount = divUp(upperBound - lowerBound, step); 132 SmallVector<Value, 4> tripCounts(op.getNumLoops()); 133 for (size_t i = 0; i < op.getNumLoops(); ++i) { 134 auto lb = op.lowerBound()[i]; 135 auto ub = op.upperBound()[i]; 136 auto step = op.step()[i]; 137 auto range = rewriter.create<SubIOp>(loc, ub, lb); 138 tripCounts[i] = divup(range, step); 139 } 140 141 // The target number of concurrent async.execute ops. 142 auto numExecuteOps = rewriter.create<ConstantOp>( 143 loc, indexTy, IntegerAttr::get(indexTy, numConcurrentAsyncExecute)); 144 145 // Blocks sizes configuration for each induction variable. 146 147 // We try to use maximum available concurrency in outer dimensions first 148 // (assuming that parallel induction variables are corresponding to some 149 // multidimensional access, e.g. in (%d0, %d1, ..., %dn) = (<from>) to (<to>) 150 // we will try to parallelize iteration along the %d0. If %d0 is too small, 151 // we'll parallelize iteration over %d1, and so on. 152 SmallVector<Value, 4> targetNumBlocks(op.getNumLoops()); 153 SmallVector<Value, 4> blockSize(op.getNumLoops()); 154 SmallVector<Value, 4> numBlocks(op.getNumLoops()); 155 156 // Compute block size and number of blocks along the first induction variable. 157 targetNumBlocks[0] = numExecuteOps; 158 blockSize[0] = divup(tripCounts[0], targetNumBlocks[0]); 159 numBlocks[0] = divup(tripCounts[0], blockSize[0]); 160 161 // Assign remaining available concurrency to other induction variables. 162 for (size_t i = 1; i < op.getNumLoops(); ++i) { 163 targetNumBlocks[i] = divup(targetNumBlocks[i - 1], numBlocks[i - 1]); 164 blockSize[i] = divup(tripCounts[i], targetNumBlocks[i]); 165 numBlocks[i] = divup(tripCounts[i], blockSize[i]); 166 } 167 168 // Create an async.group to wait on all async tokens from async execute ops. 169 auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx)); 170 171 // Build a scf.for loop nest from the parallel operation. 172 173 // Lower/upper bounds for nest block level computations. 174 SmallVector<Value, 4> blockLowerBounds(op.getNumLoops()); 175 SmallVector<Value, 4> blockUpperBounds(op.getNumLoops()); 176 SmallVector<Value, 4> blockInductionVars(op.getNumLoops()); 177 178 using LoopBodyBuilder = 179 std::function<void(OpBuilder &, Location, Value, ValueRange)>; 180 using LoopBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>; 181 182 // Builds inner loop nest inside async.execute operation that does all the 183 // work concurrently. 184 LoopBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { 185 return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) { 186 blockInductionVars[loopIdx] = iv; 187 188 // Continue building async loop nest. 189 if (loopIdx < op.getNumLoops() - 1) { 190 b.create<scf::ForOp>( 191 loc, blockLowerBounds[loopIdx + 1], blockUpperBounds[loopIdx + 1], 192 op.step()[loopIdx + 1], ValueRange(), workLoopBuilder(loopIdx + 1)); 193 b.create<scf::YieldOp>(loc); 194 return; 195 } 196 197 // Copy the body of the parallel op with new loop bounds. 198 BlockAndValueMapping mapping; 199 mapping.map(op.getInductionVars(), blockInductionVars); 200 201 for (auto &bodyOp : op.getLoopBody().getOps()) 202 b.clone(bodyOp, mapping); 203 }; 204 }; 205 206 // Builds a loop nest that does async execute op dispatching. 207 LoopBuilder asyncLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { 208 return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) { 209 auto lb = op.lowerBound()[loopIdx]; 210 auto ub = op.upperBound()[loopIdx]; 211 auto step = op.step()[loopIdx]; 212 213 // Compute lower bound for the current block: 214 // blockLowerBound = iv * blockSize * step + lowerBound 215 auto s0 = b.create<MulIOp>(loc, iv, blockSize[loopIdx]); 216 auto s1 = b.create<MulIOp>(loc, s0, step); 217 auto s2 = b.create<AddIOp>(loc, s1, lb); 218 blockLowerBounds[loopIdx] = s2; 219 220 // Compute upper bound for the current block: 221 // blockUpperBound = min(upperBound, 222 // blockLowerBound + blockSize * step) 223 auto e0 = b.create<MulIOp>(loc, blockSize[loopIdx], step); 224 auto e1 = b.create<AddIOp>(loc, e0, s2); 225 auto e2 = b.create<CmpIOp>(loc, CmpIPredicate::slt, e1, ub); 226 auto e3 = b.create<SelectOp>(loc, e2, e1, ub); 227 blockUpperBounds[loopIdx] = e3; 228 229 // Continue building async dispatch loop nest. 230 if (loopIdx < op.getNumLoops() - 1) { 231 b.create<scf::ForOp>(loc, c0, numBlocks[loopIdx + 1], c1, ValueRange(), 232 asyncLoopBuilder(loopIdx + 1)); 233 b.create<scf::YieldOp>(loc); 234 return; 235 } 236 237 // Build the inner loop nest that will do the actual work inside the 238 // `async.execute` body region. 239 auto executeBodyBuilder = [&](OpBuilder &executeBuilder, 240 Location executeLoc, 241 ValueRange executeArgs) { 242 executeBuilder.create<scf::ForOp>(executeLoc, blockLowerBounds[0], 243 blockUpperBounds[0], op.step()[0], 244 ValueRange(), workLoopBuilder(0)); 245 executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); 246 }; 247 248 auto execute = b.create<ExecuteOp>( 249 loc, /*resultTypes=*/TypeRange(), /*dependencies=*/ValueRange(), 250 /*operands=*/ValueRange(), executeBodyBuilder); 251 auto rankType = IndexType::get(ctx); 252 b.create<AddToGroupOp>(loc, rankType, execute.token(), group.result()); 253 b.create<scf::YieldOp>(loc); 254 }; 255 }; 256 257 // Start building a loop nest from the first induction variable. 258 rewriter.create<scf::ForOp>(loc, c0, numBlocks[0], c1, ValueRange(), 259 asyncLoopBuilder(0)); 260 261 // Wait for the completion of all subtasks. 262 rewriter.create<AwaitAllOp>(loc, group.result()); 263 264 // Erase the original parallel operation. 265 rewriter.eraseOp(op); 266 267 return success(); 268 } 269 270 void AsyncParallelForPass::runOnOperation() { 271 MLIRContext *ctx = &getContext(); 272 273 RewritePatternSet patterns(ctx); 274 patterns.add<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute); 275 276 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 277 signalPassFailure(); 278 } 279 280 std::unique_ptr<Pass> mlir::createAsyncParallelForPass() { 281 return std::make_unique<AsyncParallelForPass>(); 282 } 283 284 std::unique_ptr<Pass> mlir::createAsyncParallelForPass(int numWorkerThreads) { 285 return std::make_unique<AsyncParallelForPass>(numWorkerThreads); 286 } 287