1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Async/IR/Async.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/RegionUtils.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 #define DEBUG_TYPE "convert-async-to-llvm"
25 
26 using namespace mlir;
27 using namespace mlir::async;
28 
29 // Prefix for functions outlined from `async.execute` op regions.
30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
37 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
38 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
39 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
42 static constexpr const char *kAddTokenToGroup =
43     "mlirAsyncRuntimeAddTokenToGroup";
44 static constexpr const char *kAwaitAndExecute =
45     "mlirAsyncRuntimeAwaitTokenAndExecute";
46 static constexpr const char *kAwaitAllAndExecute =
47     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
48 
49 namespace {
50 // Async Runtime API function types.
51 struct AsyncAPI {
52   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
53     return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
54   }
55 
56   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
57     return FunctionType::get({}, {GroupType::get(ctx)}, ctx);
58   }
59 
60   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
61     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
62   }
63 
64   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
65     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
66   }
67 
68   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
69     return FunctionType::get({GroupType::get(ctx)}, {}, ctx);
70   }
71 
72   static FunctionType executeFunctionType(MLIRContext *ctx) {
73     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
74     auto resume = resumeFunctionType(ctx).getPointerTo();
75     return FunctionType::get({hdl, resume}, {}, ctx);
76   }
77 
78   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
79     auto i64 = IntegerType::get(64, ctx);
80     return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64},
81                              ctx);
82   }
83 
84   static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
85     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
86     auto resume = resumeFunctionType(ctx).getPointerTo();
87     return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
88   }
89 
90   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
91     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
92     auto resume = resumeFunctionType(ctx).getPointerTo();
93     return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx);
94   }
95 
96   // Auxiliary coroutine resume intrinsic wrapper.
97   static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
98     auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
99     auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
100     return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false);
101   }
102 };
103 } // namespace
104 
105 // Adds Async Runtime C API declarations to the module.
106 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
107   auto builder = OpBuilder::atBlockTerminator(module.getBody());
108 
109   MLIRContext *ctx = module.getContext();
110   Location loc = module.getLoc();
111 
112   if (!module.lookupSymbol(kCreateToken))
113     builder.create<FuncOp>(loc, kCreateToken,
114                            AsyncAPI::createTokenFunctionType(ctx));
115 
116   if (!module.lookupSymbol(kCreateGroup))
117     builder.create<FuncOp>(loc, kCreateGroup,
118                            AsyncAPI::createGroupFunctionType(ctx));
119 
120   if (!module.lookupSymbol(kEmplaceToken))
121     builder.create<FuncOp>(loc, kEmplaceToken,
122                            AsyncAPI::emplaceTokenFunctionType(ctx));
123 
124   if (!module.lookupSymbol(kAwaitToken))
125     builder.create<FuncOp>(loc, kAwaitToken,
126                            AsyncAPI::awaitTokenFunctionType(ctx));
127 
128   if (!module.lookupSymbol(kAwaitGroup))
129     builder.create<FuncOp>(loc, kAwaitGroup,
130                            AsyncAPI::awaitGroupFunctionType(ctx));
131 
132   if (!module.lookupSymbol(kExecute))
133     builder.create<FuncOp>(loc, kExecute, AsyncAPI::executeFunctionType(ctx));
134 
135   if (!module.lookupSymbol(kAddTokenToGroup))
136     builder.create<FuncOp>(loc, kAddTokenToGroup,
137                            AsyncAPI::addTokenToGroupFunctionType(ctx));
138 
139   if (!module.lookupSymbol(kAwaitAndExecute))
140     builder.create<FuncOp>(loc, kAwaitAndExecute,
141                            AsyncAPI::awaitAndExecuteFunctionType(ctx));
142 
143   if (!module.lookupSymbol(kAwaitAllAndExecute))
144     builder.create<FuncOp>(loc, kAwaitAllAndExecute,
145                            AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // LLVM coroutines intrinsics declarations.
150 //===----------------------------------------------------------------------===//
151 
152 static constexpr const char *kCoroId = "llvm.coro.id";
153 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
154 static constexpr const char *kCoroBegin = "llvm.coro.begin";
155 static constexpr const char *kCoroSave = "llvm.coro.save";
156 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
157 static constexpr const char *kCoroEnd = "llvm.coro.end";
158 static constexpr const char *kCoroFree = "llvm.coro.free";
159 static constexpr const char *kCoroResume = "llvm.coro.resume";
160 
161 /// Adds coroutine intrinsics declarations to the module.
162 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
163   using namespace mlir::LLVM;
164 
165   MLIRContext *ctx = module.getContext();
166   Location loc = module.getLoc();
167 
168   OpBuilder builder(module.getBody()->getTerminator());
169 
170   auto token = LLVMTokenType::get(ctx);
171   auto voidTy = LLVMType::getVoidTy(ctx);
172 
173   auto i8 = LLVMType::getInt8Ty(ctx);
174   auto i1 = LLVMType::getInt1Ty(ctx);
175   auto i32 = LLVMType::getInt32Ty(ctx);
176   auto i64 = LLVMType::getInt64Ty(ctx);
177   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
178 
179   if (!module.lookupSymbol(kCoroId))
180     builder.create<LLVMFuncOp>(
181         loc, kCoroId,
182         LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false));
183 
184   if (!module.lookupSymbol(kCoroSizeI64))
185     builder.create<LLVMFuncOp>(loc, kCoroSizeI64,
186                                LLVMType::getFunctionTy(i64, false));
187 
188   if (!module.lookupSymbol(kCoroBegin))
189     builder.create<LLVMFuncOp>(
190         loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
191 
192   if (!module.lookupSymbol(kCoroSave))
193     builder.create<LLVMFuncOp>(loc, kCoroSave,
194                                LLVMType::getFunctionTy(token, i8Ptr, false));
195 
196   if (!module.lookupSymbol(kCoroSuspend))
197     builder.create<LLVMFuncOp>(loc, kCoroSuspend,
198                                LLVMType::getFunctionTy(i8, {token, i1}, false));
199 
200   if (!module.lookupSymbol(kCoroEnd))
201     builder.create<LLVMFuncOp>(loc, kCoroEnd,
202                                LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false));
203 
204   if (!module.lookupSymbol(kCoroFree))
205     builder.create<LLVMFuncOp>(
206         loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
207 
208   if (!module.lookupSymbol(kCoroResume))
209     builder.create<LLVMFuncOp>(loc, kCoroResume,
210                                LLVMType::getFunctionTy(voidTy, i8Ptr, false));
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // Add malloc/free declarations to the module.
215 //===----------------------------------------------------------------------===//
216 
217 static constexpr const char *kMalloc = "malloc";
218 static constexpr const char *kFree = "free";
219 
220 /// Adds malloc/free declarations to the module.
221 static void addCRuntimeDeclarations(ModuleOp module) {
222   using namespace mlir::LLVM;
223 
224   MLIRContext *ctx = module.getContext();
225   Location loc = module.getLoc();
226 
227   OpBuilder builder(module.getBody()->getTerminator());
228 
229   auto voidTy = LLVMType::getVoidTy(ctx);
230   auto i64 = LLVMType::getInt64Ty(ctx);
231   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
232 
233   if (!module.lookupSymbol(kMalloc))
234     builder.create<LLVM::LLVMFuncOp>(
235         loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false));
236 
237   if (!module.lookupSymbol(kFree))
238     builder.create<LLVM::LLVMFuncOp>(
239         loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false));
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // Coroutine resume function wrapper.
244 //===----------------------------------------------------------------------===//
245 
246 static constexpr const char *kResume = "__resume";
247 
248 // A function that takes a coroutine handle and calls a `llvm.coro.resume`
249 // intrinsics. We need this function to be able to pass it to the async
250 // runtime execute API.
251 static void addResumeFunction(ModuleOp module) {
252   MLIRContext *ctx = module.getContext();
253 
254   OpBuilder moduleBuilder(module.getBody()->getTerminator());
255   Location loc = module.getLoc();
256 
257   if (module.lookupSymbol(kResume))
258     return;
259 
260   auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
261   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
262 
263   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
264       loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
265   SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private);
266 
267   auto *block = resumeOp.addEntryBlock();
268   OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
269 
270   blockBuilder.create<LLVM::CallOp>(loc, Type(),
271                                     blockBuilder.getSymbolRefAttr(kCoroResume),
272                                     resumeOp.getArgument(0));
273 
274   blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // async.execute op outlining to the coroutine functions.
279 //===----------------------------------------------------------------------===//
280 
281 // Function targeted for coroutine transformation has two additional blocks at
282 // the end: coroutine cleanup and coroutine suspension.
283 //
284 // async.await op lowering additionaly creates a resume block for each
285 // operation to enable non-blocking waiting via coroutine suspension.
286 namespace {
287 struct CoroMachinery {
288   Value asyncToken;
289   Value coroHandle;
290   Block *cleanup;
291   Block *suspend;
292 };
293 } // namespace
294 
295 // Builds an coroutine template compatible with LLVM coroutines lowering.
296 //
297 //  - `entry` block sets up the coroutine.
298 //  - `cleanup` block cleans up the coroutine state.
299 //  - `suspend block after the @llvm.coro.end() defines what value will be
300 //    returned to the initial caller of a coroutine. Everything before the
301 //    @llvm.coro.end() will be executed at every suspension point.
302 //
303 // Coroutine structure (only the important bits):
304 //
305 //   func @async_execute_fn(<function-arguments>) -> !async.token {
306 //     ^entryBlock(<function-arguments>):
307 //       %token = <async token> : !async.token // create async runtime token
308 //       %hdl = llvm.call @llvm.coro.id(...)   // create a coroutine handle
309 //       br ^cleanup
310 //
311 //     ^cleanup:
312 //       llvm.call @llvm.coro.free(...)        // delete coroutine state
313 //       br ^suspend
314 //
315 //     ^suspend:
316 //       llvm.call @llvm.coro.end(...)         // marks the end of a coroutine
317 //       return %token : !async.token
318 //   }
319 //
320 // The actual code for the async.execute operation body region will be inserted
321 // before the entry block terminator.
322 //
323 //
324 static CoroMachinery setupCoroMachinery(FuncOp func) {
325   assert(func.getBody().empty() && "Function must have empty body");
326 
327   MLIRContext *ctx = func.getContext();
328 
329   auto token = LLVM::LLVMTokenType::get(ctx);
330   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
331   auto i32 = LLVM::LLVMType::getInt32Ty(ctx);
332   auto i64 = LLVM::LLVMType::getInt64Ty(ctx);
333   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
334 
335   Block *entryBlock = func.addEntryBlock();
336   Location loc = func.getBody().getLoc();
337 
338   OpBuilder builder = OpBuilder::atBlockBegin(entryBlock);
339 
340   // ------------------------------------------------------------------------ //
341   // Allocate async tokens/values that we will return from a ramp function.
342   // ------------------------------------------------------------------------ //
343   auto createToken =
344       builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx));
345 
346   // ------------------------------------------------------------------------ //
347   // Initialize coroutine: allocate frame, get coroutine handle.
348   // ------------------------------------------------------------------------ //
349 
350   // Constants for initializing coroutine frame.
351   auto constZero =
352       builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0));
353   auto constFalse =
354       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
355   auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr);
356 
357   // Get coroutine id: @llvm.coro.id
358   auto coroId = builder.create<LLVM::CallOp>(
359       loc, token, builder.getSymbolRefAttr(kCoroId),
360       ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
361 
362   // Get coroutine frame size: @llvm.coro.size.i64
363   auto coroSize = builder.create<LLVM::CallOp>(
364       loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
365 
366   // Allocate memory for coroutine frame.
367   auto coroAlloc = builder.create<LLVM::CallOp>(
368       loc, i8Ptr, builder.getSymbolRefAttr(kMalloc),
369       ValueRange(coroSize.getResult(0)));
370 
371   // Begin a coroutine: @llvm.coro.begin
372   auto coroHdl = builder.create<LLVM::CallOp>(
373       loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
374       ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
375 
376   Block *cleanupBlock = func.addBlock();
377   Block *suspendBlock = func.addBlock();
378 
379   // ------------------------------------------------------------------------ //
380   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
381   // ------------------------------------------------------------------------ //
382   builder.setInsertionPointToStart(cleanupBlock);
383 
384   // Get a pointer to the coroutine frame memory: @llvm.coro.free.
385   auto coroMem = builder.create<LLVM::CallOp>(
386       loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree),
387       ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
388 
389   // Free the memory.
390   builder.create<LLVM::CallOp>(loc, Type(), builder.getSymbolRefAttr(kFree),
391                                ValueRange(coroMem.getResult(0)));
392   // Branch into the suspend block.
393   builder.create<BranchOp>(loc, suspendBlock);
394 
395   // ------------------------------------------------------------------------ //
396   // Coroutine suspend block: mark the end of a coroutine and return allocated
397   // async token.
398   // ------------------------------------------------------------------------ //
399   builder.setInsertionPointToStart(suspendBlock);
400 
401   // Mark the end of a coroutine: @llvm.coro.end.
402   builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd),
403                                ValueRange({coroHdl.getResult(0), constFalse}));
404 
405   // Return created `async.token` from the suspend block. This will be the
406   // return value of a coroutine ramp function.
407   builder.create<ReturnOp>(loc, createToken.getResult(0));
408 
409   // Branch from the entry block to the cleanup block to create a valid CFG.
410   builder.setInsertionPointToEnd(entryBlock);
411 
412   builder.create<BranchOp>(loc, cleanupBlock);
413 
414   // `async.await` op lowering will create resume blocks for async
415   // continuations, and will conditionally branch to cleanup or suspend blocks.
416 
417   return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
418           suspendBlock};
419 }
420 
421 // Adds a suspension point before the `op`, and moves `op` and all operations
422 // after it into the resume block. Returns a pointer to the resume block.
423 //
424 // `coroState` must be a value returned from the call to @llvm.coro.save(...)
425 // intrinsic (saved coroutine state).
426 //
427 // Before:
428 //
429 //   ^bb0:
430 //     "opBefore"(...)
431 //     "op"(...)
432 //   ^cleanup: ...
433 //   ^suspend: ...
434 //
435 // After:
436 //
437 //   ^bb0:
438 //     "opBefore"(...)
439 //     %suspend = llmv.call @llvm.coro.suspend(...)
440 //     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
441 //   ^resume:
442 //     "op"(...)
443 //   ^cleanup: ...
444 //   ^suspend: ...
445 //
446 static Block *addSuspensionPoint(CoroMachinery coro, Value coroState,
447                                  Operation *op) {
448   MLIRContext *ctx = op->getContext();
449   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
450   auto i8 = LLVM::LLVMType::getInt8Ty(ctx);
451 
452   Location loc = op->getLoc();
453   Block *splitBlock = op->getBlock();
454 
455   // Split the block before `op`, newly added block is the resume block.
456   Block *resume = splitBlock->splitBlock(op);
457 
458   // Add a coroutine suspension in place of original `op` in the split block.
459   OpBuilder builder = OpBuilder::atBlockEnd(splitBlock);
460 
461   auto constFalse =
462       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
463 
464   // Suspend a coroutine: @llvm.coro.suspend
465   auto coroSuspend = builder.create<LLVM::CallOp>(
466       loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
467       ValueRange({coroState, constFalse}));
468 
469   // After a suspension point decide if we should branch into resume, cleanup
470   // or suspend block of the coroutine (see @llvm.coro.suspend return code
471   // documentation).
472   auto constZero =
473       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
474   auto constNegOne =
475       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
476 
477   Block *resumeOrCleanup = builder.createBlock(resume);
478 
479   // Suspend the coroutine ...?
480   builder.setInsertionPointToEnd(splitBlock);
481   auto isNegOne = builder.create<LLVM::ICmpOp>(
482       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
483   builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
484                                  /*falseDest=*/resumeOrCleanup);
485 
486   // ... or resume or cleanup the coroutine?
487   builder.setInsertionPointToStart(resumeOrCleanup);
488   auto isZero = builder.create<LLVM::ICmpOp>(
489       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
490   builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
491                                  /*falseDest=*/coro.cleanup);
492 
493   return resume;
494 }
495 
496 // Outline the body region attached to the `async.execute` op into a standalone
497 // function.
498 static std::pair<FuncOp, CoroMachinery>
499 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
500   ModuleOp module = execute.getParentOfType<ModuleOp>();
501 
502   MLIRContext *ctx = module.getContext();
503   Location loc = execute.getLoc();
504 
505   OpBuilder moduleBuilder(module.getBody()->getTerminator());
506 
507   // Collect all outlined function inputs.
508   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
509                                               execute.dependencies().end());
510   getUsedValuesDefinedAbove(execute.body(), functionInputs);
511 
512   // Collect types for the outlined function inputs and outputs.
513   auto typesRange = llvm::map_range(
514       functionInputs, [](Value value) { return value.getType(); });
515   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
516   auto outputTypes = execute.getResultTypes();
517 
518   auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
519   auto funcAttrs = ArrayRef<NamedAttribute>();
520 
521   // TODO: Derive outlined function name from the parent FuncOp (support
522   // multiple nested async.execute operations).
523   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
524   symbolTable.insert(func, moduleBuilder.getInsertionPoint());
525 
526   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
527 
528   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
529   // blocks, adding llvm.coro instrinsics and setting up control flow.
530   CoroMachinery coro = setupCoroMachinery(func);
531 
532   // Suspend async function at the end of an entry block, and resume it using
533   // Async execute API (execution will be resumed in a thread managed by the
534   // async runtime).
535   Block *entryBlock = &func.getBlocks().front();
536   OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock);
537 
538   // A pointer to coroutine resume intrinsic wrapper.
539   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
540   auto resumePtr = builder.create<LLVM::AddressOfOp>(
541       loc, resumeFnTy.getPointerTo(), kResume);
542 
543   // Save the coroutine state: @llvm.coro.save
544   auto coroSave = builder.create<LLVM::CallOp>(
545       loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
546       ValueRange({coro.coroHandle}));
547 
548   // Call async runtime API to execute a coroutine in the managed thread.
549   SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
550   builder.create<CallOp>(loc, Type(), kExecute, executeArgs);
551 
552   // Split the entry block before the terminator.
553   Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
554                                      entryBlock->getTerminator());
555 
556   // Await on all dependencies before starting to execute the body region.
557   builder.setInsertionPointToStart(resume);
558   for (size_t i = 0; i < execute.dependencies().size(); ++i)
559     builder.create<AwaitOp>(loc, func.getArgument(i));
560 
561   // Map from function inputs defined above the execute op to the function
562   // arguments.
563   BlockAndValueMapping valueMapping;
564   valueMapping.map(functionInputs, func.getArguments());
565 
566   // Clone all operations from the execute operation body into the outlined
567   // function body, and replace all `async.yield` operations with a call
568   // to async runtime to emplace the result token.
569   for (Operation &op : execute.body().getOps()) {
570     if (isa<async::YieldOp>(op)) {
571       builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken);
572       continue;
573     }
574     builder.clone(op, valueMapping);
575   }
576 
577   // Replace the original `async.execute` with a call to outlined function.
578   OpBuilder callBuilder(execute);
579   auto callOutlinedFunc =
580       callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(),
581                                  functionInputs.getArrayRef());
582   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
583   execute.erase();
584 
585   return {func, coro};
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // Convert Async dialect types to LLVM types.
590 //===----------------------------------------------------------------------===//
591 
592 namespace {
593 class AsyncRuntimeTypeConverter : public TypeConverter {
594 public:
595   AsyncRuntimeTypeConverter() { addConversion(convertType); }
596 
597   static Type convertType(Type type) {
598     MLIRContext *ctx = type.getContext();
599     // Convert async tokens and groups to opaque pointers.
600     if (type.isa<TokenType, GroupType>())
601       return LLVM::LLVMType::getInt8PtrTy(ctx);
602     return type;
603   }
604 };
605 } // namespace
606 
607 //===----------------------------------------------------------------------===//
608 // Convert types for all call operations to lowered async types.
609 //===----------------------------------------------------------------------===//
610 
611 namespace {
612 class CallOpOpConversion : public ConversionPattern {
613 public:
614   explicit CallOpOpConversion(MLIRContext *ctx)
615       : ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
616 
617   LogicalResult
618   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
619                   ConversionPatternRewriter &rewriter) const override {
620     AsyncRuntimeTypeConverter converter;
621 
622     SmallVector<Type, 5> resultTypes;
623     converter.convertTypes(op->getResultTypes(), resultTypes);
624 
625     CallOp call = cast<CallOp>(op);
626     rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
627                                         call.getOperands());
628 
629     return success();
630   }
631 };
632 } // namespace
633 
634 //===----------------------------------------------------------------------===//
635 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
636 //===----------------------------------------------------------------------===//
637 
638 namespace {
639 class CreateGroupOpLowering : public ConversionPattern {
640 public:
641   explicit CreateGroupOpLowering(MLIRContext *ctx)
642       : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {}
643 
644   LogicalResult
645   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
646                   ConversionPatternRewriter &rewriter) const override {
647     auto retTy = GroupType::get(op->getContext());
648     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy);
649     return success();
650   }
651 };
652 } // namespace
653 
654 //===----------------------------------------------------------------------===//
655 // async.add_to_group op lowering to runtime function call.
656 //===----------------------------------------------------------------------===//
657 
658 namespace {
659 class AddToGroupOpLowering : public ConversionPattern {
660 public:
661   explicit AddToGroupOpLowering(MLIRContext *ctx)
662       : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {}
663 
664   LogicalResult
665   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
666                   ConversionPatternRewriter &rewriter) const override {
667     // Currently we can only add tokens to the group.
668     auto addToGroup = cast<AddToGroupOp>(op);
669     if (!addToGroup.operand().getType().isa<TokenType>())
670       return failure();
671 
672     auto i64 = IntegerType::get(64, op->getContext());
673     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
674     return success();
675   }
676 };
677 } // namespace
678 
679 //===----------------------------------------------------------------------===//
680 // async.await and async.await_all op lowerings to the corresponding async
681 // runtime function calls.
682 //===----------------------------------------------------------------------===//
683 
684 namespace {
685 
686 template <typename AwaitType, typename AwaitableType>
687 class AwaitOpLoweringBase : public ConversionPattern {
688 protected:
689   explicit AwaitOpLoweringBase(
690       MLIRContext *ctx,
691       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
692       StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
693       : ConversionPattern(AwaitType::getOperationName(), 1, ctx),
694         outlinedFunctions(outlinedFunctions),
695         blockingAwaitFuncName(blockingAwaitFuncName),
696         coroAwaitFuncName(coroAwaitFuncName) {}
697 
698 public:
699   LogicalResult
700   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
701                   ConversionPatternRewriter &rewriter) const override {
702     // We can only await on one the `AwaitableType` (for `await` it can be
703     // only a `token`, for `await_all` it is a `group`).
704     auto await = cast<AwaitType>(op);
705     if (!await.operand().getType().template isa<AwaitableType>())
706       return failure();
707 
708     // Check if await operation is inside the outlined coroutine function.
709     auto func = await.template getParentOfType<FuncOp>();
710     auto outlined = outlinedFunctions.find(func);
711     const bool isInCoroutine = outlined != outlinedFunctions.end();
712 
713     Location loc = op->getLoc();
714 
715     // Inside regular function we convert await operation to the blocking
716     // async API await function call.
717     if (!isInCoroutine)
718       rewriter.create<CallOp>(loc, Type(), blockingAwaitFuncName,
719                               ValueRange(op->getOperand(0)));
720 
721     // Inside the coroutine we convert await operation into coroutine suspension
722     // point, and resume execution asynchronously.
723     if (isInCoroutine) {
724       const CoroMachinery &coro = outlined->getSecond();
725 
726       OpBuilder builder(op);
727       MLIRContext *ctx = op->getContext();
728 
729       // A pointer to coroutine resume intrinsic wrapper.
730       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
731       auto resumePtr = builder.create<LLVM::AddressOfOp>(
732           loc, resumeFnTy.getPointerTo(), kResume);
733 
734       // Save the coroutine state: @llvm.coro.save
735       auto coroSave = builder.create<LLVM::CallOp>(
736           loc, LLVM::LLVMTokenType::get(ctx),
737           builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle));
738 
739       // Call async runtime API to resume a coroutine in the managed thread when
740       // the async await argument becomes ready.
741       SmallVector<Value, 3> awaitAndExecuteArgs = {
742           await.getOperand(), coro.coroHandle, resumePtr.res()};
743       builder.create<CallOp>(loc, Type(), coroAwaitFuncName,
744                              awaitAndExecuteArgs);
745 
746       // Split the entry block before the await operation.
747       addSuspensionPoint(coro, coroSave.getResult(0), op);
748     }
749 
750     // Original operation was replaced by function call or suspension point.
751     rewriter.eraseOp(op);
752 
753     return success();
754   }
755 
756 private:
757   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
758   StringRef blockingAwaitFuncName;
759   StringRef coroAwaitFuncName;
760 };
761 
762 // Lowering for `async.await` operation (only token operands are supported).
763 class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
764   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
765 
766 public:
767   explicit AwaitOpLowering(
768       MLIRContext *ctx,
769       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
770       : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {}
771 };
772 
773 // Lowering for `async.await_all` operation.
774 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
775   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
776 
777 public:
778   explicit AwaitAllOpLowering(
779       MLIRContext *ctx,
780       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
781       : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {}
782 };
783 
784 } // namespace
785 
786 //===----------------------------------------------------------------------===//
787 
788 namespace {
789 struct ConvertAsyncToLLVMPass
790     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
791   void runOnOperation() override;
792 };
793 
794 void ConvertAsyncToLLVMPass::runOnOperation() {
795   ModuleOp module = getOperation();
796   SymbolTable symbolTable(module);
797 
798   // Outline all `async.execute` body regions into async functions (coroutines).
799   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
800 
801   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
802     // We currently do not support execute operations that have async value
803     // operands or produce async results.
804     if (!execute.operands().empty() || !execute.results().empty()) {
805       execute.emitOpError("can't outline async.execute op with async value "
806                           "operands or returned async results");
807       return WalkResult::interrupt();
808     }
809 
810     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
811 
812     return WalkResult::advance();
813   });
814 
815   // Failed to outline all async execute operations.
816   if (outlineResult.wasInterrupted()) {
817     signalPassFailure();
818     return;
819   }
820 
821   LLVM_DEBUG({
822     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
823                  << " async functions\n";
824   });
825 
826   // Add declarations for all functions required by the coroutines lowering.
827   addResumeFunction(module);
828   addAsyncRuntimeApiDeclarations(module);
829   addCoroutineIntrinsicsDeclarations(module);
830   addCRuntimeDeclarations(module);
831 
832   MLIRContext *ctx = &getContext();
833 
834   // Convert async dialect types and operations to LLVM dialect.
835   AsyncRuntimeTypeConverter converter;
836   OwningRewritePatternList patterns;
837 
838   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
839   patterns.insert<CallOpOpConversion>(ctx);
840   patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
841   patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
842 
843   ConversionTarget target(*ctx);
844   target.addLegalDialect<LLVM::LLVMDialect>();
845   target.addIllegalDialect<AsyncDialect>();
846   target.addDynamicallyLegalOp<FuncOp>(
847       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
848   target.addDynamicallyLegalOp<CallOp>(
849       [&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
850 
851   if (failed(applyPartialConversion(module, target, std::move(patterns))))
852     signalPassFailure();
853 }
854 } // namespace
855 
856 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
857   return std::make_unique<ConvertAsyncToLLVMPass>();
858 }
859