1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Async/IR/Async.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/RegionUtils.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 #define DEBUG_TYPE "convert-async-to-llvm"
25 
26 using namespace mlir;
27 using namespace mlir::async;
28 
29 // Prefix for functions outlined from `async.execute` op regions.
30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
39 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
40 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
41 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
42 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
43 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
44 static constexpr const char *kAddTokenToGroup =
45     "mlirAsyncRuntimeAddTokenToGroup";
46 static constexpr const char *kAwaitAndExecute =
47     "mlirAsyncRuntimeAwaitTokenAndExecute";
48 static constexpr const char *kAwaitAllAndExecute =
49     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
50 
51 namespace {
52 // Async Runtime API function types.
53 struct AsyncAPI {
54   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
55     auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
56     auto count = IntegerType::get(32, ctx);
57     return FunctionType::get({ref, count}, {}, ctx);
58   }
59 
60   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
61     return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
62   }
63 
64   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
65     return FunctionType::get({}, {GroupType::get(ctx)}, ctx);
66   }
67 
68   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
69     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
70   }
71 
72   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
73     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
74   }
75 
76   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
77     return FunctionType::get({GroupType::get(ctx)}, {}, ctx);
78   }
79 
80   static FunctionType executeFunctionType(MLIRContext *ctx) {
81     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
82     auto resume = resumeFunctionType(ctx).getPointerTo();
83     return FunctionType::get({hdl, resume}, {}, ctx);
84   }
85 
86   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
87     auto i64 = IntegerType::get(64, ctx);
88     return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64},
89                              ctx);
90   }
91 
92   static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
93     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
94     auto resume = resumeFunctionType(ctx).getPointerTo();
95     return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
96   }
97 
98   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
99     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
100     auto resume = resumeFunctionType(ctx).getPointerTo();
101     return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx);
102   }
103 
104   // Auxiliary coroutine resume intrinsic wrapper.
105   static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
106     auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
107     auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
108     return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false);
109   }
110 };
111 } // namespace
112 
113 // Adds Async Runtime C API declarations to the module.
114 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
115   auto builder = OpBuilder::atBlockTerminator(module.getBody());
116 
117   auto addFuncDecl = [&](StringRef name, FunctionType type) {
118     if (module.lookupSymbol(name))
119       return;
120     builder.create<FuncOp>(module.getLoc(), name, type).setPrivate();
121   };
122 
123   MLIRContext *ctx = module.getContext();
124   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
125   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
126   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
127   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
128   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
129   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
130   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
131   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
132   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
133   addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
134   addFuncDecl(kAwaitAllAndExecute,
135               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // LLVM coroutines intrinsics declarations.
140 //===----------------------------------------------------------------------===//
141 
142 static constexpr const char *kCoroId = "llvm.coro.id";
143 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
144 static constexpr const char *kCoroBegin = "llvm.coro.begin";
145 static constexpr const char *kCoroSave = "llvm.coro.save";
146 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
147 static constexpr const char *kCoroEnd = "llvm.coro.end";
148 static constexpr const char *kCoroFree = "llvm.coro.free";
149 static constexpr const char *kCoroResume = "llvm.coro.resume";
150 
151 /// Adds an LLVM function declaration to a module.
152 static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name,
153                             LLVM::LLVMType ret,
154                             ArrayRef<LLVM::LLVMType> params) {
155   if (module.lookupSymbol(name))
156     return;
157   LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false);
158   builder.create<LLVM::LLVMFuncOp>(module.getLoc(), name, type);
159 }
160 
161 /// Adds coroutine intrinsics declarations to the module.
162 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
163   using namespace mlir::LLVM;
164 
165   MLIRContext *ctx = module.getContext();
166   OpBuilder builder(module.getBody()->getTerminator());
167 
168   auto token = LLVMTokenType::get(ctx);
169   auto voidTy = LLVMType::getVoidTy(ctx);
170 
171   auto i8 = LLVMType::getInt8Ty(ctx);
172   auto i1 = LLVMType::getInt1Ty(ctx);
173   auto i32 = LLVMType::getInt32Ty(ctx);
174   auto i64 = LLVMType::getInt64Ty(ctx);
175   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
176 
177   addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
178   addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
179   addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
180   addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
181   addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
182   addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
183   addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
184   addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // Add malloc/free declarations to the module.
189 //===----------------------------------------------------------------------===//
190 
191 static constexpr const char *kMalloc = "malloc";
192 static constexpr const char *kFree = "free";
193 
194 /// Adds malloc/free declarations to the module.
195 static void addCRuntimeDeclarations(ModuleOp module) {
196   using namespace mlir::LLVM;
197 
198   MLIRContext *ctx = module.getContext();
199   OpBuilder builder(module.getBody()->getTerminator());
200 
201   auto voidTy = LLVMType::getVoidTy(ctx);
202   auto i64 = LLVMType::getInt64Ty(ctx);
203   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
204 
205   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
206   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // Coroutine resume function wrapper.
211 //===----------------------------------------------------------------------===//
212 
213 static constexpr const char *kResume = "__resume";
214 
215 // A function that takes a coroutine handle and calls a `llvm.coro.resume`
216 // intrinsics. We need this function to be able to pass it to the async
217 // runtime execute API.
218 static void addResumeFunction(ModuleOp module) {
219   MLIRContext *ctx = module.getContext();
220 
221   OpBuilder moduleBuilder(module.getBody()->getTerminator());
222   Location loc = module.getLoc();
223 
224   if (module.lookupSymbol(kResume))
225     return;
226 
227   auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
228   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
229 
230   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
231       loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false));
232   resumeOp.setPrivate();
233 
234   auto *block = resumeOp.addEntryBlock();
235   OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
236 
237   blockBuilder.create<LLVM::CallOp>(loc, TypeRange(),
238                                     blockBuilder.getSymbolRefAttr(kCoroResume),
239                                     resumeOp.getArgument(0));
240 
241   blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // async.execute op outlining to the coroutine functions.
246 //===----------------------------------------------------------------------===//
247 
248 // Function targeted for coroutine transformation has two additional blocks at
249 // the end: coroutine cleanup and coroutine suspension.
250 //
251 // async.await op lowering additionaly creates a resume block for each
252 // operation to enable non-blocking waiting via coroutine suspension.
253 namespace {
254 struct CoroMachinery {
255   Value asyncToken;
256   Value coroHandle;
257   Block *cleanup;
258   Block *suspend;
259 };
260 } // namespace
261 
262 // Builds an coroutine template compatible with LLVM coroutines lowering.
263 //
264 //  - `entry` block sets up the coroutine.
265 //  - `cleanup` block cleans up the coroutine state.
266 //  - `suspend block after the @llvm.coro.end() defines what value will be
267 //    returned to the initial caller of a coroutine. Everything before the
268 //    @llvm.coro.end() will be executed at every suspension point.
269 //
270 // Coroutine structure (only the important bits):
271 //
272 //   func @async_execute_fn(<function-arguments>) -> !async.token {
273 //     ^entryBlock(<function-arguments>):
274 //       %token = <async token> : !async.token // create async runtime token
275 //       %hdl = llvm.call @llvm.coro.id(...)   // create a coroutine handle
276 //       br ^cleanup
277 //
278 //     ^cleanup:
279 //       llvm.call @llvm.coro.free(...)        // delete coroutine state
280 //       br ^suspend
281 //
282 //     ^suspend:
283 //       llvm.call @llvm.coro.end(...)         // marks the end of a coroutine
284 //       return %token : !async.token
285 //   }
286 //
287 // The actual code for the async.execute operation body region will be inserted
288 // before the entry block terminator.
289 //
290 //
291 static CoroMachinery setupCoroMachinery(FuncOp func) {
292   assert(func.getBody().empty() && "Function must have empty body");
293 
294   MLIRContext *ctx = func.getContext();
295 
296   auto token = LLVM::LLVMTokenType::get(ctx);
297   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
298   auto i32 = LLVM::LLVMType::getInt32Ty(ctx);
299   auto i64 = LLVM::LLVMType::getInt64Ty(ctx);
300   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
301 
302   Block *entryBlock = func.addEntryBlock();
303   Location loc = func.getBody().getLoc();
304 
305   OpBuilder builder = OpBuilder::atBlockBegin(entryBlock);
306 
307   // ------------------------------------------------------------------------ //
308   // Allocate async tokens/values that we will return from a ramp function.
309   // ------------------------------------------------------------------------ //
310   auto createToken =
311       builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx));
312 
313   // ------------------------------------------------------------------------ //
314   // Initialize coroutine: allocate frame, get coroutine handle.
315   // ------------------------------------------------------------------------ //
316 
317   // Constants for initializing coroutine frame.
318   auto constZero =
319       builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0));
320   auto constFalse =
321       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
322   auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr);
323 
324   // Get coroutine id: @llvm.coro.id
325   auto coroId = builder.create<LLVM::CallOp>(
326       loc, token, builder.getSymbolRefAttr(kCoroId),
327       ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
328 
329   // Get coroutine frame size: @llvm.coro.size.i64
330   auto coroSize = builder.create<LLVM::CallOp>(
331       loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
332 
333   // Allocate memory for coroutine frame.
334   auto coroAlloc = builder.create<LLVM::CallOp>(
335       loc, i8Ptr, builder.getSymbolRefAttr(kMalloc),
336       ValueRange(coroSize.getResult(0)));
337 
338   // Begin a coroutine: @llvm.coro.begin
339   auto coroHdl = builder.create<LLVM::CallOp>(
340       loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
341       ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
342 
343   Block *cleanupBlock = func.addBlock();
344   Block *suspendBlock = func.addBlock();
345 
346   // ------------------------------------------------------------------------ //
347   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
348   // ------------------------------------------------------------------------ //
349   builder.setInsertionPointToStart(cleanupBlock);
350 
351   // Get a pointer to the coroutine frame memory: @llvm.coro.free.
352   auto coroMem = builder.create<LLVM::CallOp>(
353       loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree),
354       ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
355 
356   // Free the memory.
357   builder.create<LLVM::CallOp>(loc, TypeRange(),
358                                builder.getSymbolRefAttr(kFree),
359                                ValueRange(coroMem.getResult(0)));
360   // Branch into the suspend block.
361   builder.create<BranchOp>(loc, suspendBlock);
362 
363   // ------------------------------------------------------------------------ //
364   // Coroutine suspend block: mark the end of a coroutine and return allocated
365   // async token.
366   // ------------------------------------------------------------------------ //
367   builder.setInsertionPointToStart(suspendBlock);
368 
369   // Mark the end of a coroutine: @llvm.coro.end.
370   builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd),
371                                ValueRange({coroHdl.getResult(0), constFalse}));
372 
373   // Return created `async.token` from the suspend block. This will be the
374   // return value of a coroutine ramp function.
375   builder.create<ReturnOp>(loc, createToken.getResult(0));
376 
377   // Branch from the entry block to the cleanup block to create a valid CFG.
378   builder.setInsertionPointToEnd(entryBlock);
379 
380   builder.create<BranchOp>(loc, cleanupBlock);
381 
382   // `async.await` op lowering will create resume blocks for async
383   // continuations, and will conditionally branch to cleanup or suspend blocks.
384 
385   return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
386           suspendBlock};
387 }
388 
389 // Adds a suspension point before the `op`, and moves `op` and all operations
390 // after it into the resume block. Returns a pointer to the resume block.
391 //
392 // `coroState` must be a value returned from the call to @llvm.coro.save(...)
393 // intrinsic (saved coroutine state).
394 //
395 // Before:
396 //
397 //   ^bb0:
398 //     "opBefore"(...)
399 //     "op"(...)
400 //   ^cleanup: ...
401 //   ^suspend: ...
402 //
403 // After:
404 //
405 //   ^bb0:
406 //     "opBefore"(...)
407 //     %suspend = llmv.call @llvm.coro.suspend(...)
408 //     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
409 //   ^resume:
410 //     "op"(...)
411 //   ^cleanup: ...
412 //   ^suspend: ...
413 //
414 static Block *addSuspensionPoint(CoroMachinery coro, Value coroState,
415                                  Operation *op) {
416   MLIRContext *ctx = op->getContext();
417   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
418   auto i8 = LLVM::LLVMType::getInt8Ty(ctx);
419 
420   Location loc = op->getLoc();
421   Block *splitBlock = op->getBlock();
422 
423   // Split the block before `op`, newly added block is the resume block.
424   Block *resume = splitBlock->splitBlock(op);
425 
426   // Add a coroutine suspension in place of original `op` in the split block.
427   OpBuilder builder = OpBuilder::atBlockEnd(splitBlock);
428 
429   auto constFalse =
430       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
431 
432   // Suspend a coroutine: @llvm.coro.suspend
433   auto coroSuspend = builder.create<LLVM::CallOp>(
434       loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
435       ValueRange({coroState, constFalse}));
436 
437   // After a suspension point decide if we should branch into resume, cleanup
438   // or suspend block of the coroutine (see @llvm.coro.suspend return code
439   // documentation).
440   auto constZero =
441       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
442   auto constNegOne =
443       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
444 
445   Block *resumeOrCleanup = builder.createBlock(resume);
446 
447   // Suspend the coroutine ...?
448   builder.setInsertionPointToEnd(splitBlock);
449   auto isNegOne = builder.create<LLVM::ICmpOp>(
450       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
451   builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
452                                  /*falseDest=*/resumeOrCleanup);
453 
454   // ... or resume or cleanup the coroutine?
455   builder.setInsertionPointToStart(resumeOrCleanup);
456   auto isZero = builder.create<LLVM::ICmpOp>(
457       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
458   builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
459                                  /*falseDest=*/coro.cleanup);
460 
461   return resume;
462 }
463 
464 // Outline the body region attached to the `async.execute` op into a standalone
465 // function.
466 static std::pair<FuncOp, CoroMachinery>
467 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
468   ModuleOp module = execute.getParentOfType<ModuleOp>();
469 
470   MLIRContext *ctx = module.getContext();
471   Location loc = execute.getLoc();
472 
473   OpBuilder moduleBuilder(module.getBody()->getTerminator());
474 
475   // Collect all outlined function inputs.
476   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
477                                               execute.dependencies().end());
478   getUsedValuesDefinedAbove(execute.body(), functionInputs);
479 
480   // Collect types for the outlined function inputs and outputs.
481   auto typesRange = llvm::map_range(
482       functionInputs, [](Value value) { return value.getType(); });
483   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
484   auto outputTypes = execute.getResultTypes();
485 
486   auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
487   auto funcAttrs = ArrayRef<NamedAttribute>();
488 
489   // TODO: Derive outlined function name from the parent FuncOp (support
490   // multiple nested async.execute operations).
491   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
492   symbolTable.insert(func, moduleBuilder.getInsertionPoint());
493 
494   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
495 
496   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
497   // blocks, adding llvm.coro instrinsics and setting up control flow.
498   CoroMachinery coro = setupCoroMachinery(func);
499 
500   // Suspend async function at the end of an entry block, and resume it using
501   // Async execute API (execution will be resumed in a thread managed by the
502   // async runtime).
503   Block *entryBlock = &func.getBlocks().front();
504   OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock);
505 
506   // A pointer to coroutine resume intrinsic wrapper.
507   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
508   auto resumePtr = builder.create<LLVM::AddressOfOp>(
509       loc, resumeFnTy.getPointerTo(), kResume);
510 
511   // Save the coroutine state: @llvm.coro.save
512   auto coroSave = builder.create<LLVM::CallOp>(
513       loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
514       ValueRange({coro.coroHandle}));
515 
516   // Call async runtime API to execute a coroutine in the managed thread.
517   SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
518   builder.create<CallOp>(loc, TypeRange(), kExecute, executeArgs);
519 
520   // Split the entry block before the terminator.
521   Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
522                                      entryBlock->getTerminator());
523 
524   // Await on all dependencies before starting to execute the body region.
525   builder.setInsertionPointToStart(resume);
526   for (size_t i = 0; i < execute.dependencies().size(); ++i)
527     builder.create<AwaitOp>(loc, func.getArgument(i));
528 
529   // Map from function inputs defined above the execute op to the function
530   // arguments.
531   BlockAndValueMapping valueMapping;
532   valueMapping.map(functionInputs, func.getArguments());
533 
534   // Clone all operations from the execute operation body into the outlined
535   // function body, and replace all `async.yield` operations with a call
536   // to async runtime to emplace the result token.
537   for (Operation &op : execute.body().getOps()) {
538     if (isa<async::YieldOp>(op)) {
539       builder.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
540       continue;
541     }
542     builder.clone(op, valueMapping);
543   }
544 
545   // Replace the original `async.execute` with a call to outlined function.
546   OpBuilder callBuilder(execute);
547   auto callOutlinedFunc =
548       callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(),
549                                  functionInputs.getArrayRef());
550   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
551   execute.erase();
552 
553   return {func, coro};
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // Convert Async dialect types to LLVM types.
558 //===----------------------------------------------------------------------===//
559 
560 namespace {
561 class AsyncRuntimeTypeConverter : public TypeConverter {
562 public:
563   AsyncRuntimeTypeConverter() { addConversion(convertType); }
564 
565   static Type convertType(Type type) {
566     MLIRContext *ctx = type.getContext();
567     // Convert async tokens and groups to opaque pointers.
568     if (type.isa<TokenType, GroupType>())
569       return LLVM::LLVMType::getInt8PtrTy(ctx);
570     return type;
571   }
572 };
573 } // namespace
574 
575 //===----------------------------------------------------------------------===//
576 // Convert types for all call operations to lowered async types.
577 //===----------------------------------------------------------------------===//
578 
579 namespace {
580 class CallOpOpConversion : public ConversionPattern {
581 public:
582   explicit CallOpOpConversion(MLIRContext *ctx)
583       : ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
584 
585   LogicalResult
586   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
587                   ConversionPatternRewriter &rewriter) const override {
588     AsyncRuntimeTypeConverter converter;
589 
590     SmallVector<Type, 5> resultTypes;
591     converter.convertTypes(op->getResultTypes(), resultTypes);
592 
593     CallOp call = cast<CallOp>(op);
594     rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
595                                         operands);
596 
597     return success();
598   }
599 };
600 } // namespace
601 
602 //===----------------------------------------------------------------------===//
603 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref`
604 // to the corresponding API calls).
605 //===----------------------------------------------------------------------===//
606 
607 namespace {
608 
609 template <typename RefCountingOp>
610 class RefCountingOpLowering : public ConversionPattern {
611 public:
612   explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName)
613       : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx),
614         apiFunctionName(apiFunctionName) {}
615 
616   LogicalResult
617   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
618                   ConversionPatternRewriter &rewriter) const override {
619     RefCountingOp refCountingOp = cast<RefCountingOp>(op);
620 
621     auto count = rewriter.create<ConstantOp>(
622         op->getLoc(), rewriter.getI32Type(),
623         rewriter.getI32IntegerAttr(refCountingOp.count()));
624 
625     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
626                                         ValueRange({operands[0], count}));
627 
628     return success();
629   }
630 
631 private:
632   StringRef apiFunctionName;
633 };
634 
635 // async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
636 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
637 public:
638   explicit AddRefOpLowering(MLIRContext *ctx)
639       : RefCountingOpLowering(ctx, kAddRef) {}
640 };
641 
642 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
643 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
644 public:
645   explicit DropRefOpLowering(MLIRContext *ctx)
646       : RefCountingOpLowering(ctx, kDropRef) {}
647 };
648 
649 } // namespace
650 
651 //===----------------------------------------------------------------------===//
652 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
653 //===----------------------------------------------------------------------===//
654 
655 namespace {
656 class CreateGroupOpLowering : public ConversionPattern {
657 public:
658   explicit CreateGroupOpLowering(MLIRContext *ctx)
659       : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {}
660 
661   LogicalResult
662   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
663                   ConversionPatternRewriter &rewriter) const override {
664     auto retTy = GroupType::get(op->getContext());
665     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy);
666     return success();
667   }
668 };
669 } // namespace
670 
671 //===----------------------------------------------------------------------===//
672 // async.add_to_group op lowering to runtime function call.
673 //===----------------------------------------------------------------------===//
674 
675 namespace {
676 class AddToGroupOpLowering : public ConversionPattern {
677 public:
678   explicit AddToGroupOpLowering(MLIRContext *ctx)
679       : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {}
680 
681   LogicalResult
682   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
683                   ConversionPatternRewriter &rewriter) const override {
684     // Currently we can only add tokens to the group.
685     auto addToGroup = cast<AddToGroupOp>(op);
686     if (!addToGroup.operand().getType().isa<TokenType>())
687       return failure();
688 
689     auto i64 = IntegerType::get(64, op->getContext());
690     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
691     return success();
692   }
693 };
694 } // namespace
695 
696 //===----------------------------------------------------------------------===//
697 // async.await and async.await_all op lowerings to the corresponding async
698 // runtime function calls.
699 //===----------------------------------------------------------------------===//
700 
701 namespace {
702 
703 template <typename AwaitType, typename AwaitableType>
704 class AwaitOpLoweringBase : public ConversionPattern {
705 protected:
706   explicit AwaitOpLoweringBase(
707       MLIRContext *ctx,
708       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
709       StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
710       : ConversionPattern(AwaitType::getOperationName(), 1, ctx),
711         outlinedFunctions(outlinedFunctions),
712         blockingAwaitFuncName(blockingAwaitFuncName),
713         coroAwaitFuncName(coroAwaitFuncName) {}
714 
715 public:
716   LogicalResult
717   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
718                   ConversionPatternRewriter &rewriter) const override {
719     // We can only await on one the `AwaitableType` (for `await` it can be
720     // only a `token`, for `await_all` it is a `group`).
721     auto await = cast<AwaitType>(op);
722     if (!await.operand().getType().template isa<AwaitableType>())
723       return failure();
724 
725     // Check if await operation is inside the outlined coroutine function.
726     auto func = await.template getParentOfType<FuncOp>();
727     auto outlined = outlinedFunctions.find(func);
728     const bool isInCoroutine = outlined != outlinedFunctions.end();
729 
730     Location loc = op->getLoc();
731 
732     // Inside regular function we convert await operation to the blocking
733     // async API await function call.
734     if (!isInCoroutine)
735       rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName,
736                               ValueRange(operands[0]));
737 
738     // Inside the coroutine we convert await operation into coroutine suspension
739     // point, and resume execution asynchronously.
740     if (isInCoroutine) {
741       const CoroMachinery &coro = outlined->getSecond();
742 
743       OpBuilder builder(op);
744       MLIRContext *ctx = op->getContext();
745 
746       // A pointer to coroutine resume intrinsic wrapper.
747       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
748       auto resumePtr = builder.create<LLVM::AddressOfOp>(
749           loc, resumeFnTy.getPointerTo(), kResume);
750 
751       // Save the coroutine state: @llvm.coro.save
752       auto coroSave = builder.create<LLVM::CallOp>(
753           loc, LLVM::LLVMTokenType::get(ctx),
754           builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle));
755 
756       // Call async runtime API to resume a coroutine in the managed thread when
757       // the async await argument becomes ready.
758       SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle,
759                                                    resumePtr.res()};
760       builder.create<CallOp>(loc, TypeRange(), coroAwaitFuncName,
761                              awaitAndExecuteArgs);
762 
763       // Split the entry block before the await operation.
764       addSuspensionPoint(coro, coroSave.getResult(0), op);
765     }
766 
767     // Original operation was replaced by function call or suspension point.
768     rewriter.eraseOp(op);
769 
770     return success();
771   }
772 
773 private:
774   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
775   StringRef blockingAwaitFuncName;
776   StringRef coroAwaitFuncName;
777 };
778 
779 // Lowering for `async.await` operation (only token operands are supported).
780 class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
781   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
782 
783 public:
784   explicit AwaitOpLowering(
785       MLIRContext *ctx,
786       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
787       : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {}
788 };
789 
790 // Lowering for `async.await_all` operation.
791 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
792   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
793 
794 public:
795   explicit AwaitAllOpLowering(
796       MLIRContext *ctx,
797       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
798       : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {}
799 };
800 
801 } // namespace
802 
803 //===----------------------------------------------------------------------===//
804 
805 namespace {
806 struct ConvertAsyncToLLVMPass
807     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
808   void runOnOperation() override;
809 };
810 
811 void ConvertAsyncToLLVMPass::runOnOperation() {
812   ModuleOp module = getOperation();
813   SymbolTable symbolTable(module);
814 
815   // Outline all `async.execute` body regions into async functions (coroutines).
816   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
817 
818   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
819     // We currently do not support execute operations that have async value
820     // operands or produce async results.
821     if (!execute.operands().empty() || !execute.results().empty()) {
822       execute.emitOpError("can't outline async.execute op with async value "
823                           "operands or returned async results");
824       return WalkResult::interrupt();
825     }
826 
827     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
828 
829     return WalkResult::advance();
830   });
831 
832   // Failed to outline all async execute operations.
833   if (outlineResult.wasInterrupted()) {
834     signalPassFailure();
835     return;
836   }
837 
838   LLVM_DEBUG({
839     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
840                  << " async functions\n";
841   });
842 
843   // Add declarations for all functions required by the coroutines lowering.
844   addResumeFunction(module);
845   addAsyncRuntimeApiDeclarations(module);
846   addCoroutineIntrinsicsDeclarations(module);
847   addCRuntimeDeclarations(module);
848 
849   MLIRContext *ctx = &getContext();
850 
851   // Convert async dialect types and operations to LLVM dialect.
852   AsyncRuntimeTypeConverter converter;
853   OwningRewritePatternList patterns;
854 
855   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
856   patterns.insert<CallOpOpConversion>(ctx);
857   patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx);
858   patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
859   patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
860 
861   ConversionTarget target(*ctx);
862   target.addLegalOp<ConstantOp>();
863   target.addLegalDialect<LLVM::LLVMDialect>();
864   target.addIllegalDialect<AsyncDialect>();
865   target.addDynamicallyLegalOp<FuncOp>(
866       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
867   target.addDynamicallyLegalOp<CallOp>(
868       [&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
869 
870   if (failed(applyPartialConversion(module, target, std::move(patterns))))
871     signalPassFailure();
872 }
873 } // namespace
874 
875 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
876   return std::make_unique<ConvertAsyncToLLVMPass>();
877 }
878