1c30ab6c2SEugene Zhulenev //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// 2c30ab6c2SEugene Zhulenev // 3c30ab6c2SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c30ab6c2SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 5c30ab6c2SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c30ab6c2SEugene Zhulenev // 7c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===// 8c30ab6c2SEugene Zhulenev // 9c30ab6c2SEugene Zhulenev // This file implements scf.parallel to src.for + async.execute conversion pass. 10c30ab6c2SEugene Zhulenev // 11c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===// 12c30ab6c2SEugene Zhulenev 13c30ab6c2SEugene Zhulenev #include "PassDetail.h" 14c30ab6c2SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h" 15c30ab6c2SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h" 16c30ab6c2SEugene Zhulenev #include "mlir/Dialect/SCF/SCF.h" 17c30ab6c2SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h" 18c30ab6c2SEugene Zhulenev #include "mlir/IR/BlockAndValueMapping.h" 19c30ab6c2SEugene Zhulenev #include "mlir/IR/PatternMatch.h" 20c30ab6c2SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21c30ab6c2SEugene Zhulenev 22c30ab6c2SEugene Zhulenev using namespace mlir; 23c30ab6c2SEugene Zhulenev using namespace mlir::async; 24c30ab6c2SEugene Zhulenev 25c30ab6c2SEugene Zhulenev #define DEBUG_TYPE "async-parallel-for" 26c30ab6c2SEugene Zhulenev 27c30ab6c2SEugene Zhulenev namespace { 28c30ab6c2SEugene Zhulenev 29c30ab6c2SEugene Zhulenev // Rewrite scf.parallel operation into multiple concurrent async.execute 30c30ab6c2SEugene Zhulenev // operations over non overlapping subranges of the original loop. 31c30ab6c2SEugene Zhulenev // 32c30ab6c2SEugene Zhulenev // Example: 33c30ab6c2SEugene Zhulenev // 34c30ab6c2SEugene Zhulenev // scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) { 35c30ab6c2SEugene Zhulenev // "do_some_compute"(%i, %j): () -> () 36c30ab6c2SEugene Zhulenev // } 37c30ab6c2SEugene Zhulenev // 38c30ab6c2SEugene Zhulenev // Converted to: 39c30ab6c2SEugene Zhulenev // 40c30ab6c2SEugene Zhulenev // %c0 = constant 0 : index 41c30ab6c2SEugene Zhulenev // %c1 = constant 1 : index 42c30ab6c2SEugene Zhulenev // 43c30ab6c2SEugene Zhulenev // // Compute blocks sizes for each induction variable. 44c30ab6c2SEugene Zhulenev // %num_blocks_i = ... : index 45c30ab6c2SEugene Zhulenev // %num_blocks_j = ... : index 46c30ab6c2SEugene Zhulenev // %block_size_i = ... : index 47c30ab6c2SEugene Zhulenev // %block_size_j = ... : index 48c30ab6c2SEugene Zhulenev // 49c30ab6c2SEugene Zhulenev // // Create an async group to track async execute ops. 50c30ab6c2SEugene Zhulenev // %group = async.create_group 51c30ab6c2SEugene Zhulenev // 52c30ab6c2SEugene Zhulenev // scf.for %bi = %c0 to %num_blocks_i step %c1 { 53c30ab6c2SEugene Zhulenev // %block_start_i = ... : index 54c30ab6c2SEugene Zhulenev // %block_end_i = ... : index 55c30ab6c2SEugene Zhulenev // 56c30ab6c2SEugene Zhulenev // scf.for %bj = %c0 to %num_blocks_j step %c1 { 57c30ab6c2SEugene Zhulenev // %block_start_j = ... : index 58c30ab6c2SEugene Zhulenev // %block_end_j = ... : index 59c30ab6c2SEugene Zhulenev // 60c30ab6c2SEugene Zhulenev // // Execute the body of original parallel operation for the current 61c30ab6c2SEugene Zhulenev // // block. 62c30ab6c2SEugene Zhulenev // %token = async.execute { 63c30ab6c2SEugene Zhulenev // scf.for %i = %block_start_i to %block_end_i step %si { 64c30ab6c2SEugene Zhulenev // scf.for %j = %block_start_j to %block_end_j step %sj { 65c30ab6c2SEugene Zhulenev // "do_some_compute"(%i, %j): () -> () 66c30ab6c2SEugene Zhulenev // } 67c30ab6c2SEugene Zhulenev // } 68c30ab6c2SEugene Zhulenev // } 69c30ab6c2SEugene Zhulenev // 70c30ab6c2SEugene Zhulenev // // Add produced async token to the group. 71c30ab6c2SEugene Zhulenev // async.add_to_group %token, %group 72c30ab6c2SEugene Zhulenev // } 73c30ab6c2SEugene Zhulenev // } 74c30ab6c2SEugene Zhulenev // 75c30ab6c2SEugene Zhulenev // // Await completion of all async.execute operations. 76c30ab6c2SEugene Zhulenev // async.await_all %group 77c30ab6c2SEugene Zhulenev // 78c30ab6c2SEugene Zhulenev // In this example outer loop launches inner block level loops as separate async 79c30ab6c2SEugene Zhulenev // execute operations which will be executed concurrently. 80c30ab6c2SEugene Zhulenev // 81c30ab6c2SEugene Zhulenev // At the end it waits for the completiom of all async execute operations. 82c30ab6c2SEugene Zhulenev // 83c30ab6c2SEugene Zhulenev struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { 84c30ab6c2SEugene Zhulenev public: 85c30ab6c2SEugene Zhulenev AsyncParallelForRewrite(MLIRContext *ctx, int numConcurrentAsyncExecute) 86c30ab6c2SEugene Zhulenev : OpRewritePattern(ctx), 87c30ab6c2SEugene Zhulenev numConcurrentAsyncExecute(numConcurrentAsyncExecute) {} 88c30ab6c2SEugene Zhulenev 89c30ab6c2SEugene Zhulenev LogicalResult matchAndRewrite(scf::ParallelOp op, 90c30ab6c2SEugene Zhulenev PatternRewriter &rewriter) const override; 91c30ab6c2SEugene Zhulenev 92c30ab6c2SEugene Zhulenev private: 93c30ab6c2SEugene Zhulenev int numConcurrentAsyncExecute; 94c30ab6c2SEugene Zhulenev }; 95c30ab6c2SEugene Zhulenev 96c30ab6c2SEugene Zhulenev struct AsyncParallelForPass 97c30ab6c2SEugene Zhulenev : public AsyncParallelForBase<AsyncParallelForPass> { 98c30ab6c2SEugene Zhulenev AsyncParallelForPass() = default; 9994e645f9SEugene Zhulenev AsyncParallelForPass(int numWorkerThreads) { 10094e645f9SEugene Zhulenev assert(numWorkerThreads >= 1); 10194e645f9SEugene Zhulenev numConcurrentAsyncExecute = numWorkerThreads; 10294e645f9SEugene Zhulenev } 103c30ab6c2SEugene Zhulenev void runOnFunction() override; 104c30ab6c2SEugene Zhulenev }; 105c30ab6c2SEugene Zhulenev 106c30ab6c2SEugene Zhulenev } // namespace 107c30ab6c2SEugene Zhulenev 108c30ab6c2SEugene Zhulenev LogicalResult 109c30ab6c2SEugene Zhulenev AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, 110c30ab6c2SEugene Zhulenev PatternRewriter &rewriter) const { 111c30ab6c2SEugene Zhulenev // We do not currently support rewrite for parallel op with reductions. 112c30ab6c2SEugene Zhulenev if (op.getNumReductions() != 0) 113c30ab6c2SEugene Zhulenev return failure(); 114c30ab6c2SEugene Zhulenev 115c30ab6c2SEugene Zhulenev MLIRContext *ctx = op.getContext(); 116c30ab6c2SEugene Zhulenev Location loc = op.getLoc(); 117c30ab6c2SEugene Zhulenev 118c30ab6c2SEugene Zhulenev // Index constants used below. 119c30ab6c2SEugene Zhulenev auto indexTy = IndexType::get(ctx); 120c30ab6c2SEugene Zhulenev auto zero = IntegerAttr::get(indexTy, 0); 121c30ab6c2SEugene Zhulenev auto one = IntegerAttr::get(indexTy, 1); 122c30ab6c2SEugene Zhulenev auto c0 = rewriter.create<ConstantOp>(loc, indexTy, zero); 123c30ab6c2SEugene Zhulenev auto c1 = rewriter.create<ConstantOp>(loc, indexTy, one); 124c30ab6c2SEugene Zhulenev 125c30ab6c2SEugene Zhulenev // Shorthand for signed integer ceil division operation. 126c30ab6c2SEugene Zhulenev auto divup = [&](Value x, Value y) -> Value { 127c30ab6c2SEugene Zhulenev return rewriter.create<SignedCeilDivIOp>(loc, x, y); 128c30ab6c2SEugene Zhulenev }; 129c30ab6c2SEugene Zhulenev 130c30ab6c2SEugene Zhulenev // Compute trip count for each loop induction variable: 131c30ab6c2SEugene Zhulenev // tripCount = divUp(upperBound - lowerBound, step); 132c30ab6c2SEugene Zhulenev SmallVector<Value, 4> tripCounts(op.getNumLoops()); 133c30ab6c2SEugene Zhulenev for (size_t i = 0; i < op.getNumLoops(); ++i) { 134c30ab6c2SEugene Zhulenev auto lb = op.lowerBound()[i]; 135c30ab6c2SEugene Zhulenev auto ub = op.upperBound()[i]; 136c30ab6c2SEugene Zhulenev auto step = op.step()[i]; 137c30ab6c2SEugene Zhulenev auto range = rewriter.create<SubIOp>(loc, ub, lb); 138c30ab6c2SEugene Zhulenev tripCounts[i] = divup(range, step); 139c30ab6c2SEugene Zhulenev } 140c30ab6c2SEugene Zhulenev 141c30ab6c2SEugene Zhulenev // The target number of concurrent async.execute ops. 142c30ab6c2SEugene Zhulenev auto numExecuteOps = rewriter.create<ConstantOp>( 143c30ab6c2SEugene Zhulenev loc, indexTy, IntegerAttr::get(indexTy, numConcurrentAsyncExecute)); 144c30ab6c2SEugene Zhulenev 145c30ab6c2SEugene Zhulenev // Blocks sizes configuration for each induction variable. 146c30ab6c2SEugene Zhulenev 147c30ab6c2SEugene Zhulenev // We try to use maximum available concurrency in outer dimensions first 148c30ab6c2SEugene Zhulenev // (assuming that parallel induction variables are corresponding to some 149c30ab6c2SEugene Zhulenev // multidimensional access, e.g. in (%d0, %d1, ..., %dn) = (<from>) to (<to>) 150c30ab6c2SEugene Zhulenev // we will try to parallelize iteration along the %d0. If %d0 is too small, 151c30ab6c2SEugene Zhulenev // we'll parallelize iteration over %d1, and so on. 152c30ab6c2SEugene Zhulenev SmallVector<Value, 4> targetNumBlocks(op.getNumLoops()); 153c30ab6c2SEugene Zhulenev SmallVector<Value, 4> blockSize(op.getNumLoops()); 154c30ab6c2SEugene Zhulenev SmallVector<Value, 4> numBlocks(op.getNumLoops()); 155c30ab6c2SEugene Zhulenev 156c30ab6c2SEugene Zhulenev // Compute block size and number of blocks along the first induction variable. 157c30ab6c2SEugene Zhulenev targetNumBlocks[0] = numExecuteOps; 158c30ab6c2SEugene Zhulenev blockSize[0] = divup(tripCounts[0], targetNumBlocks[0]); 159c30ab6c2SEugene Zhulenev numBlocks[0] = divup(tripCounts[0], blockSize[0]); 160c30ab6c2SEugene Zhulenev 161c30ab6c2SEugene Zhulenev // Assign remaining available concurrency to other induction variables. 162c30ab6c2SEugene Zhulenev for (size_t i = 1; i < op.getNumLoops(); ++i) { 163c30ab6c2SEugene Zhulenev targetNumBlocks[i] = divup(targetNumBlocks[i - 1], numBlocks[i - 1]); 164c30ab6c2SEugene Zhulenev blockSize[i] = divup(tripCounts[i], targetNumBlocks[i]); 165c30ab6c2SEugene Zhulenev numBlocks[i] = divup(tripCounts[i], blockSize[i]); 166c30ab6c2SEugene Zhulenev } 167c30ab6c2SEugene Zhulenev 168c30ab6c2SEugene Zhulenev // Create an async.group to wait on all async tokens from async execute ops. 169c30ab6c2SEugene Zhulenev auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx)); 170c30ab6c2SEugene Zhulenev 171c30ab6c2SEugene Zhulenev // Build a scf.for loop nest from the parallel operation. 172c30ab6c2SEugene Zhulenev 173c30ab6c2SEugene Zhulenev // Lower/upper bounds for nest block level computations. 174c30ab6c2SEugene Zhulenev SmallVector<Value, 4> blockLowerBounds(op.getNumLoops()); 175c30ab6c2SEugene Zhulenev SmallVector<Value, 4> blockUpperBounds(op.getNumLoops()); 176c30ab6c2SEugene Zhulenev SmallVector<Value, 4> blockInductionVars(op.getNumLoops()); 177c30ab6c2SEugene Zhulenev 178c30ab6c2SEugene Zhulenev using LoopBodyBuilder = 179c30ab6c2SEugene Zhulenev std::function<void(OpBuilder &, Location, Value, ValueRange)>; 180c30ab6c2SEugene Zhulenev using LoopBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>; 181c30ab6c2SEugene Zhulenev 182c30ab6c2SEugene Zhulenev // Builds inner loop nest inside async.execute operation that does all the 183c30ab6c2SEugene Zhulenev // work concurrently. 184c30ab6c2SEugene Zhulenev LoopBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { 185c30ab6c2SEugene Zhulenev return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) { 186c30ab6c2SEugene Zhulenev blockInductionVars[loopIdx] = iv; 187c30ab6c2SEugene Zhulenev 188f88fab50SKazuaki Ishizaki // Continue building async loop nest. 189c30ab6c2SEugene Zhulenev if (loopIdx < op.getNumLoops() - 1) { 190c30ab6c2SEugene Zhulenev b.create<scf::ForOp>( 191c30ab6c2SEugene Zhulenev loc, blockLowerBounds[loopIdx + 1], blockUpperBounds[loopIdx + 1], 192c30ab6c2SEugene Zhulenev op.step()[loopIdx + 1], ValueRange(), workLoopBuilder(loopIdx + 1)); 193c30ab6c2SEugene Zhulenev b.create<scf::YieldOp>(loc); 194c30ab6c2SEugene Zhulenev return; 195c30ab6c2SEugene Zhulenev } 196c30ab6c2SEugene Zhulenev 197c30ab6c2SEugene Zhulenev // Copy the body of the parallel op with new loop bounds. 198c30ab6c2SEugene Zhulenev BlockAndValueMapping mapping; 199c30ab6c2SEugene Zhulenev mapping.map(op.getInductionVars(), blockInductionVars); 200c30ab6c2SEugene Zhulenev 201c30ab6c2SEugene Zhulenev for (auto &bodyOp : op.getLoopBody().getOps()) 202c30ab6c2SEugene Zhulenev b.clone(bodyOp, mapping); 203c30ab6c2SEugene Zhulenev }; 204c30ab6c2SEugene Zhulenev }; 205c30ab6c2SEugene Zhulenev 206c30ab6c2SEugene Zhulenev // Builds a loop nest that does async execute op dispatching. 207c30ab6c2SEugene Zhulenev LoopBuilder asyncLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { 208c30ab6c2SEugene Zhulenev return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) { 209c30ab6c2SEugene Zhulenev auto lb = op.lowerBound()[loopIdx]; 210c30ab6c2SEugene Zhulenev auto ub = op.upperBound()[loopIdx]; 211c30ab6c2SEugene Zhulenev auto step = op.step()[loopIdx]; 212c30ab6c2SEugene Zhulenev 213c30ab6c2SEugene Zhulenev // Compute lower bound for the current block: 214c30ab6c2SEugene Zhulenev // blockLowerBound = iv * blockSize * step + lowerBound 215c30ab6c2SEugene Zhulenev auto s0 = b.create<MulIOp>(loc, iv, blockSize[loopIdx]); 216c30ab6c2SEugene Zhulenev auto s1 = b.create<MulIOp>(loc, s0, step); 217c30ab6c2SEugene Zhulenev auto s2 = b.create<AddIOp>(loc, s1, lb); 218c30ab6c2SEugene Zhulenev blockLowerBounds[loopIdx] = s2; 219c30ab6c2SEugene Zhulenev 220c30ab6c2SEugene Zhulenev // Compute upper bound for the current block: 221c30ab6c2SEugene Zhulenev // blockUpperBound = min(upperBound, 222c30ab6c2SEugene Zhulenev // blockLowerBound + blockSize * step) 223c30ab6c2SEugene Zhulenev auto e0 = b.create<MulIOp>(loc, blockSize[loopIdx], step); 224c30ab6c2SEugene Zhulenev auto e1 = b.create<AddIOp>(loc, e0, s2); 225c30ab6c2SEugene Zhulenev auto e2 = b.create<CmpIOp>(loc, CmpIPredicate::slt, e1, ub); 226c30ab6c2SEugene Zhulenev auto e3 = b.create<SelectOp>(loc, e2, e1, ub); 227c30ab6c2SEugene Zhulenev blockUpperBounds[loopIdx] = e3; 228c30ab6c2SEugene Zhulenev 229c30ab6c2SEugene Zhulenev // Continue building async dispatch loop nest. 230c30ab6c2SEugene Zhulenev if (loopIdx < op.getNumLoops() - 1) { 231c30ab6c2SEugene Zhulenev b.create<scf::ForOp>(loc, c0, numBlocks[loopIdx + 1], c1, ValueRange(), 232c30ab6c2SEugene Zhulenev asyncLoopBuilder(loopIdx + 1)); 233c30ab6c2SEugene Zhulenev b.create<scf::YieldOp>(loc); 234c30ab6c2SEugene Zhulenev return; 235c30ab6c2SEugene Zhulenev } 236c30ab6c2SEugene Zhulenev 237c30ab6c2SEugene Zhulenev // Build the inner loop nest that will do the actual work inside the 238c30ab6c2SEugene Zhulenev // `async.execute` body region. 239c30ab6c2SEugene Zhulenev auto executeBodyBuilder = [&](OpBuilder &executeBuilder, 240c30ab6c2SEugene Zhulenev Location executeLoc, 241c30ab6c2SEugene Zhulenev ValueRange executeArgs) { 242c30ab6c2SEugene Zhulenev executeBuilder.create<scf::ForOp>(executeLoc, blockLowerBounds[0], 243c30ab6c2SEugene Zhulenev blockUpperBounds[0], op.step()[0], 244c30ab6c2SEugene Zhulenev ValueRange(), workLoopBuilder(0)); 245c30ab6c2SEugene Zhulenev executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); 246c30ab6c2SEugene Zhulenev }; 247c30ab6c2SEugene Zhulenev 248c30ab6c2SEugene Zhulenev auto execute = b.create<ExecuteOp>( 249c30ab6c2SEugene Zhulenev loc, /*resultTypes=*/TypeRange(), /*dependencies=*/ValueRange(), 250c30ab6c2SEugene Zhulenev /*operands=*/ValueRange(), executeBodyBuilder); 251c30ab6c2SEugene Zhulenev auto rankType = IndexType::get(ctx); 252c30ab6c2SEugene Zhulenev b.create<AddToGroupOp>(loc, rankType, execute.token(), group.result()); 253c30ab6c2SEugene Zhulenev b.create<scf::YieldOp>(loc); 254c30ab6c2SEugene Zhulenev }; 255c30ab6c2SEugene Zhulenev }; 256c30ab6c2SEugene Zhulenev 257c30ab6c2SEugene Zhulenev // Start building a loop nest from the first induction variable. 258c30ab6c2SEugene Zhulenev rewriter.create<scf::ForOp>(loc, c0, numBlocks[0], c1, ValueRange(), 259c30ab6c2SEugene Zhulenev asyncLoopBuilder(0)); 260c30ab6c2SEugene Zhulenev 261c30ab6c2SEugene Zhulenev // Wait for the completion of all subtasks. 262c30ab6c2SEugene Zhulenev rewriter.create<AwaitAllOp>(loc, group.result()); 263c30ab6c2SEugene Zhulenev 264c30ab6c2SEugene Zhulenev // Erase the original parallel operation. 265c30ab6c2SEugene Zhulenev rewriter.eraseOp(op); 266c30ab6c2SEugene Zhulenev 267c30ab6c2SEugene Zhulenev return success(); 268c30ab6c2SEugene Zhulenev } 269c30ab6c2SEugene Zhulenev 270c30ab6c2SEugene Zhulenev void AsyncParallelForPass::runOnFunction() { 271c30ab6c2SEugene Zhulenev MLIRContext *ctx = &getContext(); 272c30ab6c2SEugene Zhulenev 273*dc4e913bSChris Lattner RewritePatternSet patterns(ctx); 274*dc4e913bSChris Lattner patterns.add<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute); 275c30ab6c2SEugene Zhulenev 276c30ab6c2SEugene Zhulenev if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) 277c30ab6c2SEugene Zhulenev signalPassFailure(); 278c30ab6c2SEugene Zhulenev } 279c30ab6c2SEugene Zhulenev 280c30ab6c2SEugene Zhulenev std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() { 281c30ab6c2SEugene Zhulenev return std::make_unique<AsyncParallelForPass>(); 282c30ab6c2SEugene Zhulenev } 28394e645f9SEugene Zhulenev 28494e645f9SEugene Zhulenev std::unique_ptr<OperationPass<FuncOp>> 28594e645f9SEugene Zhulenev mlir::createAsyncParallelForPass(int numWorkerThreads) { 28694e645f9SEugene Zhulenev return std::make_unique<AsyncParallelForPass>(numWorkerThreads); 28794e645f9SEugene Zhulenev } 288