1*c30ab6c2SEugene Zhulenev //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// 2*c30ab6c2SEugene Zhulenev // 3*c30ab6c2SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*c30ab6c2SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 5*c30ab6c2SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*c30ab6c2SEugene Zhulenev // 7*c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===// 8*c30ab6c2SEugene Zhulenev // 9*c30ab6c2SEugene Zhulenev // This file implements scf.parallel to src.for + async.execute conversion pass. 10*c30ab6c2SEugene Zhulenev // 11*c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===// 12*c30ab6c2SEugene Zhulenev 13*c30ab6c2SEugene Zhulenev #include "PassDetail.h" 14*c30ab6c2SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h" 15*c30ab6c2SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h" 16*c30ab6c2SEugene Zhulenev #include "mlir/Dialect/SCF/SCF.h" 17*c30ab6c2SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h" 18*c30ab6c2SEugene Zhulenev #include "mlir/IR/BlockAndValueMapping.h" 19*c30ab6c2SEugene Zhulenev #include "mlir/IR/PatternMatch.h" 20*c30ab6c2SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21*c30ab6c2SEugene Zhulenev 22*c30ab6c2SEugene Zhulenev using namespace mlir; 23*c30ab6c2SEugene Zhulenev using namespace mlir::async; 24*c30ab6c2SEugene Zhulenev 25*c30ab6c2SEugene Zhulenev #define DEBUG_TYPE "async-parallel-for" 26*c30ab6c2SEugene Zhulenev 27*c30ab6c2SEugene Zhulenev namespace { 28*c30ab6c2SEugene Zhulenev 29*c30ab6c2SEugene Zhulenev // Rewrite scf.parallel operation into multiple concurrent async.execute 30*c30ab6c2SEugene Zhulenev // operations over non overlapping subranges of the original loop. 31*c30ab6c2SEugene Zhulenev // 32*c30ab6c2SEugene Zhulenev // Example: 33*c30ab6c2SEugene Zhulenev // 34*c30ab6c2SEugene Zhulenev // scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) { 35*c30ab6c2SEugene Zhulenev // "do_some_compute"(%i, %j): () -> () 36*c30ab6c2SEugene Zhulenev // } 37*c30ab6c2SEugene Zhulenev // 38*c30ab6c2SEugene Zhulenev // Converted to: 39*c30ab6c2SEugene Zhulenev // 40*c30ab6c2SEugene Zhulenev // %c0 = constant 0 : index 41*c30ab6c2SEugene Zhulenev // %c1 = constant 1 : index 42*c30ab6c2SEugene Zhulenev // 43*c30ab6c2SEugene Zhulenev // // Compute blocks sizes for each induction variable. 44*c30ab6c2SEugene Zhulenev // %num_blocks_i = ... : index 45*c30ab6c2SEugene Zhulenev // %num_blocks_j = ... : index 46*c30ab6c2SEugene Zhulenev // %block_size_i = ... : index 47*c30ab6c2SEugene Zhulenev // %block_size_j = ... : index 48*c30ab6c2SEugene Zhulenev // 49*c30ab6c2SEugene Zhulenev // // Create an async group to track async execute ops. 50*c30ab6c2SEugene Zhulenev // %group = async.create_group 51*c30ab6c2SEugene Zhulenev // 52*c30ab6c2SEugene Zhulenev // scf.for %bi = %c0 to %num_blocks_i step %c1 { 53*c30ab6c2SEugene Zhulenev // %block_start_i = ... : index 54*c30ab6c2SEugene Zhulenev // %block_end_i = ... : index 55*c30ab6c2SEugene Zhulenev // 56*c30ab6c2SEugene Zhulenev // scf.for %bj = %c0 to %num_blocks_j step %c1 { 57*c30ab6c2SEugene Zhulenev // %block_start_j = ... : index 58*c30ab6c2SEugene Zhulenev // %block_end_j = ... : index 59*c30ab6c2SEugene Zhulenev // 60*c30ab6c2SEugene Zhulenev // // Execute the body of original parallel operation for the current 61*c30ab6c2SEugene Zhulenev // // block. 62*c30ab6c2SEugene Zhulenev // %token = async.execute { 63*c30ab6c2SEugene Zhulenev // scf.for %i = %block_start_i to %block_end_i step %si { 64*c30ab6c2SEugene Zhulenev // scf.for %j = %block_start_j to %block_end_j step %sj { 65*c30ab6c2SEugene Zhulenev // "do_some_compute"(%i, %j): () -> () 66*c30ab6c2SEugene Zhulenev // } 67*c30ab6c2SEugene Zhulenev // } 68*c30ab6c2SEugene Zhulenev // } 69*c30ab6c2SEugene Zhulenev // 70*c30ab6c2SEugene Zhulenev // // Add produced async token to the group. 71*c30ab6c2SEugene Zhulenev // async.add_to_group %token, %group 72*c30ab6c2SEugene Zhulenev // } 73*c30ab6c2SEugene Zhulenev // } 74*c30ab6c2SEugene Zhulenev // 75*c30ab6c2SEugene Zhulenev // // Await completion of all async.execute operations. 76*c30ab6c2SEugene Zhulenev // async.await_all %group 77*c30ab6c2SEugene Zhulenev // 78*c30ab6c2SEugene Zhulenev // In this example outer loop launches inner block level loops as separate async 79*c30ab6c2SEugene Zhulenev // execute operations which will be executed concurrently. 80*c30ab6c2SEugene Zhulenev // 81*c30ab6c2SEugene Zhulenev // At the end it waits for the completiom of all async execute operations. 82*c30ab6c2SEugene Zhulenev // 83*c30ab6c2SEugene Zhulenev struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { 84*c30ab6c2SEugene Zhulenev public: 85*c30ab6c2SEugene Zhulenev AsyncParallelForRewrite(MLIRContext *ctx, int numConcurrentAsyncExecute) 86*c30ab6c2SEugene Zhulenev : OpRewritePattern(ctx), 87*c30ab6c2SEugene Zhulenev numConcurrentAsyncExecute(numConcurrentAsyncExecute) {} 88*c30ab6c2SEugene Zhulenev 89*c30ab6c2SEugene Zhulenev LogicalResult matchAndRewrite(scf::ParallelOp op, 90*c30ab6c2SEugene Zhulenev PatternRewriter &rewriter) const override; 91*c30ab6c2SEugene Zhulenev 92*c30ab6c2SEugene Zhulenev private: 93*c30ab6c2SEugene Zhulenev int numConcurrentAsyncExecute; 94*c30ab6c2SEugene Zhulenev }; 95*c30ab6c2SEugene Zhulenev 96*c30ab6c2SEugene Zhulenev struct AsyncParallelForPass 97*c30ab6c2SEugene Zhulenev : public AsyncParallelForBase<AsyncParallelForPass> { 98*c30ab6c2SEugene Zhulenev AsyncParallelForPass() = default; 99*c30ab6c2SEugene Zhulenev void runOnFunction() override; 100*c30ab6c2SEugene Zhulenev }; 101*c30ab6c2SEugene Zhulenev 102*c30ab6c2SEugene Zhulenev } // namespace 103*c30ab6c2SEugene Zhulenev 104*c30ab6c2SEugene Zhulenev LogicalResult 105*c30ab6c2SEugene Zhulenev AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, 106*c30ab6c2SEugene Zhulenev PatternRewriter &rewriter) const { 107*c30ab6c2SEugene Zhulenev // We do not currently support rewrite for parallel op with reductions. 108*c30ab6c2SEugene Zhulenev if (op.getNumReductions() != 0) 109*c30ab6c2SEugene Zhulenev return failure(); 110*c30ab6c2SEugene Zhulenev 111*c30ab6c2SEugene Zhulenev MLIRContext *ctx = op.getContext(); 112*c30ab6c2SEugene Zhulenev Location loc = op.getLoc(); 113*c30ab6c2SEugene Zhulenev 114*c30ab6c2SEugene Zhulenev // Index constants used below. 115*c30ab6c2SEugene Zhulenev auto indexTy = IndexType::get(ctx); 116*c30ab6c2SEugene Zhulenev auto zero = IntegerAttr::get(indexTy, 0); 117*c30ab6c2SEugene Zhulenev auto one = IntegerAttr::get(indexTy, 1); 118*c30ab6c2SEugene Zhulenev auto c0 = rewriter.create<ConstantOp>(loc, indexTy, zero); 119*c30ab6c2SEugene Zhulenev auto c1 = rewriter.create<ConstantOp>(loc, indexTy, one); 120*c30ab6c2SEugene Zhulenev 121*c30ab6c2SEugene Zhulenev // Shorthand for signed integer ceil division operation. 122*c30ab6c2SEugene Zhulenev auto divup = [&](Value x, Value y) -> Value { 123*c30ab6c2SEugene Zhulenev return rewriter.create<SignedCeilDivIOp>(loc, x, y); 124*c30ab6c2SEugene Zhulenev }; 125*c30ab6c2SEugene Zhulenev 126*c30ab6c2SEugene Zhulenev // Compute trip count for each loop induction variable: 127*c30ab6c2SEugene Zhulenev // tripCount = divUp(upperBound - lowerBound, step); 128*c30ab6c2SEugene Zhulenev SmallVector<Value, 4> tripCounts(op.getNumLoops()); 129*c30ab6c2SEugene Zhulenev for (size_t i = 0; i < op.getNumLoops(); ++i) { 130*c30ab6c2SEugene Zhulenev auto lb = op.lowerBound()[i]; 131*c30ab6c2SEugene Zhulenev auto ub = op.upperBound()[i]; 132*c30ab6c2SEugene Zhulenev auto step = op.step()[i]; 133*c30ab6c2SEugene Zhulenev auto range = rewriter.create<SubIOp>(loc, ub, lb); 134*c30ab6c2SEugene Zhulenev tripCounts[i] = divup(range, step); 135*c30ab6c2SEugene Zhulenev } 136*c30ab6c2SEugene Zhulenev 137*c30ab6c2SEugene Zhulenev // The target number of concurrent async.execute ops. 138*c30ab6c2SEugene Zhulenev auto numExecuteOps = rewriter.create<ConstantOp>( 139*c30ab6c2SEugene Zhulenev loc, indexTy, IntegerAttr::get(indexTy, numConcurrentAsyncExecute)); 140*c30ab6c2SEugene Zhulenev 141*c30ab6c2SEugene Zhulenev // Blocks sizes configuration for each induction variable. 142*c30ab6c2SEugene Zhulenev 143*c30ab6c2SEugene Zhulenev // We try to use maximum available concurrency in outer dimensions first 144*c30ab6c2SEugene Zhulenev // (assuming that parallel induction variables are corresponding to some 145*c30ab6c2SEugene Zhulenev // multidimensional access, e.g. in (%d0, %d1, ..., %dn) = (<from>) to (<to>) 146*c30ab6c2SEugene Zhulenev // we will try to parallelize iteration along the %d0. If %d0 is too small, 147*c30ab6c2SEugene Zhulenev // we'll parallelize iteration over %d1, and so on. 148*c30ab6c2SEugene Zhulenev SmallVector<Value, 4> targetNumBlocks(op.getNumLoops()); 149*c30ab6c2SEugene Zhulenev SmallVector<Value, 4> blockSize(op.getNumLoops()); 150*c30ab6c2SEugene Zhulenev SmallVector<Value, 4> numBlocks(op.getNumLoops()); 151*c30ab6c2SEugene Zhulenev 152*c30ab6c2SEugene Zhulenev // Compute block size and number of blocks along the first induction variable. 153*c30ab6c2SEugene Zhulenev targetNumBlocks[0] = numExecuteOps; 154*c30ab6c2SEugene Zhulenev blockSize[0] = divup(tripCounts[0], targetNumBlocks[0]); 155*c30ab6c2SEugene Zhulenev numBlocks[0] = divup(tripCounts[0], blockSize[0]); 156*c30ab6c2SEugene Zhulenev 157*c30ab6c2SEugene Zhulenev // Assign remaining available concurrency to other induction variables. 158*c30ab6c2SEugene Zhulenev for (size_t i = 1; i < op.getNumLoops(); ++i) { 159*c30ab6c2SEugene Zhulenev targetNumBlocks[i] = divup(targetNumBlocks[i - 1], numBlocks[i - 1]); 160*c30ab6c2SEugene Zhulenev blockSize[i] = divup(tripCounts[i], targetNumBlocks[i]); 161*c30ab6c2SEugene Zhulenev numBlocks[i] = divup(tripCounts[i], blockSize[i]); 162*c30ab6c2SEugene Zhulenev } 163*c30ab6c2SEugene Zhulenev 164*c30ab6c2SEugene Zhulenev // Create an async.group to wait on all async tokens from async execute ops. 165*c30ab6c2SEugene Zhulenev auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx)); 166*c30ab6c2SEugene Zhulenev 167*c30ab6c2SEugene Zhulenev // Build a scf.for loop nest from the parallel operation. 168*c30ab6c2SEugene Zhulenev 169*c30ab6c2SEugene Zhulenev // Lower/upper bounds for nest block level computations. 170*c30ab6c2SEugene Zhulenev SmallVector<Value, 4> blockLowerBounds(op.getNumLoops()); 171*c30ab6c2SEugene Zhulenev SmallVector<Value, 4> blockUpperBounds(op.getNumLoops()); 172*c30ab6c2SEugene Zhulenev SmallVector<Value, 4> blockInductionVars(op.getNumLoops()); 173*c30ab6c2SEugene Zhulenev 174*c30ab6c2SEugene Zhulenev using LoopBodyBuilder = 175*c30ab6c2SEugene Zhulenev std::function<void(OpBuilder &, Location, Value, ValueRange)>; 176*c30ab6c2SEugene Zhulenev using LoopBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>; 177*c30ab6c2SEugene Zhulenev 178*c30ab6c2SEugene Zhulenev // Builds inner loop nest inside async.execute operation that does all the 179*c30ab6c2SEugene Zhulenev // work concurrently. 180*c30ab6c2SEugene Zhulenev LoopBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { 181*c30ab6c2SEugene Zhulenev return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) { 182*c30ab6c2SEugene Zhulenev blockInductionVars[loopIdx] = iv; 183*c30ab6c2SEugene Zhulenev 184*c30ab6c2SEugene Zhulenev // Continute building async loop nest. 185*c30ab6c2SEugene Zhulenev if (loopIdx < op.getNumLoops() - 1) { 186*c30ab6c2SEugene Zhulenev b.create<scf::ForOp>( 187*c30ab6c2SEugene Zhulenev loc, blockLowerBounds[loopIdx + 1], blockUpperBounds[loopIdx + 1], 188*c30ab6c2SEugene Zhulenev op.step()[loopIdx + 1], ValueRange(), workLoopBuilder(loopIdx + 1)); 189*c30ab6c2SEugene Zhulenev b.create<scf::YieldOp>(loc); 190*c30ab6c2SEugene Zhulenev return; 191*c30ab6c2SEugene Zhulenev } 192*c30ab6c2SEugene Zhulenev 193*c30ab6c2SEugene Zhulenev // Copy the body of the parallel op with new loop bounds. 194*c30ab6c2SEugene Zhulenev BlockAndValueMapping mapping; 195*c30ab6c2SEugene Zhulenev mapping.map(op.getInductionVars(), blockInductionVars); 196*c30ab6c2SEugene Zhulenev 197*c30ab6c2SEugene Zhulenev for (auto &bodyOp : op.getLoopBody().getOps()) 198*c30ab6c2SEugene Zhulenev b.clone(bodyOp, mapping); 199*c30ab6c2SEugene Zhulenev }; 200*c30ab6c2SEugene Zhulenev }; 201*c30ab6c2SEugene Zhulenev 202*c30ab6c2SEugene Zhulenev // Builds a loop nest that does async execute op dispatching. 203*c30ab6c2SEugene Zhulenev LoopBuilder asyncLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { 204*c30ab6c2SEugene Zhulenev return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) { 205*c30ab6c2SEugene Zhulenev auto lb = op.lowerBound()[loopIdx]; 206*c30ab6c2SEugene Zhulenev auto ub = op.upperBound()[loopIdx]; 207*c30ab6c2SEugene Zhulenev auto step = op.step()[loopIdx]; 208*c30ab6c2SEugene Zhulenev 209*c30ab6c2SEugene Zhulenev // Compute lower bound for the current block: 210*c30ab6c2SEugene Zhulenev // blockLowerBound = iv * blockSize * step + lowerBound 211*c30ab6c2SEugene Zhulenev auto s0 = b.create<MulIOp>(loc, iv, blockSize[loopIdx]); 212*c30ab6c2SEugene Zhulenev auto s1 = b.create<MulIOp>(loc, s0, step); 213*c30ab6c2SEugene Zhulenev auto s2 = b.create<AddIOp>(loc, s1, lb); 214*c30ab6c2SEugene Zhulenev blockLowerBounds[loopIdx] = s2; 215*c30ab6c2SEugene Zhulenev 216*c30ab6c2SEugene Zhulenev // Compute upper bound for the current block: 217*c30ab6c2SEugene Zhulenev // blockUpperBound = min(upperBound, 218*c30ab6c2SEugene Zhulenev // blockLowerBound + blockSize * step) 219*c30ab6c2SEugene Zhulenev auto e0 = b.create<MulIOp>(loc, blockSize[loopIdx], step); 220*c30ab6c2SEugene Zhulenev auto e1 = b.create<AddIOp>(loc, e0, s2); 221*c30ab6c2SEugene Zhulenev auto e2 = b.create<CmpIOp>(loc, CmpIPredicate::slt, e1, ub); 222*c30ab6c2SEugene Zhulenev auto e3 = b.create<SelectOp>(loc, e2, e1, ub); 223*c30ab6c2SEugene Zhulenev blockUpperBounds[loopIdx] = e3; 224*c30ab6c2SEugene Zhulenev 225*c30ab6c2SEugene Zhulenev // Continue building async dispatch loop nest. 226*c30ab6c2SEugene Zhulenev if (loopIdx < op.getNumLoops() - 1) { 227*c30ab6c2SEugene Zhulenev b.create<scf::ForOp>(loc, c0, numBlocks[loopIdx + 1], c1, ValueRange(), 228*c30ab6c2SEugene Zhulenev asyncLoopBuilder(loopIdx + 1)); 229*c30ab6c2SEugene Zhulenev b.create<scf::YieldOp>(loc); 230*c30ab6c2SEugene Zhulenev return; 231*c30ab6c2SEugene Zhulenev } 232*c30ab6c2SEugene Zhulenev 233*c30ab6c2SEugene Zhulenev // Build the inner loop nest that will do the actual work inside the 234*c30ab6c2SEugene Zhulenev // `async.execute` body region. 235*c30ab6c2SEugene Zhulenev auto executeBodyBuilder = [&](OpBuilder &executeBuilder, 236*c30ab6c2SEugene Zhulenev Location executeLoc, 237*c30ab6c2SEugene Zhulenev ValueRange executeArgs) { 238*c30ab6c2SEugene Zhulenev executeBuilder.create<scf::ForOp>(executeLoc, blockLowerBounds[0], 239*c30ab6c2SEugene Zhulenev blockUpperBounds[0], op.step()[0], 240*c30ab6c2SEugene Zhulenev ValueRange(), workLoopBuilder(0)); 241*c30ab6c2SEugene Zhulenev executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); 242*c30ab6c2SEugene Zhulenev }; 243*c30ab6c2SEugene Zhulenev 244*c30ab6c2SEugene Zhulenev auto execute = b.create<ExecuteOp>( 245*c30ab6c2SEugene Zhulenev loc, /*resultTypes=*/TypeRange(), /*dependencies=*/ValueRange(), 246*c30ab6c2SEugene Zhulenev /*operands=*/ValueRange(), executeBodyBuilder); 247*c30ab6c2SEugene Zhulenev auto rankType = IndexType::get(ctx); 248*c30ab6c2SEugene Zhulenev b.create<AddToGroupOp>(loc, rankType, execute.token(), group.result()); 249*c30ab6c2SEugene Zhulenev b.create<scf::YieldOp>(loc); 250*c30ab6c2SEugene Zhulenev }; 251*c30ab6c2SEugene Zhulenev }; 252*c30ab6c2SEugene Zhulenev 253*c30ab6c2SEugene Zhulenev // Start building a loop nest from the first induction variable. 254*c30ab6c2SEugene Zhulenev rewriter.create<scf::ForOp>(loc, c0, numBlocks[0], c1, ValueRange(), 255*c30ab6c2SEugene Zhulenev asyncLoopBuilder(0)); 256*c30ab6c2SEugene Zhulenev 257*c30ab6c2SEugene Zhulenev // Wait for the completion of all subtasks. 258*c30ab6c2SEugene Zhulenev rewriter.create<AwaitAllOp>(loc, group.result()); 259*c30ab6c2SEugene Zhulenev 260*c30ab6c2SEugene Zhulenev // Erase the original parallel operation. 261*c30ab6c2SEugene Zhulenev rewriter.eraseOp(op); 262*c30ab6c2SEugene Zhulenev 263*c30ab6c2SEugene Zhulenev return success(); 264*c30ab6c2SEugene Zhulenev } 265*c30ab6c2SEugene Zhulenev 266*c30ab6c2SEugene Zhulenev void AsyncParallelForPass::runOnFunction() { 267*c30ab6c2SEugene Zhulenev MLIRContext *ctx = &getContext(); 268*c30ab6c2SEugene Zhulenev 269*c30ab6c2SEugene Zhulenev OwningRewritePatternList patterns; 270*c30ab6c2SEugene Zhulenev patterns.insert<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute); 271*c30ab6c2SEugene Zhulenev 272*c30ab6c2SEugene Zhulenev if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) 273*c30ab6c2SEugene Zhulenev signalPassFailure(); 274*c30ab6c2SEugene Zhulenev } 275*c30ab6c2SEugene Zhulenev 276*c30ab6c2SEugene Zhulenev std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() { 277*c30ab6c2SEugene Zhulenev return std::make_unique<AsyncParallelForPass>(); 278*c30ab6c2SEugene Zhulenev } 279