1 //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements scf.parallel to scf.for + async.execute conversion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Async/IR/Async.h"
16 #include "mlir/Dialect/Async/Passes.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "mlir/Transforms/RegionUtils.h"
25 
26 using namespace mlir;
27 using namespace mlir::async;
28 
29 #define DEBUG_TYPE "async-parallel-for"
30 
31 namespace {
32 
33 // Rewrite scf.parallel operation into multiple concurrent async.execute
34 // operations over non overlapping subranges of the original loop.
35 //
36 // Example:
37 //
38 //   scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
39 //     "do_some_compute"(%i, %j): () -> ()
40 //   }
41 //
42 // Converted to:
43 //
44 //   // Parallel compute function that executes the parallel body region for
45 //   // a subset of the parallel iteration space defined by the one-dimensional
46 //   // compute block index.
47 //   func parallel_compute_function(%block_index : index, %block_size : index,
48 //                                  <parallel operation properties>, ...) {
49 //     // Compute multi-dimensional loop bounds for %block_index.
50 //     %block_lbi, %block_lbj = ...
51 //     %block_ubi, %block_ubj = ...
52 //
53 //     // Clone parallel operation body into the scf.for loop nest.
54 //     scf.for %i = %blockLbi to %blockUbi {
55 //       scf.for %j = block_lbj to %block_ubj {
56 //         "do_some_compute"(%i, %j): () -> ()
57 //       }
58 //     }
59 //   }
60 //
61 // And a dispatch function depending on the `asyncDispatch` option.
62 //
63 // When async dispatch is on: (pseudocode)
64 //
65 //   %block_size = ... compute parallel compute block size
66 //   %block_count = ... compute the number of compute blocks
67 //
68 //   func @async_dispatch(%block_start : index, %block_end : index, ...) {
69 //     // Keep splitting block range until we reached a range of size 1.
70 //     while (%block_end - %block_start > 1) {
71 //       %mid_index = block_start + (block_end - block_start) / 2;
72 //       async.execute { call @async_dispatch(%mid_index, %block_end); }
73 //       %block_end = %mid_index
74 //     }
75 //
76 //     // Call parallel compute function for a single block.
77 //     call @parallel_compute_fn(%block_start, %block_size, ...);
78 //   }
79 //
80 //   // Launch async dispatch for [0, block_count) range.
81 //   call @async_dispatch(%c0, %block_count);
82 //
83 // When async dispatch is off:
84 //
85 //   %block_size = ... compute parallel compute block size
86 //   %block_count = ... compute the number of compute blocks
87 //
88 //   scf.for %block_index = %c0 to %block_count {
89 //      call @parallel_compute_fn(%block_index, %block_size, ...)
90 //   }
91 //
92 struct AsyncParallelForPass
93     : public AsyncParallelForBase<AsyncParallelForPass> {
94   AsyncParallelForPass() = default;
95 
96   AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
97                        int32_t minTaskSize) {
98     this->asyncDispatch = asyncDispatch;
99     this->numWorkerThreads = numWorkerThreads;
100     this->minTaskSize = minTaskSize;
101   }
102 
103   void runOnOperation() override;
104 };
105 
106 struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
107 public:
108   AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch,
109                           int32_t numWorkerThreads, int32_t minTaskSize)
110       : OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
111         numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {}
112 
113   LogicalResult matchAndRewrite(scf::ParallelOp op,
114                                 PatternRewriter &rewriter) const override;
115 
116 private:
117   bool asyncDispatch;
118   int32_t numWorkerThreads;
119   int32_t minTaskSize;
120 };
121 
122 struct ParallelComputeFunctionType {
123   FunctionType type;
124   SmallVector<Value> captures;
125 };
126 
127 // Helper struct to parse parallel compute function argument list.
128 struct ParallelComputeFunctionArgs {
129   BlockArgument blockIndex();
130   BlockArgument blockSize();
131   ArrayRef<BlockArgument> tripCounts();
132   ArrayRef<BlockArgument> lowerBounds();
133   ArrayRef<BlockArgument> upperBounds();
134   ArrayRef<BlockArgument> steps();
135   ArrayRef<BlockArgument> captures();
136 
137   unsigned numLoops;
138   ArrayRef<BlockArgument> args;
139 };
140 
141 struct ParallelComputeFunctionBounds {
142   SmallVector<IntegerAttr> tripCounts;
143   SmallVector<IntegerAttr> lowerBounds;
144   SmallVector<IntegerAttr> upperBounds;
145   SmallVector<IntegerAttr> steps;
146 };
147 
148 struct ParallelComputeFunction {
149   unsigned numLoops;
150   FuncOp func;
151   llvm::SmallVector<Value> captures;
152 };
153 
154 } // namespace
155 
156 BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; }
157 BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; }
158 
159 ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() {
160   return args.drop_front(2).take_front(numLoops);
161 }
162 
163 ArrayRef<BlockArgument> ParallelComputeFunctionArgs::lowerBounds() {
164   return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
165 }
166 
167 ArrayRef<BlockArgument> ParallelComputeFunctionArgs::upperBounds() {
168   return args.drop_front(2 + 2 * numLoops).take_front(numLoops);
169 }
170 
171 ArrayRef<BlockArgument> ParallelComputeFunctionArgs::steps() {
172   return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
173 }
174 
175 ArrayRef<BlockArgument> ParallelComputeFunctionArgs::captures() {
176   return args.drop_front(2 + 4 * numLoops);
177 }
178 
179 template <typename ValueRange>
180 static SmallVector<IntegerAttr> integerConstants(ValueRange values) {
181   SmallVector<IntegerAttr> attrs(values.size());
182   for (unsigned i = 0; i < values.size(); ++i)
183     matchPattern(values[i], m_Constant(&attrs[i]));
184   return attrs;
185 }
186 
187 // Converts one-dimensional iteration index in the [0, tripCount) interval
188 // into multidimensional iteration coordinate.
189 static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
190                                       ArrayRef<Value> tripCounts) {
191   SmallVector<Value> coords(tripCounts.size());
192   assert(!tripCounts.empty() && "tripCounts must be not empty");
193 
194   for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
195     coords[i] = b.create<arith::RemSIOp>(index, tripCounts[i]);
196     index = b.create<arith::DivSIOp>(index, tripCounts[i]);
197   }
198 
199   return coords;
200 }
201 
202 // Returns a function type and implicit captures for a parallel compute
203 // function. We'll need a list of implicit captures to setup block and value
204 // mapping when we'll clone the body of the parallel operation.
205 static ParallelComputeFunctionType
206 getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
207   // Values implicitly captured by the parallel operation.
208   llvm::SetVector<Value> captures;
209   getUsedValuesDefinedAbove(op.region(), op.region(), captures);
210 
211   SmallVector<Type> inputs;
212   inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
213 
214   Type indexTy = rewriter.getIndexType();
215 
216   // One-dimensional iteration space defined by the block index and size.
217   inputs.push_back(indexTy); // blockIndex
218   inputs.push_back(indexTy); // blockSize
219 
220   // Multi-dimensional parallel iteration space defined by the loop trip counts.
221   for (unsigned i = 0; i < op.getNumLoops(); ++i)
222     inputs.push_back(indexTy); // loop tripCount
223 
224   // Parallel operation lower bound, upper bound and step. Lower bound, upper
225   // bound and step passed as contiguous arguments:
226   //   call @compute(%lb0, %lb1, ..., %ub0, %ub1, ..., %step0, %step1, ...)
227   for (unsigned i = 0; i < op.getNumLoops(); ++i) {
228     inputs.push_back(indexTy); // lower bound
229     inputs.push_back(indexTy); // upper bound
230     inputs.push_back(indexTy); // step
231   }
232 
233   // Types of the implicit captures.
234   for (Value capture : captures)
235     inputs.push_back(capture.getType());
236 
237   // Convert captures to vector for later convenience.
238   SmallVector<Value> capturesVector(captures.begin(), captures.end());
239   return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector};
240 }
241 
242 // Create a parallel compute fuction from the parallel operation.
243 static ParallelComputeFunction
244 createParallelComputeFunction(scf::ParallelOp op,
245                               ParallelComputeFunctionBounds bounds,
246                               PatternRewriter &rewriter) {
247   OpBuilder::InsertionGuard guard(rewriter);
248   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
249 
250   ModuleOp module = op->getParentOfType<ModuleOp>();
251 
252   ParallelComputeFunctionType computeFuncType =
253       getParallelComputeFunctionType(op, rewriter);
254 
255   FunctionType type = computeFuncType.type;
256   FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type);
257   func.setPrivate();
258 
259   // Insert function into the module symbol table and assign it unique name.
260   SymbolTable symbolTable(module);
261   symbolTable.insert(func);
262   rewriter.getListener()->notifyOperationInserted(func);
263 
264   // Create function entry block.
265   Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs());
266   b.setInsertionPointToEnd(block);
267 
268   ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
269 
270   // Block iteration position defined by the block index and size.
271   BlockArgument blockIndex = args.blockIndex();
272   BlockArgument blockSize = args.blockSize();
273 
274   // Constants used below.
275   Value c0 = b.create<arith::ConstantIndexOp>(0);
276   Value c1 = b.create<arith::ConstantIndexOp>(1);
277 
278   // Materialize known constants as constant operation in the function body.
279   auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) {
280     return llvm::to_vector(
281         llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
282           if (IntegerAttr attr = std::get<1>(tuple))
283             return b.create<ConstantOp>(attr);
284           return std::get<0>(tuple);
285         }));
286   };
287 
288   // Multi-dimensional parallel iteration space defined by the loop trip counts.
289   auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
290 
291   // Parallel operation lower bound and step.
292   auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
293   auto steps = values(args.steps(), bounds.steps);
294 
295   // Remaining arguments are implicit captures of the parallel operation.
296   ArrayRef<BlockArgument> captures = args.captures();
297 
298   // Compute a product of trip counts to get the size of the flattened
299   // one-dimensional iteration space.
300   Value tripCount = tripCounts[0];
301   for (unsigned i = 1; i < tripCounts.size(); ++i)
302     tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
303 
304   // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
305   //   blockFirstIndex = blockIndex * blockSize
306   Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize);
307 
308   // The last one-dimensional index in the block defined by the `blockIndex`:
309   //   blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1
310   Value blockEnd0 = b.create<arith::AddIOp>(blockFirstIndex, blockSize);
311   Value blockEnd1 = b.create<arith::MinSIOp>(blockEnd0, tripCount);
312   Value blockLastIndex = b.create<arith::SubIOp>(blockEnd1, c1);
313 
314   // Convert one-dimensional indices to multi-dimensional coordinates.
315   auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
316   auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
317 
318   // Compute loops upper bounds derived from the block last coordinates:
319   //   blockEndCoord[i] = blockLastCoord[i] + 1
320   //
321   // Block first and last coordinates can be the same along the outer compute
322   // dimension when inner compute dimension contains multiple blocks.
323   SmallVector<Value> blockEndCoord(op.getNumLoops());
324   for (size_t i = 0; i < blockLastCoord.size(); ++i)
325     blockEndCoord[i] = b.create<arith::AddIOp>(blockLastCoord[i], c1);
326 
327   // Construct a loop nest out of scf.for operations that will iterate over
328   // all coordinates in [blockFirstCoord, blockLastCoord] range.
329   using LoopBodyBuilder =
330       std::function<void(OpBuilder &, Location, Value, ValueRange)>;
331   using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
332 
333   // Parallel region induction variables computed from the multi-dimensional
334   // iteration coordinate using parallel operation bounds and step:
335   //
336   //   computeBlockInductionVars[loopIdx] =
337   //       lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopIdx]
338   SmallVector<Value> computeBlockInductionVars(op.getNumLoops());
339 
340   // We need to know if we are in the first or last iteration of the
341   // multi-dimensional loop for each loop in the nest, so we can decide what
342   // loop bounds should we use for the nested loops: bounds defined by compute
343   // block interval, or bounds defined by the parallel operation.
344   //
345   // Example: 2d parallel operation
346   //                   i   j
347   //   loop sizes:   [50, 50]
348   //   first coord:  [25, 25]
349   //   last coord:   [30, 30]
350   //
351   // If `i` is equal to 25 then iteration over `j` should start at 25, when `i`
352   // is between 25 and 30 it should start at 0. The upper bound for `j` should
353   // be 50, except when `i` is equal to 30, then it should also be 30.
354   //
355   // Value at ith position specifies if all loops in [0, i) range of the loop
356   // nest are in the first/last iteration.
357   SmallVector<Value> isBlockFirstCoord(op.getNumLoops());
358   SmallVector<Value> isBlockLastCoord(op.getNumLoops());
359 
360   // Builds inner loop nest inside async.execute operation that does all the
361   // work concurrently.
362   LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
363     return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv,
364                         ValueRange args) {
365       ImplicitLocOpBuilder nb(loc, nestedBuilder);
366 
367       // Compute induction variable for `loopIdx`.
368       computeBlockInductionVars[loopIdx] = nb.create<arith::AddIOp>(
369           lowerBounds[loopIdx], nb.create<arith::MulIOp>(iv, steps[loopIdx]));
370 
371       // Check if we are inside first or last iteration of the loop.
372       isBlockFirstCoord[loopIdx] = nb.create<arith::CmpIOp>(
373           arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
374       isBlockLastCoord[loopIdx] = nb.create<arith::CmpIOp>(
375           arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
376 
377       // Check if the previous loop is in its first or last iteration.
378       if (loopIdx > 0) {
379         isBlockFirstCoord[loopIdx] = nb.create<arith::AndIOp>(
380             isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
381         isBlockLastCoord[loopIdx] = nb.create<arith::AndIOp>(
382             isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
383       }
384 
385       // Keep building loop nest.
386       if (loopIdx < op.getNumLoops() - 1) {
387         // Select nested loop lower/upper bounds depending on our position in
388         // the multi-dimensional iteration space.
389         auto lb = nb.create<SelectOp>(isBlockFirstCoord[loopIdx],
390                                       blockFirstCoord[loopIdx + 1], c0);
391 
392         auto ub = nb.create<SelectOp>(isBlockLastCoord[loopIdx],
393                                       blockEndCoord[loopIdx + 1],
394                                       tripCounts[loopIdx + 1]);
395 
396         nb.create<scf::ForOp>(lb, ub, c1, ValueRange(),
397                               workLoopBuilder(loopIdx + 1));
398         nb.create<scf::YieldOp>(loc);
399         return;
400       }
401 
402       // Copy the body of the parallel op into the inner-most loop.
403       BlockAndValueMapping mapping;
404       mapping.map(op.getInductionVars(), computeBlockInductionVars);
405       mapping.map(computeFuncType.captures, captures);
406 
407       for (auto &bodyOp : op.getLoopBody().getOps())
408         nb.clone(bodyOp, mapping);
409     };
410   };
411 
412   b.create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(),
413                        workLoopBuilder(0));
414   b.create<ReturnOp>(ValueRange());
415 
416   return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
417 }
418 
419 // Creates recursive async dispatch function for the given parallel compute
420 // function. Dispatch function keeps splitting block range into halves until it
421 // reaches a single block, and then excecutes it inline.
422 //
423 // Function pseudocode (mix of C++ and MLIR):
424 //
425 //   func @async_dispatch(%block_start : index, %block_end : index, ...) {
426 //
427 //     // Keep splitting block range until we reached a range of size 1.
428 //     while (%block_end - %block_start > 1) {
429 //       %mid_index = block_start + (block_end - block_start) / 2;
430 //       async.execute { call @async_dispatch(%mid_index, %block_end); }
431 //       %block_end = %mid_index
432 //     }
433 //
434 //     // Call parallel compute function for a single block.
435 //     call @parallel_compute_fn(%block_start, %block_size, ...);
436 //   }
437 //
438 static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
439                                           PatternRewriter &rewriter) {
440   OpBuilder::InsertionGuard guard(rewriter);
441   Location loc = computeFunc.func.getLoc();
442   ImplicitLocOpBuilder b(loc, rewriter);
443 
444   ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
445 
446   ArrayRef<Type> computeFuncInputTypes =
447       computeFunc.func.type().cast<FunctionType>().getInputs();
448 
449   // Compared to the parallel compute function async dispatch function takes
450   // additional !async.group argument. Also instead of a single `blockIndex` it
451   // takes `blockStart` and `blockEnd` arguments to define the range of
452   // dispatched blocks.
453   SmallVector<Type> inputTypes;
454   inputTypes.push_back(async::GroupType::get(rewriter.getContext()));
455   inputTypes.push_back(rewriter.getIndexType()); // add blockStart argument
456   inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
457 
458   FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange());
459   FuncOp func = FuncOp::create(loc, "async_dispatch_fn", type);
460   func.setPrivate();
461 
462   // Insert function into the module symbol table and assign it unique name.
463   SymbolTable symbolTable(module);
464   symbolTable.insert(func);
465   rewriter.getListener()->notifyOperationInserted(func);
466 
467   // Create function entry block.
468   Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs());
469   b.setInsertionPointToEnd(block);
470 
471   Type indexTy = b.getIndexType();
472   Value c1 = b.create<arith::ConstantIndexOp>(1);
473   Value c2 = b.create<arith::ConstantIndexOp>(2);
474 
475   // Get the async group that will track async dispatch completion.
476   Value group = block->getArgument(0);
477 
478   // Get the block iteration range: [blockStart, blockEnd)
479   Value blockStart = block->getArgument(1);
480   Value blockEnd = block->getArgument(2);
481 
482   // Create a work splitting while loop for the [blockStart, blockEnd) range.
483   SmallVector<Type> types = {indexTy, indexTy};
484   SmallVector<Value> operands = {blockStart, blockEnd};
485 
486   // Create a recursive dispatch loop.
487   scf::WhileOp whileOp = b.create<scf::WhileOp>(types, operands);
488   Block *before = b.createBlock(&whileOp.before(), {}, types);
489   Block *after = b.createBlock(&whileOp.after(), {}, types);
490 
491   // Setup dispatch loop condition block: decide if we need to go into the
492   // `after` block and launch one more async dispatch.
493   {
494     b.setInsertionPointToEnd(before);
495     Value start = before->getArgument(0);
496     Value end = before->getArgument(1);
497     Value distance = b.create<arith::SubIOp>(end, start);
498     Value dispatch =
499         b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, distance, c1);
500     b.create<scf::ConditionOp>(dispatch, before->getArguments());
501   }
502 
503   // Setup the async dispatch loop body: recursively call dispatch function
504   // for the seconds half of the original range and go to the next iteration.
505   {
506     b.setInsertionPointToEnd(after);
507     Value start = after->getArgument(0);
508     Value end = after->getArgument(1);
509     Value distance = b.create<arith::SubIOp>(end, start);
510     Value halfDistance = b.create<arith::DivSIOp>(distance, c2);
511     Value midIndex = b.create<arith::AddIOp>(start, halfDistance);
512 
513     // Call parallel compute function inside the async.execute region.
514     auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
515                                   Location executeLoc, ValueRange executeArgs) {
516       // Update the original `blockStart` and `blockEnd` with new range.
517       SmallVector<Value> operands{block->getArguments().begin(),
518                                   block->getArguments().end()};
519       operands[1] = midIndex;
520       operands[2] = end;
521 
522       executeBuilder.create<CallOp>(executeLoc, func.sym_name(),
523                                     func.getCallableResults(), operands);
524       executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
525     };
526 
527     // Create async.execute operation to dispatch half of the block range.
528     auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
529                                        executeBodyBuilder);
530     b.create<AddToGroupOp>(indexTy, execute.token(), group);
531     b.create<scf::YieldOp>(ValueRange({start, midIndex}));
532   }
533 
534   // After dispatching async operations to process the tail of the block range
535   // call the parallel compute function for the first block of the range.
536   b.setInsertionPointAfter(whileOp);
537 
538   // Drop async dispatch specific arguments: async group, block start and end.
539   auto forwardedInputs = block->getArguments().drop_front(3);
540   SmallVector<Value> computeFuncOperands = {blockStart};
541   computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
542 
543   b.create<CallOp>(computeFunc.func.sym_name(),
544                    computeFunc.func.getCallableResults(), computeFuncOperands);
545   b.create<ReturnOp>(ValueRange());
546 
547   return func;
548 }
549 
550 // Launch async dispatch of the parallel compute function.
551 static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
552                             ParallelComputeFunction &parallelComputeFunction,
553                             scf::ParallelOp op, Value blockSize,
554                             Value blockCount,
555                             const SmallVector<Value> &tripCounts) {
556   MLIRContext *ctx = op->getContext();
557 
558   // Add one more level of indirection to dispatch parallel compute functions
559   // using async operations and recursive work splitting.
560   FuncOp asyncDispatchFunction =
561       createAsyncDispatchFunction(parallelComputeFunction, rewriter);
562 
563   Value c0 = b.create<arith::ConstantIndexOp>(0);
564   Value c1 = b.create<arith::ConstantIndexOp>(1);
565 
566   // Appends operands shared by async dispatch and parallel compute functions to
567   // the given operands vector.
568   auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
569     operands.append(tripCounts);
570     operands.append(op.lowerBound().begin(), op.lowerBound().end());
571     operands.append(op.upperBound().begin(), op.upperBound().end());
572     operands.append(op.step().begin(), op.step().end());
573     operands.append(parallelComputeFunction.captures);
574   };
575 
576   // Check if the block size is one, in this case we can skip the async dispatch
577   // completely. If this will be known statically, then canonicalization will
578   // erase async group operations.
579   Value isSingleBlock =
580       b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1);
581 
582   auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
583     ImplicitLocOpBuilder nb(loc, nestedBuilder);
584 
585     // Call parallel compute function for the single block.
586     SmallVector<Value> operands = {c0, blockSize};
587     appendBlockComputeOperands(operands);
588 
589     nb.create<CallOp>(parallelComputeFunction.func.sym_name(),
590                       parallelComputeFunction.func.getCallableResults(),
591                       operands);
592     nb.create<scf::YieldOp>();
593   };
594 
595   auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
596     // Create an async.group to wait on all async tokens from the concurrent
597     // execution of multiple parallel compute function. First block will be
598     // executed synchronously in the caller thread.
599     Value groupSize = b.create<arith::SubIOp>(blockCount, c1);
600     Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
601 
602     ImplicitLocOpBuilder nb(loc, nestedBuilder);
603 
604     // Launch async dispatch function for [0, blockCount) range.
605     SmallVector<Value> operands = {group, c0, blockCount, blockSize};
606     appendBlockComputeOperands(operands);
607 
608     nb.create<CallOp>(asyncDispatchFunction.sym_name(),
609                       asyncDispatchFunction.getCallableResults(), operands);
610 
611     // Wait for the completion of all parallel compute operations.
612     b.create<AwaitAllOp>(group);
613 
614     nb.create<scf::YieldOp>();
615   };
616 
617   // Dispatch either single block compute function, or launch async dispatch.
618   b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
619 }
620 
621 // Dispatch parallel compute functions by submitting all async compute tasks
622 // from a simple for loop in the caller thread.
623 static void
624 doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
625                      ParallelComputeFunction &parallelComputeFunction,
626                      scf::ParallelOp op, Value blockSize, Value blockCount,
627                      const SmallVector<Value> &tripCounts) {
628   MLIRContext *ctx = op->getContext();
629 
630   FuncOp compute = parallelComputeFunction.func;
631 
632   Value c0 = b.create<arith::ConstantIndexOp>(0);
633   Value c1 = b.create<arith::ConstantIndexOp>(1);
634 
635   // Create an async.group to wait on all async tokens from the concurrent
636   // execution of multiple parallel compute function. First block will be
637   // executed synchronously in the caller thread.
638   Value groupSize = b.create<arith::SubIOp>(blockCount, c1);
639   Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
640 
641   // Call parallel compute function for all blocks.
642   using LoopBodyBuilder =
643       std::function<void(OpBuilder &, Location, Value, ValueRange)>;
644 
645   // Returns parallel compute function operands to process the given block.
646   auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> {
647     SmallVector<Value> computeFuncOperands = {blockIndex, blockSize};
648     computeFuncOperands.append(tripCounts);
649     computeFuncOperands.append(op.lowerBound().begin(), op.lowerBound().end());
650     computeFuncOperands.append(op.upperBound().begin(), op.upperBound().end());
651     computeFuncOperands.append(op.step().begin(), op.step().end());
652     computeFuncOperands.append(parallelComputeFunction.captures);
653     return computeFuncOperands;
654   };
655 
656   // Induction variable is the index of the block: [0, blockCount).
657   LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc,
658                                     Value iv, ValueRange args) {
659     ImplicitLocOpBuilder nb(loc, loopBuilder);
660 
661     // Call parallel compute function inside the async.execute region.
662     auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
663                                   Location executeLoc, ValueRange executeArgs) {
664       executeBuilder.create<CallOp>(executeLoc, compute.sym_name(),
665                                     compute.getCallableResults(),
666                                     computeFuncOperands(iv));
667       executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
668     };
669 
670     // Create async.execute operation to launch parallel computate function.
671     auto execute = nb.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
672                                         executeBodyBuilder);
673     nb.create<AddToGroupOp>(rewriter.getIndexType(), execute.token(), group);
674     nb.create<scf::YieldOp>();
675   };
676 
677   // Iterate over all compute blocks and launch parallel compute operations.
678   b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder);
679 
680   // Call parallel compute function for the first block in the caller thread.
681   b.create<CallOp>(compute.sym_name(), compute.getCallableResults(),
682                    computeFuncOperands(c0));
683 
684   // Wait for the completion of all async compute operations.
685   b.create<AwaitAllOp>(group);
686 }
687 
688 LogicalResult
689 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
690                                          PatternRewriter &rewriter) const {
691   // We do not currently support rewrite for parallel op with reductions.
692   if (op.getNumReductions() != 0)
693     return failure();
694 
695   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
696 
697   // Make sure that all constants will be inside the parallel operation body to
698   // reduce the number of parallel compute function arguments.
699   cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
700 
701   // Compute trip count for each loop induction variable:
702   //   tripCount = ceil_div(upperBound - lowerBound, step);
703   SmallVector<Value> tripCounts(op.getNumLoops());
704   for (size_t i = 0; i < op.getNumLoops(); ++i) {
705     auto lb = op.lowerBound()[i];
706     auto ub = op.upperBound()[i];
707     auto step = op.step()[i];
708     auto range = b.createOrFold<arith::SubIOp>(ub, lb);
709     tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
710   }
711 
712   // Compute a product of trip counts to get the 1-dimensional iteration space
713   // for the scf.parallel operation.
714   Value tripCount = tripCounts[0];
715   for (size_t i = 1; i < tripCounts.size(); ++i)
716     tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
717 
718   // Short circuit no-op parallel loops (zero iterations) that can arise from
719   // the memrefs with dynamic dimension(s) equal to zero.
720   Value c0 = b.create<arith::ConstantIndexOp>(0);
721   Value isZeroIterations =
722       b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, tripCount, c0);
723 
724   // Do absolutely nothing if the trip count is zero.
725   auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
726     nestedBuilder.create<scf::YieldOp>(loc);
727   };
728 
729   // Compute the parallel block size and dispatch concurrent tasks computing
730   // results for each block.
731   auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
732     ImplicitLocOpBuilder nb(loc, nestedBuilder);
733 
734     // With large number of threads the value of creating many compute blocks
735     // is reduced because the problem typically becomes memory bound. For small
736     // number of threads it helps with stragglers.
737     float overshardingFactor = numWorkerThreads <= 4    ? 8.0
738                                : numWorkerThreads <= 8  ? 4.0
739                                : numWorkerThreads <= 16 ? 2.0
740                                : numWorkerThreads <= 32 ? 1.0
741                                : numWorkerThreads <= 64 ? 0.8
742                                                         : 0.6;
743 
744     // Do not overload worker threads with too many compute blocks.
745     Value maxComputeBlocks = b.create<arith::ConstantIndexOp>(
746         std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
747 
748     // Target block size from the pass parameters.
749     Value minTaskSizeCst = b.create<arith::ConstantIndexOp>(minTaskSize);
750 
751     // Compute parallel block size from the parallel problem size:
752     //   blockSize = min(tripCount,
753     //                   max(ceil_div(tripCount, maxComputeBlocks),
754     //                       ceil_div(minTaskSize, bodySize)))
755     Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
756     Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSizeCst);
757     Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
758     Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
759 
760     // Collect statically known constants defining the loop nest in the parallel
761     // compute function. LLVM can't always push constants across the non-trivial
762     // async dispatch call graph, by providing these values explicitly we can
763     // choose to build more efficient loop nest, and rely on a better constant
764     // folding, loop unrolling and vectorization.
765     ParallelComputeFunctionBounds staticBounds = {
766         integerConstants(tripCounts),
767         integerConstants(op.lowerBound()),
768         integerConstants(op.upperBound()),
769         integerConstants(op.step()),
770     };
771 
772     // Create a parallel compute function that takes a block id and computes the
773     // parallel operation body for a subset of iteration space.
774     ParallelComputeFunction parallelComputeFunction =
775         createParallelComputeFunction(op, staticBounds, rewriter);
776 
777     // Dispatch parallel compute function using async recursive work splitting,
778     // or by submitting compute task sequentially from a caller thread.
779     if (asyncDispatch) {
780       doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
781                       blockCount, tripCounts);
782     } else {
783       doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
784                            blockCount, tripCounts);
785     }
786 
787     nb.create<scf::YieldOp>();
788   };
789 
790   // Replace the `scf.parallel` operation with the parallel compute function.
791   b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch);
792 
793   // Parallel operation was replaced with a block iteration loop.
794   rewriter.eraseOp(op);
795 
796   return success();
797 }
798 
799 void AsyncParallelForPass::runOnOperation() {
800   MLIRContext *ctx = &getContext();
801 
802   RewritePatternSet patterns(ctx);
803   patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
804                                         minTaskSize);
805 
806   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
807     signalPassFailure();
808 }
809 
810 std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
811   return std::make_unique<AsyncParallelForPass>();
812 }
813 
814 std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch,
815                                                        int32_t numWorkerThreads,
816                                                        int32_t minTaskSize) {
817   return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
818                                                 minTaskSize);
819 }
820