1c30ab6c2SEugene Zhulenev //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
2c30ab6c2SEugene Zhulenev //
3c30ab6c2SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c30ab6c2SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5c30ab6c2SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c30ab6c2SEugene Zhulenev //
7c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===//
8c30ab6c2SEugene Zhulenev //
9*86ad0af8SEugene Zhulenev // This file implements scf.parallel to scf.for + async.execute conversion pass.
10c30ab6c2SEugene Zhulenev //
11c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===//
12c30ab6c2SEugene Zhulenev 
13c30ab6c2SEugene Zhulenev #include "PassDetail.h"
14c30ab6c2SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
15c30ab6c2SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h"
16c30ab6c2SEugene Zhulenev #include "mlir/Dialect/SCF/SCF.h"
17c30ab6c2SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h"
18c30ab6c2SEugene Zhulenev #include "mlir/IR/BlockAndValueMapping.h"
19*86ad0af8SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h"
20c30ab6c2SEugene Zhulenev #include "mlir/IR/PatternMatch.h"
21c30ab6c2SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22*86ad0af8SEugene Zhulenev #include "mlir/Transforms/RegionUtils.h"
23c30ab6c2SEugene Zhulenev 
24c30ab6c2SEugene Zhulenev using namespace mlir;
25c30ab6c2SEugene Zhulenev using namespace mlir::async;
26c30ab6c2SEugene Zhulenev 
27c30ab6c2SEugene Zhulenev #define DEBUG_TYPE "async-parallel-for"
28c30ab6c2SEugene Zhulenev 
29c30ab6c2SEugene Zhulenev namespace {
30c30ab6c2SEugene Zhulenev 
31c30ab6c2SEugene Zhulenev // Rewrite scf.parallel operation into multiple concurrent async.execute
32c30ab6c2SEugene Zhulenev // operations over non overlapping subranges of the original loop.
33c30ab6c2SEugene Zhulenev //
34c30ab6c2SEugene Zhulenev // Example:
35c30ab6c2SEugene Zhulenev //
36*86ad0af8SEugene Zhulenev //   scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
37c30ab6c2SEugene Zhulenev //     "do_some_compute"(%i, %j): () -> ()
38c30ab6c2SEugene Zhulenev //   }
39c30ab6c2SEugene Zhulenev //
40c30ab6c2SEugene Zhulenev // Converted to:
41c30ab6c2SEugene Zhulenev //
42*86ad0af8SEugene Zhulenev //   // Parallel compute function that executes the parallel body region for
43*86ad0af8SEugene Zhulenev //   // a subset of the parallel iteration space defined by the one-dimensional
44*86ad0af8SEugene Zhulenev //   // compute block index.
45*86ad0af8SEugene Zhulenev //   func parallel_compute_function(%block_index : index, %block_size : index,
46*86ad0af8SEugene Zhulenev //                                  <parallel operation properties>, ...) {
47*86ad0af8SEugene Zhulenev //     // Compute multi-dimensional loop bounds for %block_index.
48*86ad0af8SEugene Zhulenev //     %block_lbi, %block_lbj = ...
49*86ad0af8SEugene Zhulenev //     %block_ubi, %block_ubj = ...
50c30ab6c2SEugene Zhulenev //
51*86ad0af8SEugene Zhulenev //     // Clone parallel operation body into the scf.for loop nest.
52*86ad0af8SEugene Zhulenev //     scf.for %i = %blockLbi to %blockUbi {
53*86ad0af8SEugene Zhulenev //       scf.for %j = block_lbj to %block_ubj {
54c30ab6c2SEugene Zhulenev //         "do_some_compute"(%i, %j): () -> ()
55c30ab6c2SEugene Zhulenev //       }
56c30ab6c2SEugene Zhulenev //     }
57c30ab6c2SEugene Zhulenev //   }
58c30ab6c2SEugene Zhulenev //
59*86ad0af8SEugene Zhulenev // And a dispatch function depending on the `asyncDispatch` option.
60*86ad0af8SEugene Zhulenev //
61*86ad0af8SEugene Zhulenev // When async dispatch is on: (pseudocode)
62*86ad0af8SEugene Zhulenev //
63*86ad0af8SEugene Zhulenev //   %block_size = ... compute parallel compute block size
64*86ad0af8SEugene Zhulenev //   %block_count = ... compute the number of compute blocks
65*86ad0af8SEugene Zhulenev //
66*86ad0af8SEugene Zhulenev //   func @async_dispatch(%block_start : index, %block_end : index, ...) {
67*86ad0af8SEugene Zhulenev //     // Keep splitting block range until we reached a range of size 1.
68*86ad0af8SEugene Zhulenev //     while (%block_end - %block_start > 1) {
69*86ad0af8SEugene Zhulenev //       %mid_index = block_start + (block_end - block_start) / 2;
70*86ad0af8SEugene Zhulenev //       async.execute { call @async_dispatch(%mid_index, %block_end); }
71*86ad0af8SEugene Zhulenev //       %block_end = %mid_index
72c30ab6c2SEugene Zhulenev //     }
73c30ab6c2SEugene Zhulenev //
74*86ad0af8SEugene Zhulenev //     // Call parallel compute function for a single block.
75*86ad0af8SEugene Zhulenev //     call @parallel_compute_fn(%block_start, %block_size, ...);
76*86ad0af8SEugene Zhulenev //   }
77c30ab6c2SEugene Zhulenev //
78*86ad0af8SEugene Zhulenev //   // Launch async dispatch for [0, block_count) range.
79*86ad0af8SEugene Zhulenev //   call @async_dispatch(%c0, %block_count);
80c30ab6c2SEugene Zhulenev //
81*86ad0af8SEugene Zhulenev // When async dispatch is off:
82c30ab6c2SEugene Zhulenev //
83*86ad0af8SEugene Zhulenev //   %block_size = ... compute parallel compute block size
84*86ad0af8SEugene Zhulenev //   %block_count = ... compute the number of compute blocks
85*86ad0af8SEugene Zhulenev //
86*86ad0af8SEugene Zhulenev //   scf.for %block_index = %c0 to %block_count {
87*86ad0af8SEugene Zhulenev //      call @parallel_compute_fn(%block_index, %block_size, ...)
88*86ad0af8SEugene Zhulenev //   }
89*86ad0af8SEugene Zhulenev //
90*86ad0af8SEugene Zhulenev struct AsyncParallelForPass
91*86ad0af8SEugene Zhulenev     : public AsyncParallelForBase<AsyncParallelForPass> {
92*86ad0af8SEugene Zhulenev   AsyncParallelForPass() = default;
93*86ad0af8SEugene Zhulenev   void runOnOperation() override;
94*86ad0af8SEugene Zhulenev };
95*86ad0af8SEugene Zhulenev 
96c30ab6c2SEugene Zhulenev struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
97c30ab6c2SEugene Zhulenev public:
98*86ad0af8SEugene Zhulenev   AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch,
99*86ad0af8SEugene Zhulenev                           int32_t numWorkerThreads, int32_t targetBlockSize)
100*86ad0af8SEugene Zhulenev       : OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
101*86ad0af8SEugene Zhulenev         numWorkerThreads(numWorkerThreads), targetBlockSize(targetBlockSize) {}
102c30ab6c2SEugene Zhulenev 
103c30ab6c2SEugene Zhulenev   LogicalResult matchAndRewrite(scf::ParallelOp op,
104c30ab6c2SEugene Zhulenev                                 PatternRewriter &rewriter) const override;
105c30ab6c2SEugene Zhulenev 
106c30ab6c2SEugene Zhulenev private:
107*86ad0af8SEugene Zhulenev   // The maximum number of tasks per worker thread when sharding parallel op.
108*86ad0af8SEugene Zhulenev   static constexpr int32_t kMaxOversharding = 4;
109*86ad0af8SEugene Zhulenev 
110*86ad0af8SEugene Zhulenev   bool asyncDispatch;
111*86ad0af8SEugene Zhulenev   int32_t numWorkerThreads;
112*86ad0af8SEugene Zhulenev   int32_t targetBlockSize;
113c30ab6c2SEugene Zhulenev };
114c30ab6c2SEugene Zhulenev 
115*86ad0af8SEugene Zhulenev struct ParallelComputeFunctionType {
116*86ad0af8SEugene Zhulenev   FunctionType type;
117*86ad0af8SEugene Zhulenev   llvm::SmallVector<Value> captures;
118*86ad0af8SEugene Zhulenev };
119*86ad0af8SEugene Zhulenev 
120*86ad0af8SEugene Zhulenev struct ParallelComputeFunction {
121*86ad0af8SEugene Zhulenev   FuncOp func;
122*86ad0af8SEugene Zhulenev   llvm::SmallVector<Value> captures;
123c30ab6c2SEugene Zhulenev };
124c30ab6c2SEugene Zhulenev 
125c30ab6c2SEugene Zhulenev } // namespace
126c30ab6c2SEugene Zhulenev 
127*86ad0af8SEugene Zhulenev // Converts one-dimensional iteration index in the [0, tripCount) interval
128*86ad0af8SEugene Zhulenev // into multidimensional iteration coordinate.
129*86ad0af8SEugene Zhulenev static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
130*86ad0af8SEugene Zhulenev                                       const SmallVector<Value> &tripCounts) {
131*86ad0af8SEugene Zhulenev   SmallVector<Value> coords(tripCounts.size());
132*86ad0af8SEugene Zhulenev   assert(!tripCounts.empty() && "tripCounts must be not empty");
133*86ad0af8SEugene Zhulenev 
134*86ad0af8SEugene Zhulenev   for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
135*86ad0af8SEugene Zhulenev     coords[i] = b.create<SignedRemIOp>(index, tripCounts[i]);
136*86ad0af8SEugene Zhulenev     index = b.create<SignedDivIOp>(index, tripCounts[i]);
137*86ad0af8SEugene Zhulenev   }
138*86ad0af8SEugene Zhulenev 
139*86ad0af8SEugene Zhulenev   return coords;
140*86ad0af8SEugene Zhulenev }
141*86ad0af8SEugene Zhulenev 
142*86ad0af8SEugene Zhulenev // Returns a function type and implicit captures for a parallel compute
143*86ad0af8SEugene Zhulenev // function. We'll need a list of implicit captures to setup block and value
144*86ad0af8SEugene Zhulenev // mapping when we'll clone the body of the parallel operation.
145*86ad0af8SEugene Zhulenev static ParallelComputeFunctionType
146*86ad0af8SEugene Zhulenev getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
147*86ad0af8SEugene Zhulenev   // Values implicitly captured by the parallel operation.
148*86ad0af8SEugene Zhulenev   llvm::SetVector<Value> captures;
149*86ad0af8SEugene Zhulenev   getUsedValuesDefinedAbove(op.region(), op.region(), captures);
150*86ad0af8SEugene Zhulenev 
151*86ad0af8SEugene Zhulenev   llvm::SmallVector<Type> inputs;
152*86ad0af8SEugene Zhulenev   inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
153*86ad0af8SEugene Zhulenev 
154*86ad0af8SEugene Zhulenev   Type indexTy = rewriter.getIndexType();
155*86ad0af8SEugene Zhulenev 
156*86ad0af8SEugene Zhulenev   // One-dimensional iteration space defined by the block index and size.
157*86ad0af8SEugene Zhulenev   inputs.push_back(indexTy); // blockIndex
158*86ad0af8SEugene Zhulenev   inputs.push_back(indexTy); // blockSize
159*86ad0af8SEugene Zhulenev 
160*86ad0af8SEugene Zhulenev   // Multi-dimensional parallel iteration space defined by the loop trip counts.
161*86ad0af8SEugene Zhulenev   for (unsigned i = 0; i < op.getNumLoops(); ++i)
162*86ad0af8SEugene Zhulenev     inputs.push_back(indexTy); // loop tripCount
163*86ad0af8SEugene Zhulenev 
164*86ad0af8SEugene Zhulenev   // Parallel operation lower bound, upper bound and step.
165*86ad0af8SEugene Zhulenev   for (unsigned i = 0; i < op.getNumLoops(); ++i) {
166*86ad0af8SEugene Zhulenev     inputs.push_back(indexTy); // lower bound
167*86ad0af8SEugene Zhulenev     inputs.push_back(indexTy); // upper bound
168*86ad0af8SEugene Zhulenev     inputs.push_back(indexTy); // step
169*86ad0af8SEugene Zhulenev   }
170*86ad0af8SEugene Zhulenev 
171*86ad0af8SEugene Zhulenev   // Types of the implicit captures.
172*86ad0af8SEugene Zhulenev   for (Value capture : captures)
173*86ad0af8SEugene Zhulenev     inputs.push_back(capture.getType());
174*86ad0af8SEugene Zhulenev 
175*86ad0af8SEugene Zhulenev   // Convert captures to vector for later convenience.
176*86ad0af8SEugene Zhulenev   SmallVector<Value> capturesVector(captures.begin(), captures.end());
177*86ad0af8SEugene Zhulenev   return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector};
178*86ad0af8SEugene Zhulenev }
179*86ad0af8SEugene Zhulenev 
180*86ad0af8SEugene Zhulenev // Create a parallel compute fuction from the parallel operation.
181*86ad0af8SEugene Zhulenev static ParallelComputeFunction
182*86ad0af8SEugene Zhulenev createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
183*86ad0af8SEugene Zhulenev   OpBuilder::InsertionGuard guard(rewriter);
184*86ad0af8SEugene Zhulenev   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
185*86ad0af8SEugene Zhulenev 
186*86ad0af8SEugene Zhulenev   ModuleOp module = op->getParentOfType<ModuleOp>();
187*86ad0af8SEugene Zhulenev   b.setInsertionPointToStart(&module->getRegion(0).front());
188*86ad0af8SEugene Zhulenev 
189*86ad0af8SEugene Zhulenev   ParallelComputeFunctionType computeFuncType =
190*86ad0af8SEugene Zhulenev       getParallelComputeFunctionType(op, rewriter);
191*86ad0af8SEugene Zhulenev 
192*86ad0af8SEugene Zhulenev   FunctionType type = computeFuncType.type;
193*86ad0af8SEugene Zhulenev   FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type);
194*86ad0af8SEugene Zhulenev   func.setPrivate();
195*86ad0af8SEugene Zhulenev 
196*86ad0af8SEugene Zhulenev   // Insert function into the module symbol table and assign it unique name.
197*86ad0af8SEugene Zhulenev   SymbolTable symbolTable(module);
198*86ad0af8SEugene Zhulenev   symbolTable.insert(func);
199*86ad0af8SEugene Zhulenev   rewriter.getListener()->notifyOperationInserted(func);
200*86ad0af8SEugene Zhulenev 
201*86ad0af8SEugene Zhulenev   // Create function entry block.
202*86ad0af8SEugene Zhulenev   Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs());
203*86ad0af8SEugene Zhulenev   b.setInsertionPointToEnd(block);
204*86ad0af8SEugene Zhulenev 
205*86ad0af8SEugene Zhulenev   unsigned offset = 0; // argument offset for arguments decoding
206*86ad0af8SEugene Zhulenev 
207*86ad0af8SEugene Zhulenev   // Load multiple arguments into values vector.
208*86ad0af8SEugene Zhulenev   auto getArguments = [&](unsigned num_arguments) -> SmallVector<Value> {
209*86ad0af8SEugene Zhulenev     SmallVector<Value> values(num_arguments);
210*86ad0af8SEugene Zhulenev     for (unsigned i = 0; i < num_arguments; ++i)
211*86ad0af8SEugene Zhulenev       values[i] = block->getArgument(offset++);
212*86ad0af8SEugene Zhulenev     return values;
213*86ad0af8SEugene Zhulenev   };
214*86ad0af8SEugene Zhulenev 
215*86ad0af8SEugene Zhulenev   // Block iteration position defined by the block index and size.
216*86ad0af8SEugene Zhulenev   Value blockIndex = block->getArgument(offset++);
217*86ad0af8SEugene Zhulenev   Value blockSize = block->getArgument(offset++);
218*86ad0af8SEugene Zhulenev 
219*86ad0af8SEugene Zhulenev   // Constants used below.
220*86ad0af8SEugene Zhulenev   Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
221*86ad0af8SEugene Zhulenev   Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
222*86ad0af8SEugene Zhulenev 
223*86ad0af8SEugene Zhulenev   // Multi-dimensional parallel iteration space defined by the loop trip counts.
224*86ad0af8SEugene Zhulenev   SmallVector<Value> tripCounts = getArguments(op.getNumLoops());
225*86ad0af8SEugene Zhulenev 
226*86ad0af8SEugene Zhulenev   // Compute a product of trip counts to get the size of the flattened
227*86ad0af8SEugene Zhulenev   // one-dimensional iteration space.
228*86ad0af8SEugene Zhulenev   Value tripCount = tripCounts[0];
229*86ad0af8SEugene Zhulenev   for (unsigned i = 1; i < tripCounts.size(); ++i)
230*86ad0af8SEugene Zhulenev     tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
231*86ad0af8SEugene Zhulenev 
232*86ad0af8SEugene Zhulenev   // Parallel operation lower bound, upper bound and step.
233*86ad0af8SEugene Zhulenev   SmallVector<Value> lowerBound = getArguments(op.getNumLoops());
234*86ad0af8SEugene Zhulenev   SmallVector<Value> upperBound = getArguments(op.getNumLoops());
235*86ad0af8SEugene Zhulenev   SmallVector<Value> step = getArguments(op.getNumLoops());
236*86ad0af8SEugene Zhulenev 
237*86ad0af8SEugene Zhulenev   // Remaining arguments are implicit captures of the parallel operation.
238*86ad0af8SEugene Zhulenev   SmallVector<Value> captures = getArguments(block->getNumArguments() - offset);
239*86ad0af8SEugene Zhulenev 
240*86ad0af8SEugene Zhulenev   // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
241*86ad0af8SEugene Zhulenev   //   blockFirstIndex = blockIndex * blockSize
242*86ad0af8SEugene Zhulenev   Value blockFirstIndex = b.create<MulIOp>(blockIndex, blockSize);
243*86ad0af8SEugene Zhulenev 
244*86ad0af8SEugene Zhulenev   // The last one-dimensional index in the block defined by the `blockIndex`:
245*86ad0af8SEugene Zhulenev   //   blockLastIndex = max((blockIndex + 1) * blockSize, tripCount) - 1
246*86ad0af8SEugene Zhulenev   Value blockEnd0 = b.create<AddIOp>(blockIndex, c1);
247*86ad0af8SEugene Zhulenev   Value blockEnd1 = b.create<MulIOp>(blockEnd0, blockSize);
248*86ad0af8SEugene Zhulenev   Value blockEnd2 = b.create<CmpIOp>(CmpIPredicate::sge, blockEnd1, tripCount);
249*86ad0af8SEugene Zhulenev   Value blockEnd3 = b.create<SelectOp>(blockEnd2, tripCount, blockEnd1);
250*86ad0af8SEugene Zhulenev   Value blockLastIndex = b.create<SubIOp>(blockEnd3, c1);
251*86ad0af8SEugene Zhulenev 
252*86ad0af8SEugene Zhulenev   // Convert one-dimensional indices to multi-dimensional coordinates.
253*86ad0af8SEugene Zhulenev   auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
254*86ad0af8SEugene Zhulenev   auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
255*86ad0af8SEugene Zhulenev 
256*86ad0af8SEugene Zhulenev   // Compute compute loops upper bounds from the block last coordinates:
257*86ad0af8SEugene Zhulenev   //   blockEndCoord[i] = blockLastCoord[i] + 1
258*86ad0af8SEugene Zhulenev   //
259*86ad0af8SEugene Zhulenev   // Block first and last coordinates can be the same along the outer compute
260*86ad0af8SEugene Zhulenev   // dimension when inner compute dimension containts multple blocks.
261*86ad0af8SEugene Zhulenev   SmallVector<Value> blockEndCoord(op.getNumLoops());
262*86ad0af8SEugene Zhulenev   for (size_t i = 0; i < blockLastCoord.size(); ++i)
263*86ad0af8SEugene Zhulenev     blockEndCoord[i] = b.create<AddIOp>(blockLastCoord[i], c1);
264*86ad0af8SEugene Zhulenev 
265*86ad0af8SEugene Zhulenev   // Construct a loop nest out of scf.for operations that will iterate over
266*86ad0af8SEugene Zhulenev   // all coordinates in [blockFirstCoord, blockLastCoord] range.
267*86ad0af8SEugene Zhulenev   using LoopBodyBuilder =
268*86ad0af8SEugene Zhulenev       std::function<void(OpBuilder &, Location, Value, ValueRange)>;
269*86ad0af8SEugene Zhulenev   using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
270*86ad0af8SEugene Zhulenev 
271*86ad0af8SEugene Zhulenev   // Parallel region induction variables computed from the multi-dimensional
272*86ad0af8SEugene Zhulenev   // iteration coordinate using parallel operation bounds and step:
273*86ad0af8SEugene Zhulenev   //
274*86ad0af8SEugene Zhulenev   //   computeBlockInductionVars[loopIdx] =
275*86ad0af8SEugene Zhulenev   //       lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopDdx]
276*86ad0af8SEugene Zhulenev   SmallVector<Value> computeBlockInductionVars(op.getNumLoops());
277*86ad0af8SEugene Zhulenev 
278*86ad0af8SEugene Zhulenev   // We need to know if we are in the first or last iteration of the
279*86ad0af8SEugene Zhulenev   // multi-dimensional loop for each loop in the nest, so we can decide what
280*86ad0af8SEugene Zhulenev   // loop bounds should we use for the nested loops: bounds defined by compute
281*86ad0af8SEugene Zhulenev   // block interval, or bounds defined by the parallel operation.
282*86ad0af8SEugene Zhulenev   //
283*86ad0af8SEugene Zhulenev   // Example: 2d parallel operation
284*86ad0af8SEugene Zhulenev   //                   i   j
285*86ad0af8SEugene Zhulenev   //   loop sizes:   [50, 50]
286*86ad0af8SEugene Zhulenev   //   first coord:  [25, 25]
287*86ad0af8SEugene Zhulenev   //   last coord:   [30, 30]
288*86ad0af8SEugene Zhulenev   //
289*86ad0af8SEugene Zhulenev   // If `i` is equal to 25 then iteration over `j` should start at 25, when `i`
290*86ad0af8SEugene Zhulenev   // is between 25 and 30 it should start at 0. The upper bound for `j` should
291*86ad0af8SEugene Zhulenev   // be 50, except when `i` is equal to 30, then it should also be 30.
292*86ad0af8SEugene Zhulenev   //
293*86ad0af8SEugene Zhulenev   // Value at ith position specifies if all loops in [0, i) range of the loop
294*86ad0af8SEugene Zhulenev   // nest are in the first/last iteration.
295*86ad0af8SEugene Zhulenev   SmallVector<Value> isBlockFirstCoord(op.getNumLoops());
296*86ad0af8SEugene Zhulenev   SmallVector<Value> isBlockLastCoord(op.getNumLoops());
297*86ad0af8SEugene Zhulenev 
298*86ad0af8SEugene Zhulenev   // Builds inner loop nest inside async.execute operation that does all the
299*86ad0af8SEugene Zhulenev   // work concurrently.
300*86ad0af8SEugene Zhulenev   LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
301*86ad0af8SEugene Zhulenev     return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv,
302*86ad0af8SEugene Zhulenev                         ValueRange args) {
303*86ad0af8SEugene Zhulenev       ImplicitLocOpBuilder nb(loc, nestedBuilder);
304*86ad0af8SEugene Zhulenev 
305*86ad0af8SEugene Zhulenev       // Compute induction variable for `loopIdx`.
306*86ad0af8SEugene Zhulenev       computeBlockInductionVars[loopIdx] = nb.create<AddIOp>(
307*86ad0af8SEugene Zhulenev           lowerBound[loopIdx], nb.create<MulIOp>(iv, step[loopIdx]));
308*86ad0af8SEugene Zhulenev 
309*86ad0af8SEugene Zhulenev       // Check if we are inside first or last iteration of the loop.
310*86ad0af8SEugene Zhulenev       isBlockFirstCoord[loopIdx] =
311*86ad0af8SEugene Zhulenev           nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
312*86ad0af8SEugene Zhulenev       isBlockLastCoord[loopIdx] =
313*86ad0af8SEugene Zhulenev           nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
314*86ad0af8SEugene Zhulenev 
315*86ad0af8SEugene Zhulenev       // Check if the previous loop is in its first of last iteration.
316*86ad0af8SEugene Zhulenev       if (loopIdx > 0) {
317*86ad0af8SEugene Zhulenev         isBlockFirstCoord[loopIdx] = nb.create<AndOp>(
318*86ad0af8SEugene Zhulenev             isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
319*86ad0af8SEugene Zhulenev         isBlockLastCoord[loopIdx] = nb.create<AndOp>(
320*86ad0af8SEugene Zhulenev             isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
321*86ad0af8SEugene Zhulenev       }
322*86ad0af8SEugene Zhulenev 
323*86ad0af8SEugene Zhulenev       // Keep building loop nest.
324*86ad0af8SEugene Zhulenev       if (loopIdx < op.getNumLoops() - 1) {
325*86ad0af8SEugene Zhulenev         // Select nested loop lower/upper bounds depending on out position in
326*86ad0af8SEugene Zhulenev         // the multi-dimensional iteration space.
327*86ad0af8SEugene Zhulenev         auto lb = nb.create<SelectOp>(isBlockFirstCoord[loopIdx],
328*86ad0af8SEugene Zhulenev                                       blockFirstCoord[loopIdx + 1], c0);
329*86ad0af8SEugene Zhulenev 
330*86ad0af8SEugene Zhulenev         auto ub = nb.create<SelectOp>(isBlockLastCoord[loopIdx],
331*86ad0af8SEugene Zhulenev                                       blockEndCoord[loopIdx + 1],
332*86ad0af8SEugene Zhulenev                                       tripCounts[loopIdx + 1]);
333*86ad0af8SEugene Zhulenev 
334*86ad0af8SEugene Zhulenev         nb.create<scf::ForOp>(lb, ub, c1, ValueRange(),
335*86ad0af8SEugene Zhulenev                               workLoopBuilder(loopIdx + 1));
336*86ad0af8SEugene Zhulenev         nb.create<scf::YieldOp>(loc);
337*86ad0af8SEugene Zhulenev         return;
338*86ad0af8SEugene Zhulenev       }
339*86ad0af8SEugene Zhulenev 
340*86ad0af8SEugene Zhulenev       // Copy the body of the parallel op into the inner-most loop.
341*86ad0af8SEugene Zhulenev       BlockAndValueMapping mapping;
342*86ad0af8SEugene Zhulenev       mapping.map(op.getInductionVars(), computeBlockInductionVars);
343*86ad0af8SEugene Zhulenev       mapping.map(computeFuncType.captures, captures);
344*86ad0af8SEugene Zhulenev 
345*86ad0af8SEugene Zhulenev       for (auto &bodyOp : op.getLoopBody().getOps())
346*86ad0af8SEugene Zhulenev         nb.clone(bodyOp, mapping);
347*86ad0af8SEugene Zhulenev     };
348*86ad0af8SEugene Zhulenev   };
349*86ad0af8SEugene Zhulenev 
350*86ad0af8SEugene Zhulenev   b.create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(),
351*86ad0af8SEugene Zhulenev                        workLoopBuilder(0));
352*86ad0af8SEugene Zhulenev   b.create<ReturnOp>(ValueRange());
353*86ad0af8SEugene Zhulenev 
354*86ad0af8SEugene Zhulenev   return {func, std::move(computeFuncType.captures)};
355*86ad0af8SEugene Zhulenev }
356*86ad0af8SEugene Zhulenev 
357*86ad0af8SEugene Zhulenev // Creates recursive async dispatch function for the given parallel compute
358*86ad0af8SEugene Zhulenev // function. Dispatch function keeps splitting block range into halves until it
359*86ad0af8SEugene Zhulenev // reaches a single block, and then excecutes it inline.
360*86ad0af8SEugene Zhulenev //
361*86ad0af8SEugene Zhulenev // Function pseudocode (mix of C++ and MLIR):
362*86ad0af8SEugene Zhulenev //
363*86ad0af8SEugene Zhulenev //   func @async_dispatch(%block_start : index, %block_end : index, ...) {
364*86ad0af8SEugene Zhulenev //
365*86ad0af8SEugene Zhulenev //     // Keep splitting block range until we reached a range of size 1.
366*86ad0af8SEugene Zhulenev //     while (%block_end - %block_start > 1) {
367*86ad0af8SEugene Zhulenev //       %mid_index = block_start + (block_end - block_start) / 2;
368*86ad0af8SEugene Zhulenev //       async.execute { call @async_dispatch(%mid_index, %block_end); }
369*86ad0af8SEugene Zhulenev //       %block_end = %mid_index
370*86ad0af8SEugene Zhulenev //     }
371*86ad0af8SEugene Zhulenev //
372*86ad0af8SEugene Zhulenev //     // Call parallel compute function for a single block.
373*86ad0af8SEugene Zhulenev //     call @parallel_compute_fn(%block_start, %block_size, ...);
374*86ad0af8SEugene Zhulenev //   }
375*86ad0af8SEugene Zhulenev //
376*86ad0af8SEugene Zhulenev static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
377*86ad0af8SEugene Zhulenev                                           PatternRewriter &rewriter) {
378*86ad0af8SEugene Zhulenev   OpBuilder::InsertionGuard guard(rewriter);
379*86ad0af8SEugene Zhulenev   Location loc = computeFunc.func.getLoc();
380*86ad0af8SEugene Zhulenev   ImplicitLocOpBuilder b(loc, rewriter);
381*86ad0af8SEugene Zhulenev 
382*86ad0af8SEugene Zhulenev   ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
383*86ad0af8SEugene Zhulenev   b.setInsertionPointToStart(&module->getRegion(0).front());
384*86ad0af8SEugene Zhulenev 
385*86ad0af8SEugene Zhulenev   ArrayRef<Type> computeFuncInputTypes =
386*86ad0af8SEugene Zhulenev       computeFunc.func.type().cast<FunctionType>().getInputs();
387*86ad0af8SEugene Zhulenev 
388*86ad0af8SEugene Zhulenev   // Compared to the parallel compute function async dispatch function takes
389*86ad0af8SEugene Zhulenev   // additional !async.group argument. Also instead of a single `blockIndex` it
390*86ad0af8SEugene Zhulenev   // takes `blockStart` and `blockEnd` arguments to define the range of
391*86ad0af8SEugene Zhulenev   // dispatched blocks.
392*86ad0af8SEugene Zhulenev   SmallVector<Type> inputTypes;
393*86ad0af8SEugene Zhulenev   inputTypes.push_back(async::GroupType::get(rewriter.getContext()));
394*86ad0af8SEugene Zhulenev   inputTypes.push_back(rewriter.getIndexType()); // add blockStart argument
395*86ad0af8SEugene Zhulenev   inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
396*86ad0af8SEugene Zhulenev 
397*86ad0af8SEugene Zhulenev   FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange());
398*86ad0af8SEugene Zhulenev   FuncOp func = FuncOp::create(loc, "async_dispatch_fn", type);
399*86ad0af8SEugene Zhulenev   func.setPrivate();
400*86ad0af8SEugene Zhulenev 
401*86ad0af8SEugene Zhulenev   // Insert function into the module symbol table and assign it unique name.
402*86ad0af8SEugene Zhulenev   SymbolTable symbolTable(module);
403*86ad0af8SEugene Zhulenev   symbolTable.insert(func);
404*86ad0af8SEugene Zhulenev   rewriter.getListener()->notifyOperationInserted(func);
405*86ad0af8SEugene Zhulenev 
406*86ad0af8SEugene Zhulenev   // Create function entry block.
407*86ad0af8SEugene Zhulenev   Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs());
408*86ad0af8SEugene Zhulenev   b.setInsertionPointToEnd(block);
409*86ad0af8SEugene Zhulenev 
410*86ad0af8SEugene Zhulenev   Type indexTy = b.getIndexType();
411*86ad0af8SEugene Zhulenev   Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
412*86ad0af8SEugene Zhulenev   Value c2 = b.create<ConstantOp>(b.getIndexAttr(2));
413*86ad0af8SEugene Zhulenev 
414*86ad0af8SEugene Zhulenev   // Get the async group that will track async dispatch completion.
415*86ad0af8SEugene Zhulenev   Value group = block->getArgument(0);
416*86ad0af8SEugene Zhulenev 
417*86ad0af8SEugene Zhulenev   // Get the block iteration range: [blockStart, blockEnd)
418*86ad0af8SEugene Zhulenev   Value blockStart = block->getArgument(1);
419*86ad0af8SEugene Zhulenev   Value blockEnd = block->getArgument(2);
420*86ad0af8SEugene Zhulenev 
421*86ad0af8SEugene Zhulenev   // Create a work splitting while loop for the [blockStart, blockEnd) range.
422*86ad0af8SEugene Zhulenev   SmallVector<Type> types = {indexTy, indexTy};
423*86ad0af8SEugene Zhulenev   SmallVector<Value> operands = {blockStart, blockEnd};
424*86ad0af8SEugene Zhulenev 
425*86ad0af8SEugene Zhulenev   // Create a recursive dispatch loop.
426*86ad0af8SEugene Zhulenev   scf::WhileOp whileOp = b.create<scf::WhileOp>(types, operands);
427*86ad0af8SEugene Zhulenev   Block *before = b.createBlock(&whileOp.before(), {}, types);
428*86ad0af8SEugene Zhulenev   Block *after = b.createBlock(&whileOp.after(), {}, types);
429*86ad0af8SEugene Zhulenev 
430*86ad0af8SEugene Zhulenev   // Setup dispatch loop condition block: decide if we need to go into the
431*86ad0af8SEugene Zhulenev   // `after` block and launch one more async dispatch.
432*86ad0af8SEugene Zhulenev   {
433*86ad0af8SEugene Zhulenev     b.setInsertionPointToEnd(before);
434*86ad0af8SEugene Zhulenev     Value start = before->getArgument(0);
435*86ad0af8SEugene Zhulenev     Value end = before->getArgument(1);
436*86ad0af8SEugene Zhulenev     Value distance = b.create<SubIOp>(end, start);
437*86ad0af8SEugene Zhulenev     Value dispatch = b.create<CmpIOp>(CmpIPredicate::sgt, distance, c1);
438*86ad0af8SEugene Zhulenev     b.create<scf::ConditionOp>(dispatch, before->getArguments());
439*86ad0af8SEugene Zhulenev   }
440*86ad0af8SEugene Zhulenev 
441*86ad0af8SEugene Zhulenev   // Setup the async dispatch loop body: recursively call dispatch function
442*86ad0af8SEugene Zhulenev   // for second the half of the original range and go to the next iteration.
443*86ad0af8SEugene Zhulenev   {
444*86ad0af8SEugene Zhulenev     b.setInsertionPointToEnd(after);
445*86ad0af8SEugene Zhulenev     Value start = after->getArgument(0);
446*86ad0af8SEugene Zhulenev     Value end = after->getArgument(1);
447*86ad0af8SEugene Zhulenev     Value distance = b.create<SubIOp>(end, start);
448*86ad0af8SEugene Zhulenev     Value halfDistance = b.create<SignedDivIOp>(distance, c2);
449*86ad0af8SEugene Zhulenev     Value midIndex = b.create<AddIOp>(after->getArgument(0), halfDistance);
450*86ad0af8SEugene Zhulenev 
451*86ad0af8SEugene Zhulenev     // Call parallel compute function inside the async.execute region.
452*86ad0af8SEugene Zhulenev     auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
453*86ad0af8SEugene Zhulenev                                   Location executeLoc, ValueRange executeArgs) {
454*86ad0af8SEugene Zhulenev       // Update the original `blockStart` and `blockEnd` with new range.
455*86ad0af8SEugene Zhulenev       SmallVector<Value> operands{block->getArguments().begin(),
456*86ad0af8SEugene Zhulenev                                   block->getArguments().end()};
457*86ad0af8SEugene Zhulenev       operands[1] = midIndex;
458*86ad0af8SEugene Zhulenev       operands[2] = end;
459*86ad0af8SEugene Zhulenev 
460*86ad0af8SEugene Zhulenev       executeBuilder.create<CallOp>(executeLoc, func.sym_name(),
461*86ad0af8SEugene Zhulenev                                     func.getCallableResults(), operands);
462*86ad0af8SEugene Zhulenev       executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
463*86ad0af8SEugene Zhulenev     };
464*86ad0af8SEugene Zhulenev 
465*86ad0af8SEugene Zhulenev     // Create async.execute operation to dispatch half of the block range.
466*86ad0af8SEugene Zhulenev     auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
467*86ad0af8SEugene Zhulenev                                        executeBodyBuilder);
468*86ad0af8SEugene Zhulenev     b.create<AddToGroupOp>(indexTy, execute.token(), group);
469*86ad0af8SEugene Zhulenev     b.create<scf::YieldOp>(ValueRange({after->getArgument(0), midIndex}));
470*86ad0af8SEugene Zhulenev   }
471*86ad0af8SEugene Zhulenev 
472*86ad0af8SEugene Zhulenev   // After dispatching async operations to process the tail of the block range
473*86ad0af8SEugene Zhulenev   // call the parallel compute function for the first block of the range.
474*86ad0af8SEugene Zhulenev   b.setInsertionPointAfter(whileOp);
475*86ad0af8SEugene Zhulenev 
476*86ad0af8SEugene Zhulenev   // Drop async dispatch specific arguments: async group, block start and end.
477*86ad0af8SEugene Zhulenev   auto forwardedInputs = block->getArguments().drop_front(3);
478*86ad0af8SEugene Zhulenev   SmallVector<Value> computeFuncOperands = {blockStart};
479*86ad0af8SEugene Zhulenev   computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
480*86ad0af8SEugene Zhulenev 
481*86ad0af8SEugene Zhulenev   b.create<CallOp>(computeFunc.func.sym_name(),
482*86ad0af8SEugene Zhulenev                    computeFunc.func.getCallableResults(), computeFuncOperands);
483*86ad0af8SEugene Zhulenev   b.create<ReturnOp>(ValueRange());
484*86ad0af8SEugene Zhulenev 
485*86ad0af8SEugene Zhulenev   return func;
486*86ad0af8SEugene Zhulenev }
487*86ad0af8SEugene Zhulenev 
488*86ad0af8SEugene Zhulenev // Launch async dispatch of the parallel compute function.
489*86ad0af8SEugene Zhulenev static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
490*86ad0af8SEugene Zhulenev                             ParallelComputeFunction &parallelComputeFunction,
491*86ad0af8SEugene Zhulenev                             scf::ParallelOp op, Value blockSize,
492*86ad0af8SEugene Zhulenev                             Value blockCount,
493*86ad0af8SEugene Zhulenev                             const SmallVector<Value> &tripCounts) {
494*86ad0af8SEugene Zhulenev   MLIRContext *ctx = op->getContext();
495*86ad0af8SEugene Zhulenev 
496*86ad0af8SEugene Zhulenev   // Add one more level of indirection to dispatch parallel compute functions
497*86ad0af8SEugene Zhulenev   // using async operations and recursive work splitting.
498*86ad0af8SEugene Zhulenev   FuncOp asyncDispatchFunction =
499*86ad0af8SEugene Zhulenev       createAsyncDispatchFunction(parallelComputeFunction, rewriter);
500*86ad0af8SEugene Zhulenev 
501*86ad0af8SEugene Zhulenev   Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
502*86ad0af8SEugene Zhulenev   Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
503*86ad0af8SEugene Zhulenev 
504*86ad0af8SEugene Zhulenev   // Create an async.group to wait on all async tokens from the concurrent
505*86ad0af8SEugene Zhulenev   // execution of multiple parallel compute function. First block will be
506*86ad0af8SEugene Zhulenev   // executed synchronously in the caller thread.
507*86ad0af8SEugene Zhulenev   Value groupSize = b.create<SubIOp>(blockCount, c1);
508*86ad0af8SEugene Zhulenev   Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
509*86ad0af8SEugene Zhulenev 
510*86ad0af8SEugene Zhulenev   // Pack the async dispath function operands to launch the work splitting.
511*86ad0af8SEugene Zhulenev   SmallVector<Value> asyncDispatchOperands = {group, c0, blockCount, blockSize};
512*86ad0af8SEugene Zhulenev   asyncDispatchOperands.append(tripCounts);
513*86ad0af8SEugene Zhulenev   asyncDispatchOperands.append(op.lowerBound().begin(), op.lowerBound().end());
514*86ad0af8SEugene Zhulenev   asyncDispatchOperands.append(op.upperBound().begin(), op.upperBound().end());
515*86ad0af8SEugene Zhulenev   asyncDispatchOperands.append(op.step().begin(), op.step().end());
516*86ad0af8SEugene Zhulenev   asyncDispatchOperands.append(parallelComputeFunction.captures);
517*86ad0af8SEugene Zhulenev 
518*86ad0af8SEugene Zhulenev   // Launch async dispatch function for [0, blockCount) range.
519*86ad0af8SEugene Zhulenev   b.create<CallOp>(asyncDispatchFunction.sym_name(),
520*86ad0af8SEugene Zhulenev                    asyncDispatchFunction.getCallableResults(),
521*86ad0af8SEugene Zhulenev                    asyncDispatchOperands);
522*86ad0af8SEugene Zhulenev 
523*86ad0af8SEugene Zhulenev   // Wait for the completion of all parallel compute operations.
524*86ad0af8SEugene Zhulenev   b.create<AwaitAllOp>(group);
525*86ad0af8SEugene Zhulenev }
526*86ad0af8SEugene Zhulenev 
527*86ad0af8SEugene Zhulenev // Dispatch parallel compute functions by submitting all async compute tasks
528*86ad0af8SEugene Zhulenev // from a simple for loop in the caller thread.
529*86ad0af8SEugene Zhulenev static void
530*86ad0af8SEugene Zhulenev doSequantialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
531*86ad0af8SEugene Zhulenev                      ParallelComputeFunction &parallelComputeFunction,
532*86ad0af8SEugene Zhulenev                      scf::ParallelOp op, Value blockSize, Value blockCount,
533*86ad0af8SEugene Zhulenev                      const SmallVector<Value> &tripCounts) {
534*86ad0af8SEugene Zhulenev   MLIRContext *ctx = op->getContext();
535*86ad0af8SEugene Zhulenev 
536*86ad0af8SEugene Zhulenev   FuncOp compute = parallelComputeFunction.func;
537*86ad0af8SEugene Zhulenev 
538*86ad0af8SEugene Zhulenev   Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
539*86ad0af8SEugene Zhulenev   Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
540*86ad0af8SEugene Zhulenev 
541*86ad0af8SEugene Zhulenev   // Create an async.group to wait on all async tokens from the concurrent
542*86ad0af8SEugene Zhulenev   // execution of multiple parallel compute function. First block will be
543*86ad0af8SEugene Zhulenev   // executed synchronously in the caller thread.
544*86ad0af8SEugene Zhulenev   Value groupSize = b.create<SubIOp>(blockCount, c1);
545*86ad0af8SEugene Zhulenev   Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
546*86ad0af8SEugene Zhulenev 
547*86ad0af8SEugene Zhulenev   // Call parallel compute function for all blocks.
548*86ad0af8SEugene Zhulenev   using LoopBodyBuilder =
549*86ad0af8SEugene Zhulenev       std::function<void(OpBuilder &, Location, Value, ValueRange)>;
550*86ad0af8SEugene Zhulenev 
551*86ad0af8SEugene Zhulenev   // Returns parallel compute function operands to process the given block.
552*86ad0af8SEugene Zhulenev   auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> {
553*86ad0af8SEugene Zhulenev     SmallVector<Value> computeFuncOperands = {blockIndex, blockSize};
554*86ad0af8SEugene Zhulenev     computeFuncOperands.append(tripCounts);
555*86ad0af8SEugene Zhulenev     computeFuncOperands.append(op.lowerBound().begin(), op.lowerBound().end());
556*86ad0af8SEugene Zhulenev     computeFuncOperands.append(op.upperBound().begin(), op.upperBound().end());
557*86ad0af8SEugene Zhulenev     computeFuncOperands.append(op.step().begin(), op.step().end());
558*86ad0af8SEugene Zhulenev     computeFuncOperands.append(parallelComputeFunction.captures);
559*86ad0af8SEugene Zhulenev     return computeFuncOperands;
560*86ad0af8SEugene Zhulenev   };
561*86ad0af8SEugene Zhulenev 
562*86ad0af8SEugene Zhulenev   // Induction variable is the index of the block: [0, blockCount).
563*86ad0af8SEugene Zhulenev   LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc,
564*86ad0af8SEugene Zhulenev                                     Value iv, ValueRange args) {
565*86ad0af8SEugene Zhulenev     ImplicitLocOpBuilder nb(loc, loopBuilder);
566*86ad0af8SEugene Zhulenev 
567*86ad0af8SEugene Zhulenev     // Call parallel compute function inside the async.execute region.
568*86ad0af8SEugene Zhulenev     auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
569*86ad0af8SEugene Zhulenev                                   Location executeLoc, ValueRange executeArgs) {
570*86ad0af8SEugene Zhulenev       executeBuilder.create<CallOp>(executeLoc, compute.sym_name(),
571*86ad0af8SEugene Zhulenev                                     compute.getCallableResults(),
572*86ad0af8SEugene Zhulenev                                     computeFuncOperands(iv));
573*86ad0af8SEugene Zhulenev       executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
574*86ad0af8SEugene Zhulenev     };
575*86ad0af8SEugene Zhulenev 
576*86ad0af8SEugene Zhulenev     // Create async.execute operation to launch parallel computate function.
577*86ad0af8SEugene Zhulenev     auto execute = nb.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
578*86ad0af8SEugene Zhulenev                                         executeBodyBuilder);
579*86ad0af8SEugene Zhulenev     nb.create<AddToGroupOp>(rewriter.getIndexType(), execute.token(), group);
580*86ad0af8SEugene Zhulenev     nb.create<scf::YieldOp>();
581*86ad0af8SEugene Zhulenev   };
582*86ad0af8SEugene Zhulenev 
583*86ad0af8SEugene Zhulenev   // Iterate over all compute blocks and launch parallel compute operations.
584*86ad0af8SEugene Zhulenev   b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder);
585*86ad0af8SEugene Zhulenev 
586*86ad0af8SEugene Zhulenev   // Call parallel compute function for the first block in the caller thread.
587*86ad0af8SEugene Zhulenev   b.create<CallOp>(compute.sym_name(), compute.getCallableResults(),
588*86ad0af8SEugene Zhulenev                    computeFuncOperands(c0));
589*86ad0af8SEugene Zhulenev 
590*86ad0af8SEugene Zhulenev   // Wait for the completion of all async compute operations.
591*86ad0af8SEugene Zhulenev   b.create<AwaitAllOp>(group);
592*86ad0af8SEugene Zhulenev }
593*86ad0af8SEugene Zhulenev 
594c30ab6c2SEugene Zhulenev LogicalResult
595c30ab6c2SEugene Zhulenev AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
596c30ab6c2SEugene Zhulenev                                          PatternRewriter &rewriter) const {
597c30ab6c2SEugene Zhulenev   // We do not currently support rewrite for parallel op with reductions.
598c30ab6c2SEugene Zhulenev   if (op.getNumReductions() != 0)
599c30ab6c2SEugene Zhulenev     return failure();
600c30ab6c2SEugene Zhulenev 
601*86ad0af8SEugene Zhulenev   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
602c30ab6c2SEugene Zhulenev 
603c30ab6c2SEugene Zhulenev   // Compute trip count for each loop induction variable:
604*86ad0af8SEugene Zhulenev   //   tripCount = ceil_div(upperBound - lowerBound, step);
605*86ad0af8SEugene Zhulenev   SmallVector<Value> tripCounts(op.getNumLoops());
606c30ab6c2SEugene Zhulenev   for (size_t i = 0; i < op.getNumLoops(); ++i) {
607c30ab6c2SEugene Zhulenev     auto lb = op.lowerBound()[i];
608c30ab6c2SEugene Zhulenev     auto ub = op.upperBound()[i];
609c30ab6c2SEugene Zhulenev     auto step = op.step()[i];
610*86ad0af8SEugene Zhulenev     auto range = b.create<SubIOp>(ub, lb);
611*86ad0af8SEugene Zhulenev     tripCounts[i] = b.create<SignedCeilDivIOp>(range, step);
612c30ab6c2SEugene Zhulenev   }
613c30ab6c2SEugene Zhulenev 
614*86ad0af8SEugene Zhulenev   // Compute a product of trip counts to get the 1-dimensional iteration space
615*86ad0af8SEugene Zhulenev   // for the scf.parallel operation.
616*86ad0af8SEugene Zhulenev   Value tripCount = tripCounts[0];
617*86ad0af8SEugene Zhulenev   for (size_t i = 1; i < tripCounts.size(); ++i)
618*86ad0af8SEugene Zhulenev     tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
619c30ab6c2SEugene Zhulenev 
620*86ad0af8SEugene Zhulenev   auto indexTy = b.getIndexType();
621c30ab6c2SEugene Zhulenev 
622*86ad0af8SEugene Zhulenev   // Do not overload worker threads with too many compute blocks.
623*86ad0af8SEugene Zhulenev   Value maxComputeBlocks = b.create<ConstantOp>(
624*86ad0af8SEugene Zhulenev       indexTy, b.getIndexAttr(numWorkerThreads * kMaxOversharding));
625c30ab6c2SEugene Zhulenev 
626*86ad0af8SEugene Zhulenev   // Target block size from the pass parameters.
627*86ad0af8SEugene Zhulenev   Value targetComputeBlockSize =
628*86ad0af8SEugene Zhulenev       b.create<ConstantOp>(indexTy, b.getIndexAttr(targetBlockSize));
629c30ab6c2SEugene Zhulenev 
630*86ad0af8SEugene Zhulenev   // Compute parallel block size from the parallel problem size:
631*86ad0af8SEugene Zhulenev   //   blockSize = min(tripCount,
632*86ad0af8SEugene Zhulenev   //                   max(divup(tripCount, maxComputeBlocks),
633*86ad0af8SEugene Zhulenev   //                       targetComputeBlockSize))
634*86ad0af8SEugene Zhulenev   Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
635*86ad0af8SEugene Zhulenev   Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
636*86ad0af8SEugene Zhulenev   Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlockSize);
637*86ad0af8SEugene Zhulenev   Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
638*86ad0af8SEugene Zhulenev   Value blockSize = b.create<SelectOp>(bs3, tripCount, bs2);
639*86ad0af8SEugene Zhulenev   Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);
640*86ad0af8SEugene Zhulenev 
641*86ad0af8SEugene Zhulenev   // Create a parallel compute function that takes a block id and computes the
642*86ad0af8SEugene Zhulenev   // parallel operation body for a subset of iteration space.
643*86ad0af8SEugene Zhulenev   ParallelComputeFunction parallelComputeFunction =
644*86ad0af8SEugene Zhulenev       createParallelComputeFunction(op, rewriter);
645*86ad0af8SEugene Zhulenev 
646*86ad0af8SEugene Zhulenev   // Dispatch parallel compute function using async recursive work splitting, or
647*86ad0af8SEugene Zhulenev   // by submitting compute task sequentially from a caller thread.
648*86ad0af8SEugene Zhulenev   if (asyncDispatch) {
649*86ad0af8SEugene Zhulenev     doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
650*86ad0af8SEugene Zhulenev                     blockCount, tripCounts);
651*86ad0af8SEugene Zhulenev   } else {
652*86ad0af8SEugene Zhulenev     doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
653*86ad0af8SEugene Zhulenev                          blockCount, tripCounts);
654c30ab6c2SEugene Zhulenev   }
655c30ab6c2SEugene Zhulenev 
656*86ad0af8SEugene Zhulenev   // Parallel operation was replaces with a block iteration loop.
657c30ab6c2SEugene Zhulenev   rewriter.eraseOp(op);
658c30ab6c2SEugene Zhulenev 
659c30ab6c2SEugene Zhulenev   return success();
660c30ab6c2SEugene Zhulenev }
661c30ab6c2SEugene Zhulenev 
6628a316b00SEugene Zhulenev void AsyncParallelForPass::runOnOperation() {
663c30ab6c2SEugene Zhulenev   MLIRContext *ctx = &getContext();
664c30ab6c2SEugene Zhulenev 
665dc4e913bSChris Lattner   RewritePatternSet patterns(ctx);
666*86ad0af8SEugene Zhulenev   patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
667*86ad0af8SEugene Zhulenev                                         targetBlockSize);
668c30ab6c2SEugene Zhulenev 
6698a316b00SEugene Zhulenev   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
670c30ab6c2SEugene Zhulenev     signalPassFailure();
671c30ab6c2SEugene Zhulenev }
672c30ab6c2SEugene Zhulenev 
6738a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
674c30ab6c2SEugene Zhulenev   return std::make_unique<AsyncParallelForPass>();
675c30ab6c2SEugene Zhulenev }
676