1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Async/IR/Async.h"
16 #include "mlir/Dialect/Async/Passes.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/SetVector.h"
24 
25 using namespace mlir;
26 using namespace mlir::async;
27 
28 #define DEBUG_TYPE "async-to-async-runtime"
29 // Prefix for functions outlined from `async.execute` op regions.
30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
31 
32 namespace {
33 
34 class AsyncToAsyncRuntimePass
35     : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
36 public:
37   AsyncToAsyncRuntimePass() = default;
38   void runOnOperation() override;
39 };
40 
41 } // namespace
42 
43 //===----------------------------------------------------------------------===//
44 // async.execute op outlining to the coroutine functions.
45 //===----------------------------------------------------------------------===//
46 
47 /// Function targeted for coroutine transformation has two additional blocks at
48 /// the end: coroutine cleanup and coroutine suspension.
49 ///
50 /// async.await op lowering additionaly creates a resume block for each
51 /// operation to enable non-blocking waiting via coroutine suspension.
52 namespace {
53 struct CoroMachinery {
54   // Async execute region returns a completion token, and an async value for
55   // each yielded value.
56   //
57   //   %token, %result = async.execute -> !async.value<T> {
58   //     %0 = constant ... : T
59   //     async.yield %0 : T
60   //   }
61   Value asyncToken; // token representing completion of the async region
62   llvm::SmallVector<Value, 4> returnValues; // returned async values
63 
64   Value coroHandle; // coroutine handle (!async.coro.handle value)
65   Block *cleanup;   // coroutine cleanup block
66   Block *suspend;   // coroutine suspension block
67 };
68 } // namespace
69 
70 /// Builds an coroutine template compatible with LLVM coroutines switched-resume
71 /// lowering using `async.runtime.*` and `async.coro.*` operations.
72 ///
73 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
74 ///
75 ///  - `entry` block sets up the coroutine.
76 ///  - `cleanup` block cleans up the coroutine state.
77 ///  - `suspend block after the @llvm.coro.end() defines what value will be
78 ///    returned to the initial caller of a coroutine. Everything before the
79 ///    @llvm.coro.end() will be executed at every suspension point.
80 ///
81 /// Coroutine structure (only the important bits):
82 ///
83 ///   func @async_execute_fn(<function-arguments>)
84 ///        -> (!async.token, !async.value<T>)
85 ///   {
86 ///     ^entry(<function-arguments>):
87 ///       %token = <async token> : !async.token    // create async runtime token
88 ///       %value = <async value> : !async.value<T> // create async value
89 ///       %id = async.coro.id                      // create a coroutine id
90 ///       %hdl = async.coro.begin %id              // create a coroutine handle
91 ///       br ^cleanup
92 ///
93 ///     ^cleanup:
94 ///       async.coro.free %hdl // delete the coroutine state
95 ///       br ^suspend
96 ///
97 ///     ^suspend:
98 ///       async.coro.end %hdl // marks the end of a coroutine
99 ///       return %token, %value : !async.token, !async.value<T>
100 ///   }
101 ///
102 /// The actual code for the async.execute operation body region will be inserted
103 /// before the entry block terminator.
104 ///
105 ///
106 static CoroMachinery setupCoroMachinery(FuncOp func) {
107   assert(func.getBody().empty() && "Function must have empty body");
108 
109   MLIRContext *ctx = func.getContext();
110   Block *entryBlock = func.addEntryBlock();
111 
112   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
113 
114   // ------------------------------------------------------------------------ //
115   // Allocate async token/values that we will return from a ramp function.
116   // ------------------------------------------------------------------------ //
117   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
118 
119   llvm::SmallVector<Value, 4> retValues;
120   for (auto resType : func.getCallableResults().drop_front())
121     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
122 
123   // ------------------------------------------------------------------------ //
124   // Initialize coroutine: get coroutine id and coroutine handle.
125   // ------------------------------------------------------------------------ //
126   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
127   auto coroHdlOp =
128       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
129 
130   Block *cleanupBlock = func.addBlock();
131   Block *suspendBlock = func.addBlock();
132 
133   // ------------------------------------------------------------------------ //
134   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
135   // ------------------------------------------------------------------------ //
136   builder.setInsertionPointToStart(cleanupBlock);
137   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
138 
139   // Branch into the suspend block.
140   builder.create<BranchOp>(suspendBlock);
141 
142   // ------------------------------------------------------------------------ //
143   // Coroutine suspend block: mark the end of a coroutine and return allocated
144   // async token.
145   // ------------------------------------------------------------------------ //
146   builder.setInsertionPointToStart(suspendBlock);
147 
148   // Mark the end of a coroutine: async.coro.end
149   builder.create<CoroEndOp>(coroHdlOp.handle());
150 
151   // Return created `async.token` and `async.values` from the suspend block.
152   // This will be the return value of a coroutine ramp function.
153   SmallVector<Value, 4> ret{retToken};
154   ret.insert(ret.end(), retValues.begin(), retValues.end());
155   builder.create<ReturnOp>(ret);
156 
157   // Branch from the entry block to the cleanup block to create a valid CFG.
158   builder.setInsertionPointToEnd(entryBlock);
159   builder.create<BranchOp>(cleanupBlock);
160 
161   // `async.await` op lowering will create resume blocks for async
162   // continuations, and will conditionally branch to cleanup or suspend blocks.
163 
164   CoroMachinery machinery;
165   machinery.asyncToken = retToken;
166   machinery.returnValues = retValues;
167   machinery.coroHandle = coroHdlOp.handle();
168   machinery.cleanup = cleanupBlock;
169   machinery.suspend = suspendBlock;
170   return machinery;
171 }
172 
173 /// Outline the body region attached to the `async.execute` op into a standalone
174 /// function.
175 ///
176 /// Note that this is not reversible transformation.
177 static std::pair<FuncOp, CoroMachinery>
178 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
179   ModuleOp module = execute->getParentOfType<ModuleOp>();
180 
181   MLIRContext *ctx = module.getContext();
182   Location loc = execute.getLoc();
183 
184   // Collect all outlined function inputs.
185   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
186                                               execute.dependencies().end());
187   functionInputs.insert(execute.operands().begin(), execute.operands().end());
188   getUsedValuesDefinedAbove(execute.body(), functionInputs);
189 
190   // Collect types for the outlined function inputs and outputs.
191   auto typesRange = llvm::map_range(
192       functionInputs, [](Value value) { return value.getType(); });
193   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
194   auto outputTypes = execute.getResultTypes();
195 
196   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
197   auto funcAttrs = ArrayRef<NamedAttribute>();
198 
199   // TODO: Derive outlined function name from the parent FuncOp (support
200   // multiple nested async.execute operations).
201   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
202   symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
203 
204   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
205 
206   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
207   // blocks, adding async.coro operations and setting up control flow.
208   CoroMachinery coro = setupCoroMachinery(func);
209 
210   // Suspend async function at the end of an entry block, and resume it using
211   // Async resume operation (execution will be resumed in a thread managed by
212   // the async runtime).
213   Block *entryBlock = &func.getBlocks().front();
214   auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
215 
216   // Save the coroutine state: async.coro.save
217   auto coroSaveOp =
218       builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
219 
220   // Pass coroutine to the runtime to be resumed on a runtime managed thread.
221   builder.create<RuntimeResumeOp>(coro.coroHandle);
222 
223   // Split the entry block before the terminator (branch to suspend block).
224   auto *terminatorOp = entryBlock->getTerminator();
225   Block *suspended = terminatorOp->getBlock();
226   Block *resume = suspended->splitBlock(terminatorOp);
227 
228   // Add async.coro.suspend as a suspended block terminator.
229   builder.setInsertionPointToEnd(suspended);
230   builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
231                                 coro.cleanup);
232 
233   size_t numDependencies = execute.dependencies().size();
234   size_t numOperands = execute.operands().size();
235 
236   // Await on all dependencies before starting to execute the body region.
237   builder.setInsertionPointToStart(resume);
238   for (size_t i = 0; i < numDependencies; ++i)
239     builder.create<AwaitOp>(func.getArgument(i));
240 
241   // Await on all async value operands and unwrap the payload.
242   SmallVector<Value, 4> unwrappedOperands(numOperands);
243   for (size_t i = 0; i < numOperands; ++i) {
244     Value operand = func.getArgument(numDependencies + i);
245     unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
246   }
247 
248   // Map from function inputs defined above the execute op to the function
249   // arguments.
250   BlockAndValueMapping valueMapping;
251   valueMapping.map(functionInputs, func.getArguments());
252   valueMapping.map(execute.body().getArguments(), unwrappedOperands);
253 
254   // Clone all operations from the execute operation body into the outlined
255   // function body.
256   for (Operation &op : execute.body().getOps())
257     builder.clone(op, valueMapping);
258 
259   // Replace the original `async.execute` with a call to outlined function.
260   ImplicitLocOpBuilder callBuilder(loc, execute);
261   auto callOutlinedFunc = callBuilder.create<CallOp>(
262       func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
263   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
264   execute.erase();
265 
266   return {func, coro};
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // Convert async.create_group operation to async.runtime.create
271 //===----------------------------------------------------------------------===//
272 
273 namespace {
274 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
275 public:
276   using OpConversionPattern::OpConversionPattern;
277 
278   LogicalResult
279   matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
280                   ConversionPatternRewriter &rewriter) const override {
281     rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
282         op, GroupType::get(op->getContext()));
283     return success();
284   }
285 };
286 } // namespace
287 
288 //===----------------------------------------------------------------------===//
289 // Convert async.add_to_group operation to async.runtime.add_to_group.
290 //===----------------------------------------------------------------------===//
291 
292 namespace {
293 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
294 public:
295   using OpConversionPattern::OpConversionPattern;
296 
297   LogicalResult
298   matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands,
299                   ConversionPatternRewriter &rewriter) const override {
300     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
301         op, rewriter.getIndexType(), operands);
302     return success();
303   }
304 };
305 } // namespace
306 
307 //===----------------------------------------------------------------------===//
308 // Convert async.await and async.await_all operations to the async.runtime.await
309 // or async.runtime.await_and_resume operations.
310 //===----------------------------------------------------------------------===//
311 
312 namespace {
313 template <typename AwaitType, typename AwaitableType>
314 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
315   using AwaitAdaptor = typename AwaitType::Adaptor;
316 
317 public:
318   AwaitOpLoweringBase(
319       MLIRContext *ctx,
320       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
321       : OpConversionPattern<AwaitType>(ctx),
322         outlinedFunctions(outlinedFunctions) {}
323 
324   LogicalResult
325   matchAndRewrite(AwaitType op, ArrayRef<Value> operands,
326                   ConversionPatternRewriter &rewriter) const override {
327     // We can only await on one the `AwaitableType` (for `await` it can be
328     // a `token` or a `value`, for `await_all` it must be a `group`).
329     if (!op.operand().getType().template isa<AwaitableType>())
330       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
331 
332     // Check if await operation is inside the outlined coroutine function.
333     auto func = op->template getParentOfType<FuncOp>();
334     auto outlined = outlinedFunctions.find(func);
335     const bool isInCoroutine = outlined != outlinedFunctions.end();
336 
337     Location loc = op->getLoc();
338     Value operand = AwaitAdaptor(operands).operand();
339 
340     // Inside regular functions we use the blocking wait operation to wait for
341     // the async object (token, value or group) to become available.
342     if (!isInCoroutine)
343       rewriter.create<RuntimeAwaitOp>(loc, operand);
344 
345     // Inside the coroutine we convert await operation into coroutine suspension
346     // point, and resume execution asynchronously.
347     if (isInCoroutine) {
348       const CoroMachinery &coro = outlined->getSecond();
349       Block *suspended = op->getBlock();
350 
351       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
352       MLIRContext *ctx = op->getContext();
353 
354       // Save the coroutine state and resume on a runtime managed thread when
355       // the operand becomes available.
356       auto coroSaveOp =
357           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
358       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
359 
360       // Split the entry block before the await operation.
361       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
362 
363       // Add async.coro.suspend as a suspended block terminator.
364       builder.setInsertionPointToEnd(suspended);
365       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
366                                     coro.cleanup);
367 
368       // Make sure that replacement value will be constructed in resume block.
369       rewriter.setInsertionPointToStart(resume);
370     }
371 
372     // Erase or replace the await operation with the new value.
373     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
374       rewriter.replaceOp(op, replaceWith);
375     else
376       rewriter.eraseOp(op);
377 
378     return success();
379   }
380 
381   virtual Value getReplacementValue(AwaitType op, Value operand,
382                                     ConversionPatternRewriter &rewriter) const {
383     return Value();
384   }
385 
386 private:
387   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
388 };
389 
390 /// Lowering for `async.await` with a token operand.
391 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
392   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
393 
394 public:
395   using Base::Base;
396 };
397 
398 /// Lowering for `async.await` with a value operand.
399 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
400   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
401 
402 public:
403   using Base::Base;
404 
405   Value
406   getReplacementValue(AwaitOp op, Value operand,
407                       ConversionPatternRewriter &rewriter) const override {
408     // Load from the async value storage.
409     auto valueType = operand.getType().cast<ValueType>().getValueType();
410     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
411   }
412 };
413 
414 /// Lowering for `async.await_all` operation.
415 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
416   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
417 
418 public:
419   using Base::Base;
420 };
421 
422 } // namespace
423 
424 //===----------------------------------------------------------------------===//
425 // Convert async.yield operation to async.runtime operations.
426 //===----------------------------------------------------------------------===//
427 
428 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
429 public:
430   YieldOpLowering(
431       MLIRContext *ctx,
432       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
433       : OpConversionPattern<async::YieldOp>(ctx),
434         outlinedFunctions(outlinedFunctions) {}
435 
436   LogicalResult
437   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
438                   ConversionPatternRewriter &rewriter) const override {
439     // Check if yield operation is inside the outlined coroutine function.
440     auto func = op->template getParentOfType<FuncOp>();
441     auto outlined = outlinedFunctions.find(func);
442     if (outlined == outlinedFunctions.end())
443       return rewriter.notifyMatchFailure(
444           op, "operation is not inside the outlined async.execute function");
445 
446     Location loc = op->getLoc();
447     const CoroMachinery &coro = outlined->getSecond();
448 
449     // Store yielded values into the async values storage and switch async
450     // values state to available.
451     for (auto tuple : llvm::zip(operands, coro.returnValues)) {
452       Value yieldValue = std::get<0>(tuple);
453       Value asyncValue = std::get<1>(tuple);
454       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
455       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
456     }
457 
458     // Switch the coroutine completion token to available state.
459     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
460 
461     return success();
462   }
463 
464 private:
465   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
466 };
467 
468 //===----------------------------------------------------------------------===//
469 
470 void AsyncToAsyncRuntimePass::runOnOperation() {
471   ModuleOp module = getOperation();
472   SymbolTable symbolTable(module);
473 
474   // Outline all `async.execute` body regions into async functions (coroutines).
475   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
476 
477   module.walk([&](ExecuteOp execute) {
478     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
479   });
480 
481   LLVM_DEBUG({
482     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
483                  << " functions built from async.execute operations\n";
484   });
485 
486   // Lower async operations to async.runtime operations.
487   MLIRContext *ctx = module->getContext();
488   OwningRewritePatternList asyncPatterns;
489 
490   // Async lowering does not use type converter because it must preserve all
491   // types for async.runtime operations.
492   asyncPatterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
493   asyncPatterns.insert<AwaitTokenOpLowering, AwaitValueOpLowering,
494                        AwaitAllOpLowering, YieldOpLowering>(ctx,
495                                                             outlinedFunctions);
496 
497   // All high level async operations must be lowered to the runtime operations.
498   ConversionTarget runtimeTarget(*ctx);
499   runtimeTarget.addLegalDialect<AsyncDialect>();
500   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
501   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
502 
503   if (failed(applyPartialConversion(module, runtimeTarget,
504                                     std::move(asyncPatterns)))) {
505     signalPassFailure();
506     return;
507   }
508 }
509 
510 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
511   return std::make_unique<AsyncToAsyncRuntimePass>();
512 }
513