1c30ab6c2SEugene Zhulenev //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// 2c30ab6c2SEugene Zhulenev // 3c30ab6c2SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c30ab6c2SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 5c30ab6c2SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c30ab6c2SEugene Zhulenev // 7c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===// 8c30ab6c2SEugene Zhulenev // 9*86ad0af8SEugene Zhulenev // This file implements scf.parallel to scf.for + async.execute conversion pass. 10c30ab6c2SEugene Zhulenev // 11c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===// 12c30ab6c2SEugene Zhulenev 13c30ab6c2SEugene Zhulenev #include "PassDetail.h" 14c30ab6c2SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h" 15c30ab6c2SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h" 16c30ab6c2SEugene Zhulenev #include "mlir/Dialect/SCF/SCF.h" 17c30ab6c2SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h" 18c30ab6c2SEugene Zhulenev #include "mlir/IR/BlockAndValueMapping.h" 19*86ad0af8SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h" 20c30ab6c2SEugene Zhulenev #include "mlir/IR/PatternMatch.h" 21c30ab6c2SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22*86ad0af8SEugene Zhulenev #include "mlir/Transforms/RegionUtils.h" 23c30ab6c2SEugene Zhulenev 24c30ab6c2SEugene Zhulenev using namespace mlir; 25c30ab6c2SEugene Zhulenev using namespace mlir::async; 26c30ab6c2SEugene Zhulenev 27c30ab6c2SEugene Zhulenev #define DEBUG_TYPE "async-parallel-for" 28c30ab6c2SEugene Zhulenev 29c30ab6c2SEugene Zhulenev namespace { 30c30ab6c2SEugene Zhulenev 31c30ab6c2SEugene Zhulenev // Rewrite scf.parallel operation into multiple concurrent async.execute 32c30ab6c2SEugene Zhulenev // operations over non overlapping subranges of the original loop. 33c30ab6c2SEugene Zhulenev // 34c30ab6c2SEugene Zhulenev // Example: 35c30ab6c2SEugene Zhulenev // 36*86ad0af8SEugene Zhulenev // scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) { 37c30ab6c2SEugene Zhulenev // "do_some_compute"(%i, %j): () -> () 38c30ab6c2SEugene Zhulenev // } 39c30ab6c2SEugene Zhulenev // 40c30ab6c2SEugene Zhulenev // Converted to: 41c30ab6c2SEugene Zhulenev // 42*86ad0af8SEugene Zhulenev // // Parallel compute function that executes the parallel body region for 43*86ad0af8SEugene Zhulenev // // a subset of the parallel iteration space defined by the one-dimensional 44*86ad0af8SEugene Zhulenev // // compute block index. 45*86ad0af8SEugene Zhulenev // func parallel_compute_function(%block_index : index, %block_size : index, 46*86ad0af8SEugene Zhulenev // <parallel operation properties>, ...) { 47*86ad0af8SEugene Zhulenev // // Compute multi-dimensional loop bounds for %block_index. 48*86ad0af8SEugene Zhulenev // %block_lbi, %block_lbj = ... 49*86ad0af8SEugene Zhulenev // %block_ubi, %block_ubj = ... 50c30ab6c2SEugene Zhulenev // 51*86ad0af8SEugene Zhulenev // // Clone parallel operation body into the scf.for loop nest. 52*86ad0af8SEugene Zhulenev // scf.for %i = %blockLbi to %blockUbi { 53*86ad0af8SEugene Zhulenev // scf.for %j = block_lbj to %block_ubj { 54c30ab6c2SEugene Zhulenev // "do_some_compute"(%i, %j): () -> () 55c30ab6c2SEugene Zhulenev // } 56c30ab6c2SEugene Zhulenev // } 57c30ab6c2SEugene Zhulenev // } 58c30ab6c2SEugene Zhulenev // 59*86ad0af8SEugene Zhulenev // And a dispatch function depending on the `asyncDispatch` option. 60*86ad0af8SEugene Zhulenev // 61*86ad0af8SEugene Zhulenev // When async dispatch is on: (pseudocode) 62*86ad0af8SEugene Zhulenev // 63*86ad0af8SEugene Zhulenev // %block_size = ... compute parallel compute block size 64*86ad0af8SEugene Zhulenev // %block_count = ... compute the number of compute blocks 65*86ad0af8SEugene Zhulenev // 66*86ad0af8SEugene Zhulenev // func @async_dispatch(%block_start : index, %block_end : index, ...) { 67*86ad0af8SEugene Zhulenev // // Keep splitting block range until we reached a range of size 1. 68*86ad0af8SEugene Zhulenev // while (%block_end - %block_start > 1) { 69*86ad0af8SEugene Zhulenev // %mid_index = block_start + (block_end - block_start) / 2; 70*86ad0af8SEugene Zhulenev // async.execute { call @async_dispatch(%mid_index, %block_end); } 71*86ad0af8SEugene Zhulenev // %block_end = %mid_index 72c30ab6c2SEugene Zhulenev // } 73c30ab6c2SEugene Zhulenev // 74*86ad0af8SEugene Zhulenev // // Call parallel compute function for a single block. 75*86ad0af8SEugene Zhulenev // call @parallel_compute_fn(%block_start, %block_size, ...); 76*86ad0af8SEugene Zhulenev // } 77c30ab6c2SEugene Zhulenev // 78*86ad0af8SEugene Zhulenev // // Launch async dispatch for [0, block_count) range. 79*86ad0af8SEugene Zhulenev // call @async_dispatch(%c0, %block_count); 80c30ab6c2SEugene Zhulenev // 81*86ad0af8SEugene Zhulenev // When async dispatch is off: 82c30ab6c2SEugene Zhulenev // 83*86ad0af8SEugene Zhulenev // %block_size = ... compute parallel compute block size 84*86ad0af8SEugene Zhulenev // %block_count = ... compute the number of compute blocks 85*86ad0af8SEugene Zhulenev // 86*86ad0af8SEugene Zhulenev // scf.for %block_index = %c0 to %block_count { 87*86ad0af8SEugene Zhulenev // call @parallel_compute_fn(%block_index, %block_size, ...) 88*86ad0af8SEugene Zhulenev // } 89*86ad0af8SEugene Zhulenev // 90*86ad0af8SEugene Zhulenev struct AsyncParallelForPass 91*86ad0af8SEugene Zhulenev : public AsyncParallelForBase<AsyncParallelForPass> { 92*86ad0af8SEugene Zhulenev AsyncParallelForPass() = default; 93*86ad0af8SEugene Zhulenev void runOnOperation() override; 94*86ad0af8SEugene Zhulenev }; 95*86ad0af8SEugene Zhulenev 96c30ab6c2SEugene Zhulenev struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { 97c30ab6c2SEugene Zhulenev public: 98*86ad0af8SEugene Zhulenev AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch, 99*86ad0af8SEugene Zhulenev int32_t numWorkerThreads, int32_t targetBlockSize) 100*86ad0af8SEugene Zhulenev : OpRewritePattern(ctx), asyncDispatch(asyncDispatch), 101*86ad0af8SEugene Zhulenev numWorkerThreads(numWorkerThreads), targetBlockSize(targetBlockSize) {} 102c30ab6c2SEugene Zhulenev 103c30ab6c2SEugene Zhulenev LogicalResult matchAndRewrite(scf::ParallelOp op, 104c30ab6c2SEugene Zhulenev PatternRewriter &rewriter) const override; 105c30ab6c2SEugene Zhulenev 106c30ab6c2SEugene Zhulenev private: 107*86ad0af8SEugene Zhulenev // The maximum number of tasks per worker thread when sharding parallel op. 108*86ad0af8SEugene Zhulenev static constexpr int32_t kMaxOversharding = 4; 109*86ad0af8SEugene Zhulenev 110*86ad0af8SEugene Zhulenev bool asyncDispatch; 111*86ad0af8SEugene Zhulenev int32_t numWorkerThreads; 112*86ad0af8SEugene Zhulenev int32_t targetBlockSize; 113c30ab6c2SEugene Zhulenev }; 114c30ab6c2SEugene Zhulenev 115*86ad0af8SEugene Zhulenev struct ParallelComputeFunctionType { 116*86ad0af8SEugene Zhulenev FunctionType type; 117*86ad0af8SEugene Zhulenev llvm::SmallVector<Value> captures; 118*86ad0af8SEugene Zhulenev }; 119*86ad0af8SEugene Zhulenev 120*86ad0af8SEugene Zhulenev struct ParallelComputeFunction { 121*86ad0af8SEugene Zhulenev FuncOp func; 122*86ad0af8SEugene Zhulenev llvm::SmallVector<Value> captures; 123c30ab6c2SEugene Zhulenev }; 124c30ab6c2SEugene Zhulenev 125c30ab6c2SEugene Zhulenev } // namespace 126c30ab6c2SEugene Zhulenev 127*86ad0af8SEugene Zhulenev // Converts one-dimensional iteration index in the [0, tripCount) interval 128*86ad0af8SEugene Zhulenev // into multidimensional iteration coordinate. 129*86ad0af8SEugene Zhulenev static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index, 130*86ad0af8SEugene Zhulenev const SmallVector<Value> &tripCounts) { 131*86ad0af8SEugene Zhulenev SmallVector<Value> coords(tripCounts.size()); 132*86ad0af8SEugene Zhulenev assert(!tripCounts.empty() && "tripCounts must be not empty"); 133*86ad0af8SEugene Zhulenev 134*86ad0af8SEugene Zhulenev for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) { 135*86ad0af8SEugene Zhulenev coords[i] = b.create<SignedRemIOp>(index, tripCounts[i]); 136*86ad0af8SEugene Zhulenev index = b.create<SignedDivIOp>(index, tripCounts[i]); 137*86ad0af8SEugene Zhulenev } 138*86ad0af8SEugene Zhulenev 139*86ad0af8SEugene Zhulenev return coords; 140*86ad0af8SEugene Zhulenev } 141*86ad0af8SEugene Zhulenev 142*86ad0af8SEugene Zhulenev // Returns a function type and implicit captures for a parallel compute 143*86ad0af8SEugene Zhulenev // function. We'll need a list of implicit captures to setup block and value 144*86ad0af8SEugene Zhulenev // mapping when we'll clone the body of the parallel operation. 145*86ad0af8SEugene Zhulenev static ParallelComputeFunctionType 146*86ad0af8SEugene Zhulenev getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) { 147*86ad0af8SEugene Zhulenev // Values implicitly captured by the parallel operation. 148*86ad0af8SEugene Zhulenev llvm::SetVector<Value> captures; 149*86ad0af8SEugene Zhulenev getUsedValuesDefinedAbove(op.region(), op.region(), captures); 150*86ad0af8SEugene Zhulenev 151*86ad0af8SEugene Zhulenev llvm::SmallVector<Type> inputs; 152*86ad0af8SEugene Zhulenev inputs.reserve(2 + 4 * op.getNumLoops() + captures.size()); 153*86ad0af8SEugene Zhulenev 154*86ad0af8SEugene Zhulenev Type indexTy = rewriter.getIndexType(); 155*86ad0af8SEugene Zhulenev 156*86ad0af8SEugene Zhulenev // One-dimensional iteration space defined by the block index and size. 157*86ad0af8SEugene Zhulenev inputs.push_back(indexTy); // blockIndex 158*86ad0af8SEugene Zhulenev inputs.push_back(indexTy); // blockSize 159*86ad0af8SEugene Zhulenev 160*86ad0af8SEugene Zhulenev // Multi-dimensional parallel iteration space defined by the loop trip counts. 161*86ad0af8SEugene Zhulenev for (unsigned i = 0; i < op.getNumLoops(); ++i) 162*86ad0af8SEugene Zhulenev inputs.push_back(indexTy); // loop tripCount 163*86ad0af8SEugene Zhulenev 164*86ad0af8SEugene Zhulenev // Parallel operation lower bound, upper bound and step. 165*86ad0af8SEugene Zhulenev for (unsigned i = 0; i < op.getNumLoops(); ++i) { 166*86ad0af8SEugene Zhulenev inputs.push_back(indexTy); // lower bound 167*86ad0af8SEugene Zhulenev inputs.push_back(indexTy); // upper bound 168*86ad0af8SEugene Zhulenev inputs.push_back(indexTy); // step 169*86ad0af8SEugene Zhulenev } 170*86ad0af8SEugene Zhulenev 171*86ad0af8SEugene Zhulenev // Types of the implicit captures. 172*86ad0af8SEugene Zhulenev for (Value capture : captures) 173*86ad0af8SEugene Zhulenev inputs.push_back(capture.getType()); 174*86ad0af8SEugene Zhulenev 175*86ad0af8SEugene Zhulenev // Convert captures to vector for later convenience. 176*86ad0af8SEugene Zhulenev SmallVector<Value> capturesVector(captures.begin(), captures.end()); 177*86ad0af8SEugene Zhulenev return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector}; 178*86ad0af8SEugene Zhulenev } 179*86ad0af8SEugene Zhulenev 180*86ad0af8SEugene Zhulenev // Create a parallel compute fuction from the parallel operation. 181*86ad0af8SEugene Zhulenev static ParallelComputeFunction 182*86ad0af8SEugene Zhulenev createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) { 183*86ad0af8SEugene Zhulenev OpBuilder::InsertionGuard guard(rewriter); 184*86ad0af8SEugene Zhulenev ImplicitLocOpBuilder b(op.getLoc(), rewriter); 185*86ad0af8SEugene Zhulenev 186*86ad0af8SEugene Zhulenev ModuleOp module = op->getParentOfType<ModuleOp>(); 187*86ad0af8SEugene Zhulenev b.setInsertionPointToStart(&module->getRegion(0).front()); 188*86ad0af8SEugene Zhulenev 189*86ad0af8SEugene Zhulenev ParallelComputeFunctionType computeFuncType = 190*86ad0af8SEugene Zhulenev getParallelComputeFunctionType(op, rewriter); 191*86ad0af8SEugene Zhulenev 192*86ad0af8SEugene Zhulenev FunctionType type = computeFuncType.type; 193*86ad0af8SEugene Zhulenev FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type); 194*86ad0af8SEugene Zhulenev func.setPrivate(); 195*86ad0af8SEugene Zhulenev 196*86ad0af8SEugene Zhulenev // Insert function into the module symbol table and assign it unique name. 197*86ad0af8SEugene Zhulenev SymbolTable symbolTable(module); 198*86ad0af8SEugene Zhulenev symbolTable.insert(func); 199*86ad0af8SEugene Zhulenev rewriter.getListener()->notifyOperationInserted(func); 200*86ad0af8SEugene Zhulenev 201*86ad0af8SEugene Zhulenev // Create function entry block. 202*86ad0af8SEugene Zhulenev Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs()); 203*86ad0af8SEugene Zhulenev b.setInsertionPointToEnd(block); 204*86ad0af8SEugene Zhulenev 205*86ad0af8SEugene Zhulenev unsigned offset = 0; // argument offset for arguments decoding 206*86ad0af8SEugene Zhulenev 207*86ad0af8SEugene Zhulenev // Load multiple arguments into values vector. 208*86ad0af8SEugene Zhulenev auto getArguments = [&](unsigned num_arguments) -> SmallVector<Value> { 209*86ad0af8SEugene Zhulenev SmallVector<Value> values(num_arguments); 210*86ad0af8SEugene Zhulenev for (unsigned i = 0; i < num_arguments; ++i) 211*86ad0af8SEugene Zhulenev values[i] = block->getArgument(offset++); 212*86ad0af8SEugene Zhulenev return values; 213*86ad0af8SEugene Zhulenev }; 214*86ad0af8SEugene Zhulenev 215*86ad0af8SEugene Zhulenev // Block iteration position defined by the block index and size. 216*86ad0af8SEugene Zhulenev Value blockIndex = block->getArgument(offset++); 217*86ad0af8SEugene Zhulenev Value blockSize = block->getArgument(offset++); 218*86ad0af8SEugene Zhulenev 219*86ad0af8SEugene Zhulenev // Constants used below. 220*86ad0af8SEugene Zhulenev Value c0 = b.create<ConstantOp>(b.getIndexAttr(0)); 221*86ad0af8SEugene Zhulenev Value c1 = b.create<ConstantOp>(b.getIndexAttr(1)); 222*86ad0af8SEugene Zhulenev 223*86ad0af8SEugene Zhulenev // Multi-dimensional parallel iteration space defined by the loop trip counts. 224*86ad0af8SEugene Zhulenev SmallVector<Value> tripCounts = getArguments(op.getNumLoops()); 225*86ad0af8SEugene Zhulenev 226*86ad0af8SEugene Zhulenev // Compute a product of trip counts to get the size of the flattened 227*86ad0af8SEugene Zhulenev // one-dimensional iteration space. 228*86ad0af8SEugene Zhulenev Value tripCount = tripCounts[0]; 229*86ad0af8SEugene Zhulenev for (unsigned i = 1; i < tripCounts.size(); ++i) 230*86ad0af8SEugene Zhulenev tripCount = b.create<MulIOp>(tripCount, tripCounts[i]); 231*86ad0af8SEugene Zhulenev 232*86ad0af8SEugene Zhulenev // Parallel operation lower bound, upper bound and step. 233*86ad0af8SEugene Zhulenev SmallVector<Value> lowerBound = getArguments(op.getNumLoops()); 234*86ad0af8SEugene Zhulenev SmallVector<Value> upperBound = getArguments(op.getNumLoops()); 235*86ad0af8SEugene Zhulenev SmallVector<Value> step = getArguments(op.getNumLoops()); 236*86ad0af8SEugene Zhulenev 237*86ad0af8SEugene Zhulenev // Remaining arguments are implicit captures of the parallel operation. 238*86ad0af8SEugene Zhulenev SmallVector<Value> captures = getArguments(block->getNumArguments() - offset); 239*86ad0af8SEugene Zhulenev 240*86ad0af8SEugene Zhulenev // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]: 241*86ad0af8SEugene Zhulenev // blockFirstIndex = blockIndex * blockSize 242*86ad0af8SEugene Zhulenev Value blockFirstIndex = b.create<MulIOp>(blockIndex, blockSize); 243*86ad0af8SEugene Zhulenev 244*86ad0af8SEugene Zhulenev // The last one-dimensional index in the block defined by the `blockIndex`: 245*86ad0af8SEugene Zhulenev // blockLastIndex = max((blockIndex + 1) * blockSize, tripCount) - 1 246*86ad0af8SEugene Zhulenev Value blockEnd0 = b.create<AddIOp>(blockIndex, c1); 247*86ad0af8SEugene Zhulenev Value blockEnd1 = b.create<MulIOp>(blockEnd0, blockSize); 248*86ad0af8SEugene Zhulenev Value blockEnd2 = b.create<CmpIOp>(CmpIPredicate::sge, blockEnd1, tripCount); 249*86ad0af8SEugene Zhulenev Value blockEnd3 = b.create<SelectOp>(blockEnd2, tripCount, blockEnd1); 250*86ad0af8SEugene Zhulenev Value blockLastIndex = b.create<SubIOp>(blockEnd3, c1); 251*86ad0af8SEugene Zhulenev 252*86ad0af8SEugene Zhulenev // Convert one-dimensional indices to multi-dimensional coordinates. 253*86ad0af8SEugene Zhulenev auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts); 254*86ad0af8SEugene Zhulenev auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts); 255*86ad0af8SEugene Zhulenev 256*86ad0af8SEugene Zhulenev // Compute compute loops upper bounds from the block last coordinates: 257*86ad0af8SEugene Zhulenev // blockEndCoord[i] = blockLastCoord[i] + 1 258*86ad0af8SEugene Zhulenev // 259*86ad0af8SEugene Zhulenev // Block first and last coordinates can be the same along the outer compute 260*86ad0af8SEugene Zhulenev // dimension when inner compute dimension containts multple blocks. 261*86ad0af8SEugene Zhulenev SmallVector<Value> blockEndCoord(op.getNumLoops()); 262*86ad0af8SEugene Zhulenev for (size_t i = 0; i < blockLastCoord.size(); ++i) 263*86ad0af8SEugene Zhulenev blockEndCoord[i] = b.create<AddIOp>(blockLastCoord[i], c1); 264*86ad0af8SEugene Zhulenev 265*86ad0af8SEugene Zhulenev // Construct a loop nest out of scf.for operations that will iterate over 266*86ad0af8SEugene Zhulenev // all coordinates in [blockFirstCoord, blockLastCoord] range. 267*86ad0af8SEugene Zhulenev using LoopBodyBuilder = 268*86ad0af8SEugene Zhulenev std::function<void(OpBuilder &, Location, Value, ValueRange)>; 269*86ad0af8SEugene Zhulenev using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>; 270*86ad0af8SEugene Zhulenev 271*86ad0af8SEugene Zhulenev // Parallel region induction variables computed from the multi-dimensional 272*86ad0af8SEugene Zhulenev // iteration coordinate using parallel operation bounds and step: 273*86ad0af8SEugene Zhulenev // 274*86ad0af8SEugene Zhulenev // computeBlockInductionVars[loopIdx] = 275*86ad0af8SEugene Zhulenev // lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopDdx] 276*86ad0af8SEugene Zhulenev SmallVector<Value> computeBlockInductionVars(op.getNumLoops()); 277*86ad0af8SEugene Zhulenev 278*86ad0af8SEugene Zhulenev // We need to know if we are in the first or last iteration of the 279*86ad0af8SEugene Zhulenev // multi-dimensional loop for each loop in the nest, so we can decide what 280*86ad0af8SEugene Zhulenev // loop bounds should we use for the nested loops: bounds defined by compute 281*86ad0af8SEugene Zhulenev // block interval, or bounds defined by the parallel operation. 282*86ad0af8SEugene Zhulenev // 283*86ad0af8SEugene Zhulenev // Example: 2d parallel operation 284*86ad0af8SEugene Zhulenev // i j 285*86ad0af8SEugene Zhulenev // loop sizes: [50, 50] 286*86ad0af8SEugene Zhulenev // first coord: [25, 25] 287*86ad0af8SEugene Zhulenev // last coord: [30, 30] 288*86ad0af8SEugene Zhulenev // 289*86ad0af8SEugene Zhulenev // If `i` is equal to 25 then iteration over `j` should start at 25, when `i` 290*86ad0af8SEugene Zhulenev // is between 25 and 30 it should start at 0. The upper bound for `j` should 291*86ad0af8SEugene Zhulenev // be 50, except when `i` is equal to 30, then it should also be 30. 292*86ad0af8SEugene Zhulenev // 293*86ad0af8SEugene Zhulenev // Value at ith position specifies if all loops in [0, i) range of the loop 294*86ad0af8SEugene Zhulenev // nest are in the first/last iteration. 295*86ad0af8SEugene Zhulenev SmallVector<Value> isBlockFirstCoord(op.getNumLoops()); 296*86ad0af8SEugene Zhulenev SmallVector<Value> isBlockLastCoord(op.getNumLoops()); 297*86ad0af8SEugene Zhulenev 298*86ad0af8SEugene Zhulenev // Builds inner loop nest inside async.execute operation that does all the 299*86ad0af8SEugene Zhulenev // work concurrently. 300*86ad0af8SEugene Zhulenev LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { 301*86ad0af8SEugene Zhulenev return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv, 302*86ad0af8SEugene Zhulenev ValueRange args) { 303*86ad0af8SEugene Zhulenev ImplicitLocOpBuilder nb(loc, nestedBuilder); 304*86ad0af8SEugene Zhulenev 305*86ad0af8SEugene Zhulenev // Compute induction variable for `loopIdx`. 306*86ad0af8SEugene Zhulenev computeBlockInductionVars[loopIdx] = nb.create<AddIOp>( 307*86ad0af8SEugene Zhulenev lowerBound[loopIdx], nb.create<MulIOp>(iv, step[loopIdx])); 308*86ad0af8SEugene Zhulenev 309*86ad0af8SEugene Zhulenev // Check if we are inside first or last iteration of the loop. 310*86ad0af8SEugene Zhulenev isBlockFirstCoord[loopIdx] = 311*86ad0af8SEugene Zhulenev nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]); 312*86ad0af8SEugene Zhulenev isBlockLastCoord[loopIdx] = 313*86ad0af8SEugene Zhulenev nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockLastCoord[loopIdx]); 314*86ad0af8SEugene Zhulenev 315*86ad0af8SEugene Zhulenev // Check if the previous loop is in its first of last iteration. 316*86ad0af8SEugene Zhulenev if (loopIdx > 0) { 317*86ad0af8SEugene Zhulenev isBlockFirstCoord[loopIdx] = nb.create<AndOp>( 318*86ad0af8SEugene Zhulenev isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]); 319*86ad0af8SEugene Zhulenev isBlockLastCoord[loopIdx] = nb.create<AndOp>( 320*86ad0af8SEugene Zhulenev isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]); 321*86ad0af8SEugene Zhulenev } 322*86ad0af8SEugene Zhulenev 323*86ad0af8SEugene Zhulenev // Keep building loop nest. 324*86ad0af8SEugene Zhulenev if (loopIdx < op.getNumLoops() - 1) { 325*86ad0af8SEugene Zhulenev // Select nested loop lower/upper bounds depending on out position in 326*86ad0af8SEugene Zhulenev // the multi-dimensional iteration space. 327*86ad0af8SEugene Zhulenev auto lb = nb.create<SelectOp>(isBlockFirstCoord[loopIdx], 328*86ad0af8SEugene Zhulenev blockFirstCoord[loopIdx + 1], c0); 329*86ad0af8SEugene Zhulenev 330*86ad0af8SEugene Zhulenev auto ub = nb.create<SelectOp>(isBlockLastCoord[loopIdx], 331*86ad0af8SEugene Zhulenev blockEndCoord[loopIdx + 1], 332*86ad0af8SEugene Zhulenev tripCounts[loopIdx + 1]); 333*86ad0af8SEugene Zhulenev 334*86ad0af8SEugene Zhulenev nb.create<scf::ForOp>(lb, ub, c1, ValueRange(), 335*86ad0af8SEugene Zhulenev workLoopBuilder(loopIdx + 1)); 336*86ad0af8SEugene Zhulenev nb.create<scf::YieldOp>(loc); 337*86ad0af8SEugene Zhulenev return; 338*86ad0af8SEugene Zhulenev } 339*86ad0af8SEugene Zhulenev 340*86ad0af8SEugene Zhulenev // Copy the body of the parallel op into the inner-most loop. 341*86ad0af8SEugene Zhulenev BlockAndValueMapping mapping; 342*86ad0af8SEugene Zhulenev mapping.map(op.getInductionVars(), computeBlockInductionVars); 343*86ad0af8SEugene Zhulenev mapping.map(computeFuncType.captures, captures); 344*86ad0af8SEugene Zhulenev 345*86ad0af8SEugene Zhulenev for (auto &bodyOp : op.getLoopBody().getOps()) 346*86ad0af8SEugene Zhulenev nb.clone(bodyOp, mapping); 347*86ad0af8SEugene Zhulenev }; 348*86ad0af8SEugene Zhulenev }; 349*86ad0af8SEugene Zhulenev 350*86ad0af8SEugene Zhulenev b.create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(), 351*86ad0af8SEugene Zhulenev workLoopBuilder(0)); 352*86ad0af8SEugene Zhulenev b.create<ReturnOp>(ValueRange()); 353*86ad0af8SEugene Zhulenev 354*86ad0af8SEugene Zhulenev return {func, std::move(computeFuncType.captures)}; 355*86ad0af8SEugene Zhulenev } 356*86ad0af8SEugene Zhulenev 357*86ad0af8SEugene Zhulenev // Creates recursive async dispatch function for the given parallel compute 358*86ad0af8SEugene Zhulenev // function. Dispatch function keeps splitting block range into halves until it 359*86ad0af8SEugene Zhulenev // reaches a single block, and then excecutes it inline. 360*86ad0af8SEugene Zhulenev // 361*86ad0af8SEugene Zhulenev // Function pseudocode (mix of C++ and MLIR): 362*86ad0af8SEugene Zhulenev // 363*86ad0af8SEugene Zhulenev // func @async_dispatch(%block_start : index, %block_end : index, ...) { 364*86ad0af8SEugene Zhulenev // 365*86ad0af8SEugene Zhulenev // // Keep splitting block range until we reached a range of size 1. 366*86ad0af8SEugene Zhulenev // while (%block_end - %block_start > 1) { 367*86ad0af8SEugene Zhulenev // %mid_index = block_start + (block_end - block_start) / 2; 368*86ad0af8SEugene Zhulenev // async.execute { call @async_dispatch(%mid_index, %block_end); } 369*86ad0af8SEugene Zhulenev // %block_end = %mid_index 370*86ad0af8SEugene Zhulenev // } 371*86ad0af8SEugene Zhulenev // 372*86ad0af8SEugene Zhulenev // // Call parallel compute function for a single block. 373*86ad0af8SEugene Zhulenev // call @parallel_compute_fn(%block_start, %block_size, ...); 374*86ad0af8SEugene Zhulenev // } 375*86ad0af8SEugene Zhulenev // 376*86ad0af8SEugene Zhulenev static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, 377*86ad0af8SEugene Zhulenev PatternRewriter &rewriter) { 378*86ad0af8SEugene Zhulenev OpBuilder::InsertionGuard guard(rewriter); 379*86ad0af8SEugene Zhulenev Location loc = computeFunc.func.getLoc(); 380*86ad0af8SEugene Zhulenev ImplicitLocOpBuilder b(loc, rewriter); 381*86ad0af8SEugene Zhulenev 382*86ad0af8SEugene Zhulenev ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>(); 383*86ad0af8SEugene Zhulenev b.setInsertionPointToStart(&module->getRegion(0).front()); 384*86ad0af8SEugene Zhulenev 385*86ad0af8SEugene Zhulenev ArrayRef<Type> computeFuncInputTypes = 386*86ad0af8SEugene Zhulenev computeFunc.func.type().cast<FunctionType>().getInputs(); 387*86ad0af8SEugene Zhulenev 388*86ad0af8SEugene Zhulenev // Compared to the parallel compute function async dispatch function takes 389*86ad0af8SEugene Zhulenev // additional !async.group argument. Also instead of a single `blockIndex` it 390*86ad0af8SEugene Zhulenev // takes `blockStart` and `blockEnd` arguments to define the range of 391*86ad0af8SEugene Zhulenev // dispatched blocks. 392*86ad0af8SEugene Zhulenev SmallVector<Type> inputTypes; 393*86ad0af8SEugene Zhulenev inputTypes.push_back(async::GroupType::get(rewriter.getContext())); 394*86ad0af8SEugene Zhulenev inputTypes.push_back(rewriter.getIndexType()); // add blockStart argument 395*86ad0af8SEugene Zhulenev inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end()); 396*86ad0af8SEugene Zhulenev 397*86ad0af8SEugene Zhulenev FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange()); 398*86ad0af8SEugene Zhulenev FuncOp func = FuncOp::create(loc, "async_dispatch_fn", type); 399*86ad0af8SEugene Zhulenev func.setPrivate(); 400*86ad0af8SEugene Zhulenev 401*86ad0af8SEugene Zhulenev // Insert function into the module symbol table and assign it unique name. 402*86ad0af8SEugene Zhulenev SymbolTable symbolTable(module); 403*86ad0af8SEugene Zhulenev symbolTable.insert(func); 404*86ad0af8SEugene Zhulenev rewriter.getListener()->notifyOperationInserted(func); 405*86ad0af8SEugene Zhulenev 406*86ad0af8SEugene Zhulenev // Create function entry block. 407*86ad0af8SEugene Zhulenev Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs()); 408*86ad0af8SEugene Zhulenev b.setInsertionPointToEnd(block); 409*86ad0af8SEugene Zhulenev 410*86ad0af8SEugene Zhulenev Type indexTy = b.getIndexType(); 411*86ad0af8SEugene Zhulenev Value c1 = b.create<ConstantOp>(b.getIndexAttr(1)); 412*86ad0af8SEugene Zhulenev Value c2 = b.create<ConstantOp>(b.getIndexAttr(2)); 413*86ad0af8SEugene Zhulenev 414*86ad0af8SEugene Zhulenev // Get the async group that will track async dispatch completion. 415*86ad0af8SEugene Zhulenev Value group = block->getArgument(0); 416*86ad0af8SEugene Zhulenev 417*86ad0af8SEugene Zhulenev // Get the block iteration range: [blockStart, blockEnd) 418*86ad0af8SEugene Zhulenev Value blockStart = block->getArgument(1); 419*86ad0af8SEugene Zhulenev Value blockEnd = block->getArgument(2); 420*86ad0af8SEugene Zhulenev 421*86ad0af8SEugene Zhulenev // Create a work splitting while loop for the [blockStart, blockEnd) range. 422*86ad0af8SEugene Zhulenev SmallVector<Type> types = {indexTy, indexTy}; 423*86ad0af8SEugene Zhulenev SmallVector<Value> operands = {blockStart, blockEnd}; 424*86ad0af8SEugene Zhulenev 425*86ad0af8SEugene Zhulenev // Create a recursive dispatch loop. 426*86ad0af8SEugene Zhulenev scf::WhileOp whileOp = b.create<scf::WhileOp>(types, operands); 427*86ad0af8SEugene Zhulenev Block *before = b.createBlock(&whileOp.before(), {}, types); 428*86ad0af8SEugene Zhulenev Block *after = b.createBlock(&whileOp.after(), {}, types); 429*86ad0af8SEugene Zhulenev 430*86ad0af8SEugene Zhulenev // Setup dispatch loop condition block: decide if we need to go into the 431*86ad0af8SEugene Zhulenev // `after` block and launch one more async dispatch. 432*86ad0af8SEugene Zhulenev { 433*86ad0af8SEugene Zhulenev b.setInsertionPointToEnd(before); 434*86ad0af8SEugene Zhulenev Value start = before->getArgument(0); 435*86ad0af8SEugene Zhulenev Value end = before->getArgument(1); 436*86ad0af8SEugene Zhulenev Value distance = b.create<SubIOp>(end, start); 437*86ad0af8SEugene Zhulenev Value dispatch = b.create<CmpIOp>(CmpIPredicate::sgt, distance, c1); 438*86ad0af8SEugene Zhulenev b.create<scf::ConditionOp>(dispatch, before->getArguments()); 439*86ad0af8SEugene Zhulenev } 440*86ad0af8SEugene Zhulenev 441*86ad0af8SEugene Zhulenev // Setup the async dispatch loop body: recursively call dispatch function 442*86ad0af8SEugene Zhulenev // for second the half of the original range and go to the next iteration. 443*86ad0af8SEugene Zhulenev { 444*86ad0af8SEugene Zhulenev b.setInsertionPointToEnd(after); 445*86ad0af8SEugene Zhulenev Value start = after->getArgument(0); 446*86ad0af8SEugene Zhulenev Value end = after->getArgument(1); 447*86ad0af8SEugene Zhulenev Value distance = b.create<SubIOp>(end, start); 448*86ad0af8SEugene Zhulenev Value halfDistance = b.create<SignedDivIOp>(distance, c2); 449*86ad0af8SEugene Zhulenev Value midIndex = b.create<AddIOp>(after->getArgument(0), halfDistance); 450*86ad0af8SEugene Zhulenev 451*86ad0af8SEugene Zhulenev // Call parallel compute function inside the async.execute region. 452*86ad0af8SEugene Zhulenev auto executeBodyBuilder = [&](OpBuilder &executeBuilder, 453*86ad0af8SEugene Zhulenev Location executeLoc, ValueRange executeArgs) { 454*86ad0af8SEugene Zhulenev // Update the original `blockStart` and `blockEnd` with new range. 455*86ad0af8SEugene Zhulenev SmallVector<Value> operands{block->getArguments().begin(), 456*86ad0af8SEugene Zhulenev block->getArguments().end()}; 457*86ad0af8SEugene Zhulenev operands[1] = midIndex; 458*86ad0af8SEugene Zhulenev operands[2] = end; 459*86ad0af8SEugene Zhulenev 460*86ad0af8SEugene Zhulenev executeBuilder.create<CallOp>(executeLoc, func.sym_name(), 461*86ad0af8SEugene Zhulenev func.getCallableResults(), operands); 462*86ad0af8SEugene Zhulenev executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); 463*86ad0af8SEugene Zhulenev }; 464*86ad0af8SEugene Zhulenev 465*86ad0af8SEugene Zhulenev // Create async.execute operation to dispatch half of the block range. 466*86ad0af8SEugene Zhulenev auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(), 467*86ad0af8SEugene Zhulenev executeBodyBuilder); 468*86ad0af8SEugene Zhulenev b.create<AddToGroupOp>(indexTy, execute.token(), group); 469*86ad0af8SEugene Zhulenev b.create<scf::YieldOp>(ValueRange({after->getArgument(0), midIndex})); 470*86ad0af8SEugene Zhulenev } 471*86ad0af8SEugene Zhulenev 472*86ad0af8SEugene Zhulenev // After dispatching async operations to process the tail of the block range 473*86ad0af8SEugene Zhulenev // call the parallel compute function for the first block of the range. 474*86ad0af8SEugene Zhulenev b.setInsertionPointAfter(whileOp); 475*86ad0af8SEugene Zhulenev 476*86ad0af8SEugene Zhulenev // Drop async dispatch specific arguments: async group, block start and end. 477*86ad0af8SEugene Zhulenev auto forwardedInputs = block->getArguments().drop_front(3); 478*86ad0af8SEugene Zhulenev SmallVector<Value> computeFuncOperands = {blockStart}; 479*86ad0af8SEugene Zhulenev computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end()); 480*86ad0af8SEugene Zhulenev 481*86ad0af8SEugene Zhulenev b.create<CallOp>(computeFunc.func.sym_name(), 482*86ad0af8SEugene Zhulenev computeFunc.func.getCallableResults(), computeFuncOperands); 483*86ad0af8SEugene Zhulenev b.create<ReturnOp>(ValueRange()); 484*86ad0af8SEugene Zhulenev 485*86ad0af8SEugene Zhulenev return func; 486*86ad0af8SEugene Zhulenev } 487*86ad0af8SEugene Zhulenev 488*86ad0af8SEugene Zhulenev // Launch async dispatch of the parallel compute function. 489*86ad0af8SEugene Zhulenev static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, 490*86ad0af8SEugene Zhulenev ParallelComputeFunction ¶llelComputeFunction, 491*86ad0af8SEugene Zhulenev scf::ParallelOp op, Value blockSize, 492*86ad0af8SEugene Zhulenev Value blockCount, 493*86ad0af8SEugene Zhulenev const SmallVector<Value> &tripCounts) { 494*86ad0af8SEugene Zhulenev MLIRContext *ctx = op->getContext(); 495*86ad0af8SEugene Zhulenev 496*86ad0af8SEugene Zhulenev // Add one more level of indirection to dispatch parallel compute functions 497*86ad0af8SEugene Zhulenev // using async operations and recursive work splitting. 498*86ad0af8SEugene Zhulenev FuncOp asyncDispatchFunction = 499*86ad0af8SEugene Zhulenev createAsyncDispatchFunction(parallelComputeFunction, rewriter); 500*86ad0af8SEugene Zhulenev 501*86ad0af8SEugene Zhulenev Value c0 = b.create<ConstantOp>(b.getIndexAttr(0)); 502*86ad0af8SEugene Zhulenev Value c1 = b.create<ConstantOp>(b.getIndexAttr(1)); 503*86ad0af8SEugene Zhulenev 504*86ad0af8SEugene Zhulenev // Create an async.group to wait on all async tokens from the concurrent 505*86ad0af8SEugene Zhulenev // execution of multiple parallel compute function. First block will be 506*86ad0af8SEugene Zhulenev // executed synchronously in the caller thread. 507*86ad0af8SEugene Zhulenev Value groupSize = b.create<SubIOp>(blockCount, c1); 508*86ad0af8SEugene Zhulenev Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize); 509*86ad0af8SEugene Zhulenev 510*86ad0af8SEugene Zhulenev // Pack the async dispath function operands to launch the work splitting. 511*86ad0af8SEugene Zhulenev SmallVector<Value> asyncDispatchOperands = {group, c0, blockCount, blockSize}; 512*86ad0af8SEugene Zhulenev asyncDispatchOperands.append(tripCounts); 513*86ad0af8SEugene Zhulenev asyncDispatchOperands.append(op.lowerBound().begin(), op.lowerBound().end()); 514*86ad0af8SEugene Zhulenev asyncDispatchOperands.append(op.upperBound().begin(), op.upperBound().end()); 515*86ad0af8SEugene Zhulenev asyncDispatchOperands.append(op.step().begin(), op.step().end()); 516*86ad0af8SEugene Zhulenev asyncDispatchOperands.append(parallelComputeFunction.captures); 517*86ad0af8SEugene Zhulenev 518*86ad0af8SEugene Zhulenev // Launch async dispatch function for [0, blockCount) range. 519*86ad0af8SEugene Zhulenev b.create<CallOp>(asyncDispatchFunction.sym_name(), 520*86ad0af8SEugene Zhulenev asyncDispatchFunction.getCallableResults(), 521*86ad0af8SEugene Zhulenev asyncDispatchOperands); 522*86ad0af8SEugene Zhulenev 523*86ad0af8SEugene Zhulenev // Wait for the completion of all parallel compute operations. 524*86ad0af8SEugene Zhulenev b.create<AwaitAllOp>(group); 525*86ad0af8SEugene Zhulenev } 526*86ad0af8SEugene Zhulenev 527*86ad0af8SEugene Zhulenev // Dispatch parallel compute functions by submitting all async compute tasks 528*86ad0af8SEugene Zhulenev // from a simple for loop in the caller thread. 529*86ad0af8SEugene Zhulenev static void 530*86ad0af8SEugene Zhulenev doSequantialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, 531*86ad0af8SEugene Zhulenev ParallelComputeFunction ¶llelComputeFunction, 532*86ad0af8SEugene Zhulenev scf::ParallelOp op, Value blockSize, Value blockCount, 533*86ad0af8SEugene Zhulenev const SmallVector<Value> &tripCounts) { 534*86ad0af8SEugene Zhulenev MLIRContext *ctx = op->getContext(); 535*86ad0af8SEugene Zhulenev 536*86ad0af8SEugene Zhulenev FuncOp compute = parallelComputeFunction.func; 537*86ad0af8SEugene Zhulenev 538*86ad0af8SEugene Zhulenev Value c0 = b.create<ConstantOp>(b.getIndexAttr(0)); 539*86ad0af8SEugene Zhulenev Value c1 = b.create<ConstantOp>(b.getIndexAttr(1)); 540*86ad0af8SEugene Zhulenev 541*86ad0af8SEugene Zhulenev // Create an async.group to wait on all async tokens from the concurrent 542*86ad0af8SEugene Zhulenev // execution of multiple parallel compute function. First block will be 543*86ad0af8SEugene Zhulenev // executed synchronously in the caller thread. 544*86ad0af8SEugene Zhulenev Value groupSize = b.create<SubIOp>(blockCount, c1); 545*86ad0af8SEugene Zhulenev Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize); 546*86ad0af8SEugene Zhulenev 547*86ad0af8SEugene Zhulenev // Call parallel compute function for all blocks. 548*86ad0af8SEugene Zhulenev using LoopBodyBuilder = 549*86ad0af8SEugene Zhulenev std::function<void(OpBuilder &, Location, Value, ValueRange)>; 550*86ad0af8SEugene Zhulenev 551*86ad0af8SEugene Zhulenev // Returns parallel compute function operands to process the given block. 552*86ad0af8SEugene Zhulenev auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> { 553*86ad0af8SEugene Zhulenev SmallVector<Value> computeFuncOperands = {blockIndex, blockSize}; 554*86ad0af8SEugene Zhulenev computeFuncOperands.append(tripCounts); 555*86ad0af8SEugene Zhulenev computeFuncOperands.append(op.lowerBound().begin(), op.lowerBound().end()); 556*86ad0af8SEugene Zhulenev computeFuncOperands.append(op.upperBound().begin(), op.upperBound().end()); 557*86ad0af8SEugene Zhulenev computeFuncOperands.append(op.step().begin(), op.step().end()); 558*86ad0af8SEugene Zhulenev computeFuncOperands.append(parallelComputeFunction.captures); 559*86ad0af8SEugene Zhulenev return computeFuncOperands; 560*86ad0af8SEugene Zhulenev }; 561*86ad0af8SEugene Zhulenev 562*86ad0af8SEugene Zhulenev // Induction variable is the index of the block: [0, blockCount). 563*86ad0af8SEugene Zhulenev LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc, 564*86ad0af8SEugene Zhulenev Value iv, ValueRange args) { 565*86ad0af8SEugene Zhulenev ImplicitLocOpBuilder nb(loc, loopBuilder); 566*86ad0af8SEugene Zhulenev 567*86ad0af8SEugene Zhulenev // Call parallel compute function inside the async.execute region. 568*86ad0af8SEugene Zhulenev auto executeBodyBuilder = [&](OpBuilder &executeBuilder, 569*86ad0af8SEugene Zhulenev Location executeLoc, ValueRange executeArgs) { 570*86ad0af8SEugene Zhulenev executeBuilder.create<CallOp>(executeLoc, compute.sym_name(), 571*86ad0af8SEugene Zhulenev compute.getCallableResults(), 572*86ad0af8SEugene Zhulenev computeFuncOperands(iv)); 573*86ad0af8SEugene Zhulenev executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); 574*86ad0af8SEugene Zhulenev }; 575*86ad0af8SEugene Zhulenev 576*86ad0af8SEugene Zhulenev // Create async.execute operation to launch parallel computate function. 577*86ad0af8SEugene Zhulenev auto execute = nb.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(), 578*86ad0af8SEugene Zhulenev executeBodyBuilder); 579*86ad0af8SEugene Zhulenev nb.create<AddToGroupOp>(rewriter.getIndexType(), execute.token(), group); 580*86ad0af8SEugene Zhulenev nb.create<scf::YieldOp>(); 581*86ad0af8SEugene Zhulenev }; 582*86ad0af8SEugene Zhulenev 583*86ad0af8SEugene Zhulenev // Iterate over all compute blocks and launch parallel compute operations. 584*86ad0af8SEugene Zhulenev b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder); 585*86ad0af8SEugene Zhulenev 586*86ad0af8SEugene Zhulenev // Call parallel compute function for the first block in the caller thread. 587*86ad0af8SEugene Zhulenev b.create<CallOp>(compute.sym_name(), compute.getCallableResults(), 588*86ad0af8SEugene Zhulenev computeFuncOperands(c0)); 589*86ad0af8SEugene Zhulenev 590*86ad0af8SEugene Zhulenev // Wait for the completion of all async compute operations. 591*86ad0af8SEugene Zhulenev b.create<AwaitAllOp>(group); 592*86ad0af8SEugene Zhulenev } 593*86ad0af8SEugene Zhulenev 594c30ab6c2SEugene Zhulenev LogicalResult 595c30ab6c2SEugene Zhulenev AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, 596c30ab6c2SEugene Zhulenev PatternRewriter &rewriter) const { 597c30ab6c2SEugene Zhulenev // We do not currently support rewrite for parallel op with reductions. 598c30ab6c2SEugene Zhulenev if (op.getNumReductions() != 0) 599c30ab6c2SEugene Zhulenev return failure(); 600c30ab6c2SEugene Zhulenev 601*86ad0af8SEugene Zhulenev ImplicitLocOpBuilder b(op.getLoc(), rewriter); 602c30ab6c2SEugene Zhulenev 603c30ab6c2SEugene Zhulenev // Compute trip count for each loop induction variable: 604*86ad0af8SEugene Zhulenev // tripCount = ceil_div(upperBound - lowerBound, step); 605*86ad0af8SEugene Zhulenev SmallVector<Value> tripCounts(op.getNumLoops()); 606c30ab6c2SEugene Zhulenev for (size_t i = 0; i < op.getNumLoops(); ++i) { 607c30ab6c2SEugene Zhulenev auto lb = op.lowerBound()[i]; 608c30ab6c2SEugene Zhulenev auto ub = op.upperBound()[i]; 609c30ab6c2SEugene Zhulenev auto step = op.step()[i]; 610*86ad0af8SEugene Zhulenev auto range = b.create<SubIOp>(ub, lb); 611*86ad0af8SEugene Zhulenev tripCounts[i] = b.create<SignedCeilDivIOp>(range, step); 612c30ab6c2SEugene Zhulenev } 613c30ab6c2SEugene Zhulenev 614*86ad0af8SEugene Zhulenev // Compute a product of trip counts to get the 1-dimensional iteration space 615*86ad0af8SEugene Zhulenev // for the scf.parallel operation. 616*86ad0af8SEugene Zhulenev Value tripCount = tripCounts[0]; 617*86ad0af8SEugene Zhulenev for (size_t i = 1; i < tripCounts.size(); ++i) 618*86ad0af8SEugene Zhulenev tripCount = b.create<MulIOp>(tripCount, tripCounts[i]); 619c30ab6c2SEugene Zhulenev 620*86ad0af8SEugene Zhulenev auto indexTy = b.getIndexType(); 621c30ab6c2SEugene Zhulenev 622*86ad0af8SEugene Zhulenev // Do not overload worker threads with too many compute blocks. 623*86ad0af8SEugene Zhulenev Value maxComputeBlocks = b.create<ConstantOp>( 624*86ad0af8SEugene Zhulenev indexTy, b.getIndexAttr(numWorkerThreads * kMaxOversharding)); 625c30ab6c2SEugene Zhulenev 626*86ad0af8SEugene Zhulenev // Target block size from the pass parameters. 627*86ad0af8SEugene Zhulenev Value targetComputeBlockSize = 628*86ad0af8SEugene Zhulenev b.create<ConstantOp>(indexTy, b.getIndexAttr(targetBlockSize)); 629c30ab6c2SEugene Zhulenev 630*86ad0af8SEugene Zhulenev // Compute parallel block size from the parallel problem size: 631*86ad0af8SEugene Zhulenev // blockSize = min(tripCount, 632*86ad0af8SEugene Zhulenev // max(divup(tripCount, maxComputeBlocks), 633*86ad0af8SEugene Zhulenev // targetComputeBlockSize)) 634*86ad0af8SEugene Zhulenev Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks); 635*86ad0af8SEugene Zhulenev Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize); 636*86ad0af8SEugene Zhulenev Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlockSize); 637*86ad0af8SEugene Zhulenev Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2); 638*86ad0af8SEugene Zhulenev Value blockSize = b.create<SelectOp>(bs3, tripCount, bs2); 639*86ad0af8SEugene Zhulenev Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize); 640*86ad0af8SEugene Zhulenev 641*86ad0af8SEugene Zhulenev // Create a parallel compute function that takes a block id and computes the 642*86ad0af8SEugene Zhulenev // parallel operation body for a subset of iteration space. 643*86ad0af8SEugene Zhulenev ParallelComputeFunction parallelComputeFunction = 644*86ad0af8SEugene Zhulenev createParallelComputeFunction(op, rewriter); 645*86ad0af8SEugene Zhulenev 646*86ad0af8SEugene Zhulenev // Dispatch parallel compute function using async recursive work splitting, or 647*86ad0af8SEugene Zhulenev // by submitting compute task sequentially from a caller thread. 648*86ad0af8SEugene Zhulenev if (asyncDispatch) { 649*86ad0af8SEugene Zhulenev doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize, 650*86ad0af8SEugene Zhulenev blockCount, tripCounts); 651*86ad0af8SEugene Zhulenev } else { 652*86ad0af8SEugene Zhulenev doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize, 653*86ad0af8SEugene Zhulenev blockCount, tripCounts); 654c30ab6c2SEugene Zhulenev } 655c30ab6c2SEugene Zhulenev 656*86ad0af8SEugene Zhulenev // Parallel operation was replaces with a block iteration loop. 657c30ab6c2SEugene Zhulenev rewriter.eraseOp(op); 658c30ab6c2SEugene Zhulenev 659c30ab6c2SEugene Zhulenev return success(); 660c30ab6c2SEugene Zhulenev } 661c30ab6c2SEugene Zhulenev 6628a316b00SEugene Zhulenev void AsyncParallelForPass::runOnOperation() { 663c30ab6c2SEugene Zhulenev MLIRContext *ctx = &getContext(); 664c30ab6c2SEugene Zhulenev 665dc4e913bSChris Lattner RewritePatternSet patterns(ctx); 666*86ad0af8SEugene Zhulenev patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads, 667*86ad0af8SEugene Zhulenev targetBlockSize); 668c30ab6c2SEugene Zhulenev 6698a316b00SEugene Zhulenev if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 670c30ab6c2SEugene Zhulenev signalPassFailure(); 671c30ab6c2SEugene Zhulenev } 672c30ab6c2SEugene Zhulenev 6738a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncParallelForPass() { 674c30ab6c2SEugene Zhulenev return std::make_unique<AsyncParallelForPass>(); 675c30ab6c2SEugene Zhulenev } 676