1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Async/IR/Async.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/RegionUtils.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 #define DEBUG_TYPE "convert-async-to-llvm"
25 
26 using namespace mlir;
27 using namespace mlir::async;
28 
29 // Prefix for functions outlined from `async.execute` op regions.
30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
37 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
39 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
40 static constexpr const char *kAwaitAndExecute =
41     "mlirAsyncRuntimeAwaitTokenAndExecute";
42 
43 namespace {
44 // Async Runtime API function types.
45 struct AsyncAPI {
46   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
47     return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
48   }
49 
50   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
51     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
52   }
53 
54   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
55     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
56   }
57 
58   static FunctionType executeFunctionType(MLIRContext *ctx) {
59     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
60     auto resume = resumeFunctionType(ctx).getPointerTo();
61     return FunctionType::get({hdl, resume}, {}, ctx);
62   }
63 
64   static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
65     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
66     auto resume = resumeFunctionType(ctx).getPointerTo();
67     return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
68   }
69 
70   // Auxiliary coroutine resume intrinsic wrapper.
71   static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
72     auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
73     auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
74     return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false);
75   }
76 };
77 } // namespace
78 
79 // Adds Async Runtime C API declarations to the module.
80 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
81   auto builder = OpBuilder::atBlockTerminator(module.getBody());
82 
83   MLIRContext *ctx = module.getContext();
84   Location loc = module.getLoc();
85 
86   if (!module.lookupSymbol(kCreateToken))
87     builder.create<FuncOp>(loc, kCreateToken,
88                            AsyncAPI::createTokenFunctionType(ctx));
89 
90   if (!module.lookupSymbol(kEmplaceToken))
91     builder.create<FuncOp>(loc, kEmplaceToken,
92                            AsyncAPI::emplaceTokenFunctionType(ctx));
93 
94   if (!module.lookupSymbol(kAwaitToken))
95     builder.create<FuncOp>(loc, kAwaitToken,
96                            AsyncAPI::awaitTokenFunctionType(ctx));
97 
98   if (!module.lookupSymbol(kExecute))
99     builder.create<FuncOp>(loc, kExecute, AsyncAPI::executeFunctionType(ctx));
100 
101   if (!module.lookupSymbol(kAwaitAndExecute))
102     builder.create<FuncOp>(loc, kAwaitAndExecute,
103                            AsyncAPI::awaitAndExecuteFunctionType(ctx));
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // LLVM coroutines intrinsics declarations.
108 //===----------------------------------------------------------------------===//
109 
110 static constexpr const char *kCoroId = "llvm.coro.id";
111 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
112 static constexpr const char *kCoroBegin = "llvm.coro.begin";
113 static constexpr const char *kCoroSave = "llvm.coro.save";
114 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
115 static constexpr const char *kCoroEnd = "llvm.coro.end";
116 static constexpr const char *kCoroFree = "llvm.coro.free";
117 static constexpr const char *kCoroResume = "llvm.coro.resume";
118 
119 /// Adds coroutine intrinsics declarations to the module.
120 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
121   using namespace mlir::LLVM;
122 
123   MLIRContext *ctx = module.getContext();
124   Location loc = module.getLoc();
125 
126   OpBuilder builder(module.getBody()->getTerminator());
127 
128   auto token = LLVMTokenType::get(ctx);
129   auto voidTy = LLVMType::getVoidTy(ctx);
130 
131   auto i8 = LLVMType::getInt8Ty(ctx);
132   auto i1 = LLVMType::getInt1Ty(ctx);
133   auto i32 = LLVMType::getInt32Ty(ctx);
134   auto i64 = LLVMType::getInt64Ty(ctx);
135   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
136 
137   if (!module.lookupSymbol(kCoroId))
138     builder.create<LLVMFuncOp>(
139         loc, kCoroId,
140         LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false));
141 
142   if (!module.lookupSymbol(kCoroSizeI64))
143     builder.create<LLVMFuncOp>(loc, kCoroSizeI64,
144                                LLVMType::getFunctionTy(i64, false));
145 
146   if (!module.lookupSymbol(kCoroBegin))
147     builder.create<LLVMFuncOp>(
148         loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
149 
150   if (!module.lookupSymbol(kCoroSave))
151     builder.create<LLVMFuncOp>(loc, kCoroSave,
152                                LLVMType::getFunctionTy(token, i8Ptr, false));
153 
154   if (!module.lookupSymbol(kCoroSuspend))
155     builder.create<LLVMFuncOp>(loc, kCoroSuspend,
156                                LLVMType::getFunctionTy(i8, {token, i1}, false));
157 
158   if (!module.lookupSymbol(kCoroEnd))
159     builder.create<LLVMFuncOp>(loc, kCoroEnd,
160                                LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false));
161 
162   if (!module.lookupSymbol(kCoroFree))
163     builder.create<LLVMFuncOp>(
164         loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
165 
166   if (!module.lookupSymbol(kCoroResume))
167     builder.create<LLVMFuncOp>(loc, kCoroResume,
168                                LLVMType::getFunctionTy(voidTy, i8Ptr, false));
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Add malloc/free declarations to the module.
173 //===----------------------------------------------------------------------===//
174 
175 static constexpr const char *kMalloc = "malloc";
176 static constexpr const char *kFree = "free";
177 
178 /// Adds malloc/free declarations to the module.
179 static void addCRuntimeDeclarations(ModuleOp module) {
180   using namespace mlir::LLVM;
181 
182   MLIRContext *ctx = module.getContext();
183   Location loc = module.getLoc();
184 
185   OpBuilder builder(module.getBody()->getTerminator());
186 
187   auto voidTy = LLVMType::getVoidTy(ctx);
188   auto i64 = LLVMType::getInt64Ty(ctx);
189   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
190 
191   if (!module.lookupSymbol(kMalloc))
192     builder.create<LLVM::LLVMFuncOp>(
193         loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false));
194 
195   if (!module.lookupSymbol(kFree))
196     builder.create<LLVM::LLVMFuncOp>(
197         loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false));
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // Coroutine resume function wrapper.
202 //===----------------------------------------------------------------------===//
203 
204 static constexpr const char *kResume = "__resume";
205 
206 // A function that takes a coroutine handle and calls a `llvm.coro.resume`
207 // intrinsics. We need this function to be able to pass it to the async
208 // runtime execute API.
209 static void addResumeFunction(ModuleOp module) {
210   MLIRContext *ctx = module.getContext();
211 
212   OpBuilder moduleBuilder(module.getBody()->getTerminator());
213   Location loc = module.getLoc();
214 
215   if (module.lookupSymbol(kResume))
216     return;
217 
218   auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
219   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
220 
221   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
222       loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
223   SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private);
224 
225   auto *block = resumeOp.addEntryBlock();
226   OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
227 
228   blockBuilder.create<LLVM::CallOp>(loc, Type(),
229                                     blockBuilder.getSymbolRefAttr(kCoroResume),
230                                     resumeOp.getArgument(0));
231 
232   blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // async.execute op outlining to the coroutine functions.
237 //===----------------------------------------------------------------------===//
238 
239 // Function targeted for coroutine transformation has two additional blocks at
240 // the end: coroutine cleanup and coroutine suspension.
241 //
242 // async.await op lowering additionaly creates a resume block for each
243 // operation to enable non-blocking waiting via coroutine suspension.
244 namespace {
245 struct CoroMachinery {
246   Value asyncToken;
247   Value coroHandle;
248   Block *cleanup;
249   Block *suspend;
250 };
251 } // namespace
252 
253 // Builds an coroutine template compatible with LLVM coroutines lowering.
254 //
255 //  - `entry` block sets up the coroutine.
256 //  - `cleanup` block cleans up the coroutine state.
257 //  - `suspend block after the @llvm.coro.end() defines what value will be
258 //    returned to the initial caller of a coroutine. Everything before the
259 //    @llvm.coro.end() will be executed at every suspension point.
260 //
261 // Coroutine structure (only the important bits):
262 //
263 //   func @async_execute_fn(<function-arguments>) -> !async.token {
264 //     ^entryBlock(<function-arguments>):
265 //       %token = <async token> : !async.token // create async runtime token
266 //       %hdl = llvm.call @llvm.coro.id(...)   // create a coroutine handle
267 //       br ^cleanup
268 //
269 //     ^cleanup:
270 //       llvm.call @llvm.coro.free(...)        // delete coroutine state
271 //       br ^suspend
272 //
273 //     ^suspend:
274 //       llvm.call @llvm.coro.end(...)         // marks the end of a coroutine
275 //       return %token : !async.token
276 //   }
277 //
278 // The actual code for the async.execute operation body region will be inserted
279 // before the entry block terminator.
280 //
281 //
282 static CoroMachinery setupCoroMachinery(FuncOp func) {
283   assert(func.getBody().empty() && "Function must have empty body");
284 
285   MLIRContext *ctx = func.getContext();
286 
287   auto token = LLVM::LLVMTokenType::get(ctx);
288   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
289   auto i32 = LLVM::LLVMType::getInt32Ty(ctx);
290   auto i64 = LLVM::LLVMType::getInt64Ty(ctx);
291   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
292 
293   Block *entryBlock = func.addEntryBlock();
294   Location loc = func.getBody().getLoc();
295 
296   OpBuilder builder = OpBuilder::atBlockBegin(entryBlock);
297 
298   // ------------------------------------------------------------------------ //
299   // Allocate async tokens/values that we will return from a ramp function.
300   // ------------------------------------------------------------------------ //
301   auto createToken =
302       builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx));
303 
304   // ------------------------------------------------------------------------ //
305   // Initialize coroutine: allocate frame, get coroutine handle.
306   // ------------------------------------------------------------------------ //
307 
308   // Constants for initializing coroutine frame.
309   auto constZero =
310       builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0));
311   auto constFalse =
312       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
313   auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr);
314 
315   // Get coroutine id: @llvm.coro.id
316   auto coroId = builder.create<LLVM::CallOp>(
317       loc, token, builder.getSymbolRefAttr(kCoroId),
318       ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
319 
320   // Get coroutine frame size: @llvm.coro.size.i64
321   auto coroSize = builder.create<LLVM::CallOp>(
322       loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
323 
324   // Allocate memory for coroutine frame.
325   auto coroAlloc = builder.create<LLVM::CallOp>(
326       loc, i8Ptr, builder.getSymbolRefAttr(kMalloc),
327       ValueRange(coroSize.getResult(0)));
328 
329   // Begin a coroutine: @llvm.coro.begin
330   auto coroHdl = builder.create<LLVM::CallOp>(
331       loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
332       ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
333 
334   Block *cleanupBlock = func.addBlock();
335   Block *suspendBlock = func.addBlock();
336 
337   // ------------------------------------------------------------------------ //
338   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
339   // ------------------------------------------------------------------------ //
340   builder.setInsertionPointToStart(cleanupBlock);
341 
342   // Get a pointer to the coroutine frame memory: @llvm.coro.free.
343   auto coroMem = builder.create<LLVM::CallOp>(
344       loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree),
345       ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
346 
347   // Free the memory.
348   builder.create<LLVM::CallOp>(loc, Type(), builder.getSymbolRefAttr(kFree),
349                                ValueRange(coroMem.getResult(0)));
350   // Branch into the suspend block.
351   builder.create<BranchOp>(loc, suspendBlock);
352 
353   // ------------------------------------------------------------------------ //
354   // Coroutine suspend block: mark the end of a coroutine and return allocated
355   // async token.
356   // ------------------------------------------------------------------------ //
357   builder.setInsertionPointToStart(suspendBlock);
358 
359   // Mark the end of a coroutine: @llvm.coro.end.
360   builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd),
361                                ValueRange({coroHdl.getResult(0), constFalse}));
362 
363   // Return created `async.token` from the suspend block. This will be the
364   // return value of a coroutine ramp function.
365   builder.create<ReturnOp>(loc, createToken.getResult(0));
366 
367   // Branch from the entry block to the cleanup block to create a valid CFG.
368   builder.setInsertionPointToEnd(entryBlock);
369 
370   builder.create<BranchOp>(loc, cleanupBlock);
371 
372   // `async.await` op lowering will create resume blocks for async
373   // continuations, and will conditionally branch to cleanup or suspend blocks.
374 
375   return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
376           suspendBlock};
377 }
378 
379 // Adds a suspension point before the `op`, and moves `op` and all operations
380 // after it into the resume block. Returns a pointer to the resume block.
381 //
382 // `coroState` must be a value returned from the call to @llvm.coro.save(...)
383 // intrinsic (saved coroutine state).
384 //
385 // Before:
386 //
387 //   ^bb0:
388 //     "opBefore"(...)
389 //     "op"(...)
390 //   ^cleanup: ...
391 //   ^suspend: ...
392 //
393 // After:
394 //
395 //   ^bb0:
396 //     "opBefore"(...)
397 //     %suspend = llmv.call @llvm.coro.suspend(...)
398 //     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
399 //   ^resume:
400 //     "op"(...)
401 //   ^cleanup: ...
402 //   ^suspend: ...
403 //
404 static Block *addSuspensionPoint(CoroMachinery coro, Value coroState,
405                                  Operation *op) {
406   MLIRContext *ctx = op->getContext();
407   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
408   auto i8 = LLVM::LLVMType::getInt8Ty(ctx);
409 
410   Location loc = op->getLoc();
411   Block *splitBlock = op->getBlock();
412 
413   // Split the block before `op`, newly added block is the resume block.
414   Block *resume = splitBlock->splitBlock(op);
415 
416   // Add a coroutine suspension in place of original `op` in the split block.
417   OpBuilder builder = OpBuilder::atBlockEnd(splitBlock);
418 
419   auto constFalse =
420       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
421 
422   // Suspend a coroutine: @llvm.coro.suspend
423   auto coroSuspend = builder.create<LLVM::CallOp>(
424       loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
425       ValueRange({coroState, constFalse}));
426 
427   // After a suspension point decide if we should branch into resume, cleanup
428   // or suspend block of the coroutine (see @llvm.coro.suspend return code
429   // documentation).
430   auto constZero =
431       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
432   auto constNegOne =
433       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
434 
435   Block *resumeOrCleanup = builder.createBlock(resume);
436 
437   // Suspend the coroutine ...?
438   builder.setInsertionPointToEnd(splitBlock);
439   auto isNegOne = builder.create<LLVM::ICmpOp>(
440       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
441   builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
442                                  /*falseDest=*/resumeOrCleanup);
443 
444   // ... or resume or cleanup the coroutine?
445   builder.setInsertionPointToStart(resumeOrCleanup);
446   auto isZero = builder.create<LLVM::ICmpOp>(
447       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
448   builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
449                                  /*falseDest=*/coro.cleanup);
450 
451   return resume;
452 }
453 
454 // Outline the body region attached to the `async.execute` op into a standalone
455 // function.
456 static std::pair<FuncOp, CoroMachinery>
457 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
458   ModuleOp module = execute.getParentOfType<ModuleOp>();
459 
460   MLIRContext *ctx = module.getContext();
461   Location loc = execute.getLoc();
462 
463   OpBuilder moduleBuilder(module.getBody()->getTerminator());
464 
465   // Get values captured by the async region
466   llvm::SetVector<mlir::Value> usedAbove;
467   getUsedValuesDefinedAbove(execute.body(), usedAbove);
468 
469   // Collect types of the captured values.
470   auto usedAboveTypes =
471       llvm::map_range(usedAbove, [](Value value) { return value.getType(); });
472   SmallVector<Type, 4> inputTypes(usedAboveTypes.begin(), usedAboveTypes.end());
473   auto outputTypes = execute.getResultTypes();
474 
475   auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
476   auto funcAttrs = ArrayRef<NamedAttribute>();
477 
478   // TODO: Derive outlined function name from the parent FuncOp (support
479   // multiple nested async.execute operations).
480   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
481   symbolTable.insert(func, moduleBuilder.getInsertionPoint());
482 
483   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
484 
485   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
486   // blocks, adding llvm.coro instrinsics and setting up control flow.
487   CoroMachinery coro = setupCoroMachinery(func);
488 
489   // Suspend async function at the end of an entry block, and resume it using
490   // Async execute API (execution will be resumed in a thread managed by the
491   // async runtime).
492   Block *entryBlock = &func.getBlocks().front();
493   OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock);
494 
495   // A pointer to coroutine resume intrinsic wrapper.
496   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
497   auto resumePtr = builder.create<LLVM::AddressOfOp>(
498       loc, resumeFnTy.getPointerTo(), kResume);
499 
500   // Save the coroutine state: @llvm.coro.save
501   auto coroSave = builder.create<LLVM::CallOp>(
502       loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
503       ValueRange({coro.coroHandle}));
504 
505   // Call async runtime API to execute a coroutine in the managed thread.
506   SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
507   builder.create<CallOp>(loc, Type(), kExecute, executeArgs);
508 
509   // Split the entry block before the terminator.
510   Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
511                                      entryBlock->getTerminator());
512 
513   // Map from values defined above the execute op to the function arguments.
514   BlockAndValueMapping valueMapping;
515   valueMapping.map(usedAbove, func.getArguments());
516 
517   // Clone all operations from the execute operation body into the outlined
518   // function body, and replace all `async.yield` operations with a call
519   // to async runtime to emplace the result token.
520   builder.setInsertionPointToStart(resume);
521   for (Operation &op : execute.body().getOps()) {
522     if (isa<async::YieldOp>(op)) {
523       builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken);
524       continue;
525     }
526     builder.clone(op, valueMapping);
527   }
528 
529   // Replace the original `async.execute` with a call to outlined function.
530   OpBuilder callBuilder(execute);
531   SmallVector<Value, 4> usedAboveArgs(usedAbove.begin(), usedAbove.end());
532   auto callOutlinedFunc = callBuilder.create<CallOp>(
533       loc, func.getName(), execute.getResultTypes(), usedAboveArgs);
534   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
535   execute.erase();
536 
537   return {func, coro};
538 }
539 
540 //===----------------------------------------------------------------------===//
541 // Convert Async dialect types to LLVM types.
542 //===----------------------------------------------------------------------===//
543 
544 namespace {
545 class AsyncRuntimeTypeConverter : public TypeConverter {
546 public:
547   AsyncRuntimeTypeConverter() { addConversion(convertType); }
548 
549   static Type convertType(Type type) {
550     MLIRContext *ctx = type.getContext();
551     // Convert async tokens to opaque pointers.
552     if (type.isa<TokenType>())
553       return LLVM::LLVMType::getInt8PtrTy(ctx);
554     return type;
555   }
556 };
557 } // namespace
558 
559 //===----------------------------------------------------------------------===//
560 // Convert types for all call operations to lowered async types.
561 //===----------------------------------------------------------------------===//
562 
563 namespace {
564 class CallOpOpConversion : public ConversionPattern {
565 public:
566   explicit CallOpOpConversion(MLIRContext *ctx)
567       : ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
568 
569   LogicalResult
570   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
571                   ConversionPatternRewriter &rewriter) const override {
572     AsyncRuntimeTypeConverter converter;
573 
574     SmallVector<Type, 5> resultTypes;
575     converter.convertTypes(op->getResultTypes(), resultTypes);
576 
577     CallOp call = cast<CallOp>(op);
578     rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
579                                         call.getOperands());
580 
581     return success();
582   }
583 };
584 } // namespace
585 
586 //===----------------------------------------------------------------------===//
587 // async.await op lowering to mlirAsyncRuntimeAwaitToken function call.
588 //===----------------------------------------------------------------------===//
589 
590 namespace {
591 class AwaitOpLowering : public ConversionPattern {
592 public:
593   explicit AwaitOpLowering(
594       MLIRContext *ctx,
595       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
596       : ConversionPattern(AwaitOp::getOperationName(), 1, ctx),
597         outlinedFunctions(outlinedFunctions) {}
598 
599   LogicalResult
600   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
601                   ConversionPatternRewriter &rewriter) const override {
602     // We can only await on the token operand. Async valus are not supported.
603     auto await = cast<AwaitOp>(op);
604     if (!await.operand().getType().isa<TokenType>())
605       return failure();
606 
607     // Check if `async.await` is inside the outlined coroutine function.
608     auto func = await.getParentOfType<FuncOp>();
609     auto outlined = outlinedFunctions.find(func);
610     const bool isInCoroutine = outlined != outlinedFunctions.end();
611 
612     Location loc = op->getLoc();
613 
614     // Inside regular function we convert await operation to the blocking
615     // async API await function call.
616     if (!isInCoroutine)
617       rewriter.create<CallOp>(loc, Type(), kAwaitToken,
618                               ValueRange(op->getOperand(0)));
619 
620     // Inside the coroutine we convert await operation into coroutine suspension
621     // point, and resume execution asynchronously.
622     if (isInCoroutine) {
623       const CoroMachinery &coro = outlined->getSecond();
624 
625       OpBuilder builder(op);
626       MLIRContext *ctx = op->getContext();
627 
628       // A pointer to coroutine resume intrinsic wrapper.
629       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
630       auto resumePtr = builder.create<LLVM::AddressOfOp>(
631           loc, resumeFnTy.getPointerTo(), kResume);
632 
633       // Save the coroutine state: @llvm.coro.save
634       auto coroSave = builder.create<LLVM::CallOp>(
635           loc, LLVM::LLVMTokenType::get(ctx),
636           builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle));
637 
638       // Call async runtime API to resume a coroutine in the managed thread when
639       // the async await argument becomes ready.
640       SmallVector<Value, 3> awaitAndExecuteArgs = {
641           await.getOperand(), coro.coroHandle, resumePtr.res()};
642       builder.create<CallOp>(loc, Type(), kAwaitAndExecute,
643                              awaitAndExecuteArgs);
644 
645       // Split the entry block before the await operation.
646       addSuspensionPoint(coro, coroSave.getResult(0), op);
647     }
648 
649     // Original operation was replaced by function call or suspension point.
650     rewriter.eraseOp(op);
651 
652     return success();
653   }
654 
655 private:
656   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
657 };
658 } // namespace
659 
660 //===----------------------------------------------------------------------===//
661 
662 namespace {
663 struct ConvertAsyncToLLVMPass
664     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
665   void runOnOperation() override;
666 };
667 
668 void ConvertAsyncToLLVMPass::runOnOperation() {
669   ModuleOp module = getOperation();
670   SymbolTable symbolTable(module);
671 
672   // Outline all `async.execute` body regions into async functions (coroutines).
673   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
674 
675   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
676     // We currently do not support execute operations that take async
677     // token dependencies, async value arguments or produce async results.
678     if (!execute.dependencies().empty() || !execute.operands().empty() ||
679         !execute.results().empty()) {
680       execute.emitOpError(
681           "Can't outline async.execute op with async dependencies, arguments "
682           "or returned async results");
683       return WalkResult::interrupt();
684     }
685 
686     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
687 
688     return WalkResult::advance();
689   });
690 
691   // Failed to outline all async execute operations.
692   if (outlineResult.wasInterrupted()) {
693     signalPassFailure();
694     return;
695   }
696 
697   LLVM_DEBUG({
698     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
699                  << " async functions\n";
700   });
701 
702   // Add declarations for all functions required by the coroutines lowering.
703   addResumeFunction(module);
704   addAsyncRuntimeApiDeclarations(module);
705   addCoroutineIntrinsicsDeclarations(module);
706   addCRuntimeDeclarations(module);
707 
708   MLIRContext *ctx = &getContext();
709 
710   // Convert async dialect types and operations to LLVM dialect.
711   AsyncRuntimeTypeConverter converter;
712   OwningRewritePatternList patterns;
713 
714   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
715   patterns.insert<CallOpOpConversion>(ctx);
716   patterns.insert<AwaitOpLowering>(ctx, outlinedFunctions);
717 
718   ConversionTarget target(*ctx);
719   target.addLegalDialect<LLVM::LLVMDialect>();
720   target.addIllegalDialect<AsyncDialect>();
721   target.addDynamicallyLegalOp<FuncOp>(
722       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
723   target.addDynamicallyLegalOp<CallOp>(
724       [&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
725 
726   if (failed(applyPartialConversion(module, target, std::move(patterns))))
727     signalPassFailure();
728 }
729 } // namespace
730 
731 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
732   return std::make_unique<ConvertAsyncToLLVMPass>();
733 }
734