1*c30ab6c2SEugene Zhulenev //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
2*c30ab6c2SEugene Zhulenev //
3*c30ab6c2SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*c30ab6c2SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5*c30ab6c2SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*c30ab6c2SEugene Zhulenev //
7*c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===//
8*c30ab6c2SEugene Zhulenev //
9*c30ab6c2SEugene Zhulenev // This file implements scf.parallel to src.for + async.execute conversion pass.
10*c30ab6c2SEugene Zhulenev //
11*c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===//
12*c30ab6c2SEugene Zhulenev 
13*c30ab6c2SEugene Zhulenev #include "PassDetail.h"
14*c30ab6c2SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
15*c30ab6c2SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h"
16*c30ab6c2SEugene Zhulenev #include "mlir/Dialect/SCF/SCF.h"
17*c30ab6c2SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h"
18*c30ab6c2SEugene Zhulenev #include "mlir/IR/BlockAndValueMapping.h"
19*c30ab6c2SEugene Zhulenev #include "mlir/IR/PatternMatch.h"
20*c30ab6c2SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21*c30ab6c2SEugene Zhulenev 
22*c30ab6c2SEugene Zhulenev using namespace mlir;
23*c30ab6c2SEugene Zhulenev using namespace mlir::async;
24*c30ab6c2SEugene Zhulenev 
25*c30ab6c2SEugene Zhulenev #define DEBUG_TYPE "async-parallel-for"
26*c30ab6c2SEugene Zhulenev 
27*c30ab6c2SEugene Zhulenev namespace {
28*c30ab6c2SEugene Zhulenev 
29*c30ab6c2SEugene Zhulenev // Rewrite scf.parallel operation into multiple concurrent async.execute
30*c30ab6c2SEugene Zhulenev // operations over non overlapping subranges of the original loop.
31*c30ab6c2SEugene Zhulenev //
32*c30ab6c2SEugene Zhulenev // Example:
33*c30ab6c2SEugene Zhulenev //
34*c30ab6c2SEugene Zhulenev //   scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
35*c30ab6c2SEugene Zhulenev //     "do_some_compute"(%i, %j): () -> ()
36*c30ab6c2SEugene Zhulenev //   }
37*c30ab6c2SEugene Zhulenev //
38*c30ab6c2SEugene Zhulenev // Converted to:
39*c30ab6c2SEugene Zhulenev //
40*c30ab6c2SEugene Zhulenev //   %c0 = constant 0 : index
41*c30ab6c2SEugene Zhulenev //   %c1 = constant 1 : index
42*c30ab6c2SEugene Zhulenev //
43*c30ab6c2SEugene Zhulenev //   // Compute blocks sizes for each induction variable.
44*c30ab6c2SEugene Zhulenev //   %num_blocks_i = ... : index
45*c30ab6c2SEugene Zhulenev //   %num_blocks_j = ... : index
46*c30ab6c2SEugene Zhulenev //   %block_size_i = ... : index
47*c30ab6c2SEugene Zhulenev //   %block_size_j = ... : index
48*c30ab6c2SEugene Zhulenev //
49*c30ab6c2SEugene Zhulenev //   // Create an async group to track async execute ops.
50*c30ab6c2SEugene Zhulenev //   %group = async.create_group
51*c30ab6c2SEugene Zhulenev //
52*c30ab6c2SEugene Zhulenev //   scf.for %bi = %c0 to %num_blocks_i step %c1 {
53*c30ab6c2SEugene Zhulenev //     %block_start_i = ... : index
54*c30ab6c2SEugene Zhulenev //     %block_end_i   = ... : index
55*c30ab6c2SEugene Zhulenev //
56*c30ab6c2SEugene Zhulenev //     scf.for %bj = %c0 to %num_blocks_j step %c1 {
57*c30ab6c2SEugene Zhulenev //       %block_start_j = ... : index
58*c30ab6c2SEugene Zhulenev //       %block_end_j   = ... : index
59*c30ab6c2SEugene Zhulenev //
60*c30ab6c2SEugene Zhulenev //       // Execute the body of original parallel operation for the current
61*c30ab6c2SEugene Zhulenev //       // block.
62*c30ab6c2SEugene Zhulenev //       %token = async.execute {
63*c30ab6c2SEugene Zhulenev //         scf.for %i = %block_start_i to %block_end_i step %si {
64*c30ab6c2SEugene Zhulenev //           scf.for %j = %block_start_j to %block_end_j step %sj {
65*c30ab6c2SEugene Zhulenev //             "do_some_compute"(%i, %j): () -> ()
66*c30ab6c2SEugene Zhulenev //           }
67*c30ab6c2SEugene Zhulenev //         }
68*c30ab6c2SEugene Zhulenev //       }
69*c30ab6c2SEugene Zhulenev //
70*c30ab6c2SEugene Zhulenev //       // Add produced async token to the group.
71*c30ab6c2SEugene Zhulenev //       async.add_to_group %token, %group
72*c30ab6c2SEugene Zhulenev //     }
73*c30ab6c2SEugene Zhulenev //   }
74*c30ab6c2SEugene Zhulenev //
75*c30ab6c2SEugene Zhulenev //   // Await completion of all async.execute operations.
76*c30ab6c2SEugene Zhulenev //   async.await_all %group
77*c30ab6c2SEugene Zhulenev //
78*c30ab6c2SEugene Zhulenev // In this example outer loop launches inner block level loops as separate async
79*c30ab6c2SEugene Zhulenev // execute operations which will be executed concurrently.
80*c30ab6c2SEugene Zhulenev //
81*c30ab6c2SEugene Zhulenev // At the end it waits for the completiom of all async execute operations.
82*c30ab6c2SEugene Zhulenev //
83*c30ab6c2SEugene Zhulenev struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
84*c30ab6c2SEugene Zhulenev public:
85*c30ab6c2SEugene Zhulenev   AsyncParallelForRewrite(MLIRContext *ctx, int numConcurrentAsyncExecute)
86*c30ab6c2SEugene Zhulenev       : OpRewritePattern(ctx),
87*c30ab6c2SEugene Zhulenev         numConcurrentAsyncExecute(numConcurrentAsyncExecute) {}
88*c30ab6c2SEugene Zhulenev 
89*c30ab6c2SEugene Zhulenev   LogicalResult matchAndRewrite(scf::ParallelOp op,
90*c30ab6c2SEugene Zhulenev                                 PatternRewriter &rewriter) const override;
91*c30ab6c2SEugene Zhulenev 
92*c30ab6c2SEugene Zhulenev private:
93*c30ab6c2SEugene Zhulenev   int numConcurrentAsyncExecute;
94*c30ab6c2SEugene Zhulenev };
95*c30ab6c2SEugene Zhulenev 
96*c30ab6c2SEugene Zhulenev struct AsyncParallelForPass
97*c30ab6c2SEugene Zhulenev     : public AsyncParallelForBase<AsyncParallelForPass> {
98*c30ab6c2SEugene Zhulenev   AsyncParallelForPass() = default;
99*c30ab6c2SEugene Zhulenev   void runOnFunction() override;
100*c30ab6c2SEugene Zhulenev };
101*c30ab6c2SEugene Zhulenev 
102*c30ab6c2SEugene Zhulenev } // namespace
103*c30ab6c2SEugene Zhulenev 
104*c30ab6c2SEugene Zhulenev LogicalResult
105*c30ab6c2SEugene Zhulenev AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
106*c30ab6c2SEugene Zhulenev                                          PatternRewriter &rewriter) const {
107*c30ab6c2SEugene Zhulenev   // We do not currently support rewrite for parallel op with reductions.
108*c30ab6c2SEugene Zhulenev   if (op.getNumReductions() != 0)
109*c30ab6c2SEugene Zhulenev     return failure();
110*c30ab6c2SEugene Zhulenev 
111*c30ab6c2SEugene Zhulenev   MLIRContext *ctx = op.getContext();
112*c30ab6c2SEugene Zhulenev   Location loc = op.getLoc();
113*c30ab6c2SEugene Zhulenev 
114*c30ab6c2SEugene Zhulenev   // Index constants used below.
115*c30ab6c2SEugene Zhulenev   auto indexTy = IndexType::get(ctx);
116*c30ab6c2SEugene Zhulenev   auto zero = IntegerAttr::get(indexTy, 0);
117*c30ab6c2SEugene Zhulenev   auto one = IntegerAttr::get(indexTy, 1);
118*c30ab6c2SEugene Zhulenev   auto c0 = rewriter.create<ConstantOp>(loc, indexTy, zero);
119*c30ab6c2SEugene Zhulenev   auto c1 = rewriter.create<ConstantOp>(loc, indexTy, one);
120*c30ab6c2SEugene Zhulenev 
121*c30ab6c2SEugene Zhulenev   // Shorthand for signed integer ceil division operation.
122*c30ab6c2SEugene Zhulenev   auto divup = [&](Value x, Value y) -> Value {
123*c30ab6c2SEugene Zhulenev     return rewriter.create<SignedCeilDivIOp>(loc, x, y);
124*c30ab6c2SEugene Zhulenev   };
125*c30ab6c2SEugene Zhulenev 
126*c30ab6c2SEugene Zhulenev   // Compute trip count for each loop induction variable:
127*c30ab6c2SEugene Zhulenev   //   tripCount = divUp(upperBound - lowerBound, step);
128*c30ab6c2SEugene Zhulenev   SmallVector<Value, 4> tripCounts(op.getNumLoops());
129*c30ab6c2SEugene Zhulenev   for (size_t i = 0; i < op.getNumLoops(); ++i) {
130*c30ab6c2SEugene Zhulenev     auto lb = op.lowerBound()[i];
131*c30ab6c2SEugene Zhulenev     auto ub = op.upperBound()[i];
132*c30ab6c2SEugene Zhulenev     auto step = op.step()[i];
133*c30ab6c2SEugene Zhulenev     auto range = rewriter.create<SubIOp>(loc, ub, lb);
134*c30ab6c2SEugene Zhulenev     tripCounts[i] = divup(range, step);
135*c30ab6c2SEugene Zhulenev   }
136*c30ab6c2SEugene Zhulenev 
137*c30ab6c2SEugene Zhulenev   // The target number of concurrent async.execute ops.
138*c30ab6c2SEugene Zhulenev   auto numExecuteOps = rewriter.create<ConstantOp>(
139*c30ab6c2SEugene Zhulenev       loc, indexTy, IntegerAttr::get(indexTy, numConcurrentAsyncExecute));
140*c30ab6c2SEugene Zhulenev 
141*c30ab6c2SEugene Zhulenev   // Blocks sizes configuration for each induction variable.
142*c30ab6c2SEugene Zhulenev 
143*c30ab6c2SEugene Zhulenev   // We try to use maximum available concurrency in outer dimensions first
144*c30ab6c2SEugene Zhulenev   // (assuming that parallel induction variables are corresponding to some
145*c30ab6c2SEugene Zhulenev   // multidimensional access, e.g. in (%d0, %d1, ..., %dn) = (<from>) to (<to>)
146*c30ab6c2SEugene Zhulenev   // we will try to parallelize iteration along the %d0. If %d0 is too small,
147*c30ab6c2SEugene Zhulenev   // we'll parallelize iteration over %d1, and so on.
148*c30ab6c2SEugene Zhulenev   SmallVector<Value, 4> targetNumBlocks(op.getNumLoops());
149*c30ab6c2SEugene Zhulenev   SmallVector<Value, 4> blockSize(op.getNumLoops());
150*c30ab6c2SEugene Zhulenev   SmallVector<Value, 4> numBlocks(op.getNumLoops());
151*c30ab6c2SEugene Zhulenev 
152*c30ab6c2SEugene Zhulenev   // Compute block size and number of blocks along the first induction variable.
153*c30ab6c2SEugene Zhulenev   targetNumBlocks[0] = numExecuteOps;
154*c30ab6c2SEugene Zhulenev   blockSize[0] = divup(tripCounts[0], targetNumBlocks[0]);
155*c30ab6c2SEugene Zhulenev   numBlocks[0] = divup(tripCounts[0], blockSize[0]);
156*c30ab6c2SEugene Zhulenev 
157*c30ab6c2SEugene Zhulenev   // Assign remaining available concurrency to other induction variables.
158*c30ab6c2SEugene Zhulenev   for (size_t i = 1; i < op.getNumLoops(); ++i) {
159*c30ab6c2SEugene Zhulenev     targetNumBlocks[i] = divup(targetNumBlocks[i - 1], numBlocks[i - 1]);
160*c30ab6c2SEugene Zhulenev     blockSize[i] = divup(tripCounts[i], targetNumBlocks[i]);
161*c30ab6c2SEugene Zhulenev     numBlocks[i] = divup(tripCounts[i], blockSize[i]);
162*c30ab6c2SEugene Zhulenev   }
163*c30ab6c2SEugene Zhulenev 
164*c30ab6c2SEugene Zhulenev   // Create an async.group to wait on all async tokens from async execute ops.
165*c30ab6c2SEugene Zhulenev   auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx));
166*c30ab6c2SEugene Zhulenev 
167*c30ab6c2SEugene Zhulenev   // Build a scf.for loop nest from the parallel operation.
168*c30ab6c2SEugene Zhulenev 
169*c30ab6c2SEugene Zhulenev   // Lower/upper bounds for nest block level computations.
170*c30ab6c2SEugene Zhulenev   SmallVector<Value, 4> blockLowerBounds(op.getNumLoops());
171*c30ab6c2SEugene Zhulenev   SmallVector<Value, 4> blockUpperBounds(op.getNumLoops());
172*c30ab6c2SEugene Zhulenev   SmallVector<Value, 4> blockInductionVars(op.getNumLoops());
173*c30ab6c2SEugene Zhulenev 
174*c30ab6c2SEugene Zhulenev   using LoopBodyBuilder =
175*c30ab6c2SEugene Zhulenev       std::function<void(OpBuilder &, Location, Value, ValueRange)>;
176*c30ab6c2SEugene Zhulenev   using LoopBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
177*c30ab6c2SEugene Zhulenev 
178*c30ab6c2SEugene Zhulenev   // Builds inner loop nest inside async.execute operation that does all the
179*c30ab6c2SEugene Zhulenev   // work concurrently.
180*c30ab6c2SEugene Zhulenev   LoopBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
181*c30ab6c2SEugene Zhulenev     return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) {
182*c30ab6c2SEugene Zhulenev       blockInductionVars[loopIdx] = iv;
183*c30ab6c2SEugene Zhulenev 
184*c30ab6c2SEugene Zhulenev       // Continute building async loop nest.
185*c30ab6c2SEugene Zhulenev       if (loopIdx < op.getNumLoops() - 1) {
186*c30ab6c2SEugene Zhulenev         b.create<scf::ForOp>(
187*c30ab6c2SEugene Zhulenev             loc, blockLowerBounds[loopIdx + 1], blockUpperBounds[loopIdx + 1],
188*c30ab6c2SEugene Zhulenev             op.step()[loopIdx + 1], ValueRange(), workLoopBuilder(loopIdx + 1));
189*c30ab6c2SEugene Zhulenev         b.create<scf::YieldOp>(loc);
190*c30ab6c2SEugene Zhulenev         return;
191*c30ab6c2SEugene Zhulenev       }
192*c30ab6c2SEugene Zhulenev 
193*c30ab6c2SEugene Zhulenev       // Copy the body of the parallel op with new loop bounds.
194*c30ab6c2SEugene Zhulenev       BlockAndValueMapping mapping;
195*c30ab6c2SEugene Zhulenev       mapping.map(op.getInductionVars(), blockInductionVars);
196*c30ab6c2SEugene Zhulenev 
197*c30ab6c2SEugene Zhulenev       for (auto &bodyOp : op.getLoopBody().getOps())
198*c30ab6c2SEugene Zhulenev         b.clone(bodyOp, mapping);
199*c30ab6c2SEugene Zhulenev     };
200*c30ab6c2SEugene Zhulenev   };
201*c30ab6c2SEugene Zhulenev 
202*c30ab6c2SEugene Zhulenev   // Builds a loop nest that does async execute op dispatching.
203*c30ab6c2SEugene Zhulenev   LoopBuilder asyncLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
204*c30ab6c2SEugene Zhulenev     return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) {
205*c30ab6c2SEugene Zhulenev       auto lb = op.lowerBound()[loopIdx];
206*c30ab6c2SEugene Zhulenev       auto ub = op.upperBound()[loopIdx];
207*c30ab6c2SEugene Zhulenev       auto step = op.step()[loopIdx];
208*c30ab6c2SEugene Zhulenev 
209*c30ab6c2SEugene Zhulenev       // Compute lower bound for the current block:
210*c30ab6c2SEugene Zhulenev       //   blockLowerBound = iv * blockSize * step + lowerBound
211*c30ab6c2SEugene Zhulenev       auto s0 = b.create<MulIOp>(loc, iv, blockSize[loopIdx]);
212*c30ab6c2SEugene Zhulenev       auto s1 = b.create<MulIOp>(loc, s0, step);
213*c30ab6c2SEugene Zhulenev       auto s2 = b.create<AddIOp>(loc, s1, lb);
214*c30ab6c2SEugene Zhulenev       blockLowerBounds[loopIdx] = s2;
215*c30ab6c2SEugene Zhulenev 
216*c30ab6c2SEugene Zhulenev       // Compute upper bound for the current block:
217*c30ab6c2SEugene Zhulenev       //   blockUpperBound = min(upperBound,
218*c30ab6c2SEugene Zhulenev       //                         blockLowerBound + blockSize * step)
219*c30ab6c2SEugene Zhulenev       auto e0 = b.create<MulIOp>(loc, blockSize[loopIdx], step);
220*c30ab6c2SEugene Zhulenev       auto e1 = b.create<AddIOp>(loc, e0, s2);
221*c30ab6c2SEugene Zhulenev       auto e2 = b.create<CmpIOp>(loc, CmpIPredicate::slt, e1, ub);
222*c30ab6c2SEugene Zhulenev       auto e3 = b.create<SelectOp>(loc, e2, e1, ub);
223*c30ab6c2SEugene Zhulenev       blockUpperBounds[loopIdx] = e3;
224*c30ab6c2SEugene Zhulenev 
225*c30ab6c2SEugene Zhulenev       // Continue building async dispatch loop nest.
226*c30ab6c2SEugene Zhulenev       if (loopIdx < op.getNumLoops() - 1) {
227*c30ab6c2SEugene Zhulenev         b.create<scf::ForOp>(loc, c0, numBlocks[loopIdx + 1], c1, ValueRange(),
228*c30ab6c2SEugene Zhulenev                              asyncLoopBuilder(loopIdx + 1));
229*c30ab6c2SEugene Zhulenev         b.create<scf::YieldOp>(loc);
230*c30ab6c2SEugene Zhulenev         return;
231*c30ab6c2SEugene Zhulenev       }
232*c30ab6c2SEugene Zhulenev 
233*c30ab6c2SEugene Zhulenev       // Build the inner loop nest that will do the actual work inside the
234*c30ab6c2SEugene Zhulenev       // `async.execute` body region.
235*c30ab6c2SEugene Zhulenev       auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
236*c30ab6c2SEugene Zhulenev                                     Location executeLoc,
237*c30ab6c2SEugene Zhulenev                                     ValueRange executeArgs) {
238*c30ab6c2SEugene Zhulenev         executeBuilder.create<scf::ForOp>(executeLoc, blockLowerBounds[0],
239*c30ab6c2SEugene Zhulenev                                           blockUpperBounds[0], op.step()[0],
240*c30ab6c2SEugene Zhulenev                                           ValueRange(), workLoopBuilder(0));
241*c30ab6c2SEugene Zhulenev         executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
242*c30ab6c2SEugene Zhulenev       };
243*c30ab6c2SEugene Zhulenev 
244*c30ab6c2SEugene Zhulenev       auto execute = b.create<ExecuteOp>(
245*c30ab6c2SEugene Zhulenev           loc, /*resultTypes=*/TypeRange(), /*dependencies=*/ValueRange(),
246*c30ab6c2SEugene Zhulenev           /*operands=*/ValueRange(), executeBodyBuilder);
247*c30ab6c2SEugene Zhulenev       auto rankType = IndexType::get(ctx);
248*c30ab6c2SEugene Zhulenev       b.create<AddToGroupOp>(loc, rankType, execute.token(), group.result());
249*c30ab6c2SEugene Zhulenev       b.create<scf::YieldOp>(loc);
250*c30ab6c2SEugene Zhulenev     };
251*c30ab6c2SEugene Zhulenev   };
252*c30ab6c2SEugene Zhulenev 
253*c30ab6c2SEugene Zhulenev   // Start building a loop nest from the first induction variable.
254*c30ab6c2SEugene Zhulenev   rewriter.create<scf::ForOp>(loc, c0, numBlocks[0], c1, ValueRange(),
255*c30ab6c2SEugene Zhulenev                               asyncLoopBuilder(0));
256*c30ab6c2SEugene Zhulenev 
257*c30ab6c2SEugene Zhulenev   // Wait for the completion of all subtasks.
258*c30ab6c2SEugene Zhulenev   rewriter.create<AwaitAllOp>(loc, group.result());
259*c30ab6c2SEugene Zhulenev 
260*c30ab6c2SEugene Zhulenev   // Erase the original parallel operation.
261*c30ab6c2SEugene Zhulenev   rewriter.eraseOp(op);
262*c30ab6c2SEugene Zhulenev 
263*c30ab6c2SEugene Zhulenev   return success();
264*c30ab6c2SEugene Zhulenev }
265*c30ab6c2SEugene Zhulenev 
266*c30ab6c2SEugene Zhulenev void AsyncParallelForPass::runOnFunction() {
267*c30ab6c2SEugene Zhulenev   MLIRContext *ctx = &getContext();
268*c30ab6c2SEugene Zhulenev 
269*c30ab6c2SEugene Zhulenev   OwningRewritePatternList patterns;
270*c30ab6c2SEugene Zhulenev   patterns.insert<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
271*c30ab6c2SEugene Zhulenev 
272*c30ab6c2SEugene Zhulenev   if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
273*c30ab6c2SEugene Zhulenev     signalPassFailure();
274*c30ab6c2SEugene Zhulenev }
275*c30ab6c2SEugene Zhulenev 
276*c30ab6c2SEugene Zhulenev std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
277*c30ab6c2SEugene Zhulenev   return std::make_unique<AsyncParallelForPass>();
278*c30ab6c2SEugene Zhulenev }
279