1 //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements scf.parallel to src.for + async.execute conversion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Async/IR/Async.h"
15 #include "mlir/Dialect/Async/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::async;
24 
25 #define DEBUG_TYPE "async-parallel-for"
26 
27 namespace {
28 
29 // Rewrite scf.parallel operation into multiple concurrent async.execute
30 // operations over non overlapping subranges of the original loop.
31 //
32 // Example:
33 //
34 //   scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
35 //     "do_some_compute"(%i, %j): () -> ()
36 //   }
37 //
38 // Converted to:
39 //
40 //   %c0 = constant 0 : index
41 //   %c1 = constant 1 : index
42 //
43 //   // Compute blocks sizes for each induction variable.
44 //   %num_blocks_i = ... : index
45 //   %num_blocks_j = ... : index
46 //   %block_size_i = ... : index
47 //   %block_size_j = ... : index
48 //
49 //   // Create an async group to track async execute ops.
50 //   %group = async.create_group
51 //
52 //   scf.for %bi = %c0 to %num_blocks_i step %c1 {
53 //     %block_start_i = ... : index
54 //     %block_end_i   = ... : index
55 //
56 //     scf.for %bj = %c0 to %num_blocks_j step %c1 {
57 //       %block_start_j = ... : index
58 //       %block_end_j   = ... : index
59 //
60 //       // Execute the body of original parallel operation for the current
61 //       // block.
62 //       %token = async.execute {
63 //         scf.for %i = %block_start_i to %block_end_i step %si {
64 //           scf.for %j = %block_start_j to %block_end_j step %sj {
65 //             "do_some_compute"(%i, %j): () -> ()
66 //           }
67 //         }
68 //       }
69 //
70 //       // Add produced async token to the group.
71 //       async.add_to_group %token, %group
72 //     }
73 //   }
74 //
75 //   // Await completion of all async.execute operations.
76 //   async.await_all %group
77 //
78 // In this example outer loop launches inner block level loops as separate async
79 // execute operations which will be executed concurrently.
80 //
81 // At the end it waits for the completiom of all async execute operations.
82 //
83 struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
84 public:
85   AsyncParallelForRewrite(MLIRContext *ctx, int numConcurrentAsyncExecute)
86       : OpRewritePattern(ctx),
87         numConcurrentAsyncExecute(numConcurrentAsyncExecute) {}
88 
89   LogicalResult matchAndRewrite(scf::ParallelOp op,
90                                 PatternRewriter &rewriter) const override;
91 
92 private:
93   int numConcurrentAsyncExecute;
94 };
95 
96 struct AsyncParallelForPass
97     : public AsyncParallelForBase<AsyncParallelForPass> {
98   AsyncParallelForPass() = default;
99   void runOnFunction() override;
100 };
101 
102 } // namespace
103 
104 LogicalResult
105 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
106                                          PatternRewriter &rewriter) const {
107   // We do not currently support rewrite for parallel op with reductions.
108   if (op.getNumReductions() != 0)
109     return failure();
110 
111   MLIRContext *ctx = op.getContext();
112   Location loc = op.getLoc();
113 
114   // Index constants used below.
115   auto indexTy = IndexType::get(ctx);
116   auto zero = IntegerAttr::get(indexTy, 0);
117   auto one = IntegerAttr::get(indexTy, 1);
118   auto c0 = rewriter.create<ConstantOp>(loc, indexTy, zero);
119   auto c1 = rewriter.create<ConstantOp>(loc, indexTy, one);
120 
121   // Shorthand for signed integer ceil division operation.
122   auto divup = [&](Value x, Value y) -> Value {
123     return rewriter.create<SignedCeilDivIOp>(loc, x, y);
124   };
125 
126   // Compute trip count for each loop induction variable:
127   //   tripCount = divUp(upperBound - lowerBound, step);
128   SmallVector<Value, 4> tripCounts(op.getNumLoops());
129   for (size_t i = 0; i < op.getNumLoops(); ++i) {
130     auto lb = op.lowerBound()[i];
131     auto ub = op.upperBound()[i];
132     auto step = op.step()[i];
133     auto range = rewriter.create<SubIOp>(loc, ub, lb);
134     tripCounts[i] = divup(range, step);
135   }
136 
137   // The target number of concurrent async.execute ops.
138   auto numExecuteOps = rewriter.create<ConstantOp>(
139       loc, indexTy, IntegerAttr::get(indexTy, numConcurrentAsyncExecute));
140 
141   // Blocks sizes configuration for each induction variable.
142 
143   // We try to use maximum available concurrency in outer dimensions first
144   // (assuming that parallel induction variables are corresponding to some
145   // multidimensional access, e.g. in (%d0, %d1, ..., %dn) = (<from>) to (<to>)
146   // we will try to parallelize iteration along the %d0. If %d0 is too small,
147   // we'll parallelize iteration over %d1, and so on.
148   SmallVector<Value, 4> targetNumBlocks(op.getNumLoops());
149   SmallVector<Value, 4> blockSize(op.getNumLoops());
150   SmallVector<Value, 4> numBlocks(op.getNumLoops());
151 
152   // Compute block size and number of blocks along the first induction variable.
153   targetNumBlocks[0] = numExecuteOps;
154   blockSize[0] = divup(tripCounts[0], targetNumBlocks[0]);
155   numBlocks[0] = divup(tripCounts[0], blockSize[0]);
156 
157   // Assign remaining available concurrency to other induction variables.
158   for (size_t i = 1; i < op.getNumLoops(); ++i) {
159     targetNumBlocks[i] = divup(targetNumBlocks[i - 1], numBlocks[i - 1]);
160     blockSize[i] = divup(tripCounts[i], targetNumBlocks[i]);
161     numBlocks[i] = divup(tripCounts[i], blockSize[i]);
162   }
163 
164   // Create an async.group to wait on all async tokens from async execute ops.
165   auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx));
166 
167   // Build a scf.for loop nest from the parallel operation.
168 
169   // Lower/upper bounds for nest block level computations.
170   SmallVector<Value, 4> blockLowerBounds(op.getNumLoops());
171   SmallVector<Value, 4> blockUpperBounds(op.getNumLoops());
172   SmallVector<Value, 4> blockInductionVars(op.getNumLoops());
173 
174   using LoopBodyBuilder =
175       std::function<void(OpBuilder &, Location, Value, ValueRange)>;
176   using LoopBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
177 
178   // Builds inner loop nest inside async.execute operation that does all the
179   // work concurrently.
180   LoopBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
181     return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) {
182       blockInductionVars[loopIdx] = iv;
183 
184       // Continute building async loop nest.
185       if (loopIdx < op.getNumLoops() - 1) {
186         b.create<scf::ForOp>(
187             loc, blockLowerBounds[loopIdx + 1], blockUpperBounds[loopIdx + 1],
188             op.step()[loopIdx + 1], ValueRange(), workLoopBuilder(loopIdx + 1));
189         b.create<scf::YieldOp>(loc);
190         return;
191       }
192 
193       // Copy the body of the parallel op with new loop bounds.
194       BlockAndValueMapping mapping;
195       mapping.map(op.getInductionVars(), blockInductionVars);
196 
197       for (auto &bodyOp : op.getLoopBody().getOps())
198         b.clone(bodyOp, mapping);
199     };
200   };
201 
202   // Builds a loop nest that does async execute op dispatching.
203   LoopBuilder asyncLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
204     return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) {
205       auto lb = op.lowerBound()[loopIdx];
206       auto ub = op.upperBound()[loopIdx];
207       auto step = op.step()[loopIdx];
208 
209       // Compute lower bound for the current block:
210       //   blockLowerBound = iv * blockSize * step + lowerBound
211       auto s0 = b.create<MulIOp>(loc, iv, blockSize[loopIdx]);
212       auto s1 = b.create<MulIOp>(loc, s0, step);
213       auto s2 = b.create<AddIOp>(loc, s1, lb);
214       blockLowerBounds[loopIdx] = s2;
215 
216       // Compute upper bound for the current block:
217       //   blockUpperBound = min(upperBound,
218       //                         blockLowerBound + blockSize * step)
219       auto e0 = b.create<MulIOp>(loc, blockSize[loopIdx], step);
220       auto e1 = b.create<AddIOp>(loc, e0, s2);
221       auto e2 = b.create<CmpIOp>(loc, CmpIPredicate::slt, e1, ub);
222       auto e3 = b.create<SelectOp>(loc, e2, e1, ub);
223       blockUpperBounds[loopIdx] = e3;
224 
225       // Continue building async dispatch loop nest.
226       if (loopIdx < op.getNumLoops() - 1) {
227         b.create<scf::ForOp>(loc, c0, numBlocks[loopIdx + 1], c1, ValueRange(),
228                              asyncLoopBuilder(loopIdx + 1));
229         b.create<scf::YieldOp>(loc);
230         return;
231       }
232 
233       // Build the inner loop nest that will do the actual work inside the
234       // `async.execute` body region.
235       auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
236                                     Location executeLoc,
237                                     ValueRange executeArgs) {
238         executeBuilder.create<scf::ForOp>(executeLoc, blockLowerBounds[0],
239                                           blockUpperBounds[0], op.step()[0],
240                                           ValueRange(), workLoopBuilder(0));
241         executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
242       };
243 
244       auto execute = b.create<ExecuteOp>(
245           loc, /*resultTypes=*/TypeRange(), /*dependencies=*/ValueRange(),
246           /*operands=*/ValueRange(), executeBodyBuilder);
247       auto rankType = IndexType::get(ctx);
248       b.create<AddToGroupOp>(loc, rankType, execute.token(), group.result());
249       b.create<scf::YieldOp>(loc);
250     };
251   };
252 
253   // Start building a loop nest from the first induction variable.
254   rewriter.create<scf::ForOp>(loc, c0, numBlocks[0], c1, ValueRange(),
255                               asyncLoopBuilder(0));
256 
257   // Wait for the completion of all subtasks.
258   rewriter.create<AwaitAllOp>(loc, group.result());
259 
260   // Erase the original parallel operation.
261   rewriter.eraseOp(op);
262 
263   return success();
264 }
265 
266 void AsyncParallelForPass::runOnFunction() {
267   MLIRContext *ctx = &getContext();
268 
269   OwningRewritePatternList patterns;
270   patterns.insert<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
271 
272   if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
273     signalPassFailure();
274 }
275 
276 std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
277   return std::make_unique<AsyncParallelForPass>();
278 }
279