1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Async/IR/Async.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #define DEBUG_TYPE "convert-async-to-llvm"
27 
28 using namespace mlir;
29 using namespace mlir::async;
30 
31 // Prefix for functions outlined from `async.execute` op regions.
32 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
33 
34 //===----------------------------------------------------------------------===//
35 // Async Runtime C API declaration.
36 //===----------------------------------------------------------------------===//
37 
38 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
39 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
40 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
41 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
42 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
43 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
44 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
45 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
46 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
47 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
48 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
49 static constexpr const char *kGetValueStorage =
50     "mlirAsyncRuntimeGetValueStorage";
51 static constexpr const char *kAddTokenToGroup =
52     "mlirAsyncRuntimeAddTokenToGroup";
53 static constexpr const char *kAwaitTokenAndExecute =
54     "mlirAsyncRuntimeAwaitTokenAndExecute";
55 static constexpr const char *kAwaitValueAndExecute =
56     "mlirAsyncRuntimeAwaitValueAndExecute";
57 static constexpr const char *kAwaitAllAndExecute =
58     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
59 
60 namespace {
61 /// Async Runtime API function types.
62 ///
63 /// Because we can't create API function signature for type parametrized
64 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
65 /// lowering all async data types become opaque pointers at runtime.
66 struct AsyncAPI {
67   // All async types are lowered to opaque i8* LLVM pointers at runtime.
68   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
69     return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
70   }
71 
72   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
73     auto ref = opaquePointerType(ctx);
74     auto count = IntegerType::get(ctx, 32);
75     return FunctionType::get(ctx, {ref, count}, {});
76   }
77 
78   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
79     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
80   }
81 
82   static FunctionType createValueFunctionType(MLIRContext *ctx) {
83     auto i32 = IntegerType::get(ctx, 32);
84     auto value = opaquePointerType(ctx);
85     return FunctionType::get(ctx, {i32}, {value});
86   }
87 
88   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
89     return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
90   }
91 
92   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
93     auto value = opaquePointerType(ctx);
94     auto storage = opaquePointerType(ctx);
95     return FunctionType::get(ctx, {value}, {storage});
96   }
97 
98   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
99     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
100   }
101 
102   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
103     auto value = opaquePointerType(ctx);
104     return FunctionType::get(ctx, {value}, {});
105   }
106 
107   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
108     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
109   }
110 
111   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
112     auto value = opaquePointerType(ctx);
113     return FunctionType::get(ctx, {value}, {});
114   }
115 
116   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
117     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
118   }
119 
120   static FunctionType executeFunctionType(MLIRContext *ctx) {
121     auto hdl = opaquePointerType(ctx);
122     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
123     return FunctionType::get(ctx, {hdl, resume}, {});
124   }
125 
126   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
127     auto i64 = IntegerType::get(ctx, 64);
128     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
129                              {i64});
130   }
131 
132   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
133     auto hdl = opaquePointerType(ctx);
134     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
135     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
136   }
137 
138   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
139     auto value = opaquePointerType(ctx);
140     auto hdl = opaquePointerType(ctx);
141     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
142     return FunctionType::get(ctx, {value, hdl, resume}, {});
143   }
144 
145   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
146     auto hdl = opaquePointerType(ctx);
147     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
148     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
149   }
150 
151   // Auxiliary coroutine resume intrinsic wrapper.
152   static Type resumeFunctionType(MLIRContext *ctx) {
153     auto voidTy = LLVM::LLVMVoidType::get(ctx);
154     auto i8Ptr = opaquePointerType(ctx);
155     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
156   }
157 };
158 } // namespace
159 
160 /// Adds Async Runtime C API declarations to the module.
161 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
162   auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
163                                                          module.getBody());
164 
165   auto addFuncDecl = [&](StringRef name, FunctionType type) {
166     if (module.lookupSymbol(name))
167       return;
168     builder.create<FuncOp>(name, type).setPrivate();
169   };
170 
171   MLIRContext *ctx = module.getContext();
172   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
173   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
174   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
175   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
176   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
177   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
178   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
179   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
180   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
181   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
182   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
183   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
184   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
185   addFuncDecl(kAwaitTokenAndExecute,
186               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
187   addFuncDecl(kAwaitValueAndExecute,
188               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
189   addFuncDecl(kAwaitAllAndExecute,
190               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // LLVM coroutines intrinsics declarations.
195 //===----------------------------------------------------------------------===//
196 
197 static constexpr const char *kCoroId = "llvm.coro.id";
198 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
199 static constexpr const char *kCoroBegin = "llvm.coro.begin";
200 static constexpr const char *kCoroSave = "llvm.coro.save";
201 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
202 static constexpr const char *kCoroEnd = "llvm.coro.end";
203 static constexpr const char *kCoroFree = "llvm.coro.free";
204 static constexpr const char *kCoroResume = "llvm.coro.resume";
205 
206 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
207                             StringRef name, Type ret, ArrayRef<Type> params) {
208   if (module.lookupSymbol(name))
209     return;
210   Type type = LLVM::LLVMFunctionType::get(ret, params);
211   builder.create<LLVM::LLVMFuncOp>(name, type);
212 }
213 
214 /// Adds coroutine intrinsics declarations to the module.
215 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
216   using namespace mlir::LLVM;
217 
218   MLIRContext *ctx = module.getContext();
219   ImplicitLocOpBuilder builder(module.getLoc(),
220                                module.getBody()->getTerminator());
221 
222   auto token = LLVMTokenType::get(ctx);
223   auto voidTy = LLVMVoidType::get(ctx);
224 
225   auto i8 = LLVMIntegerType::get(ctx, 8);
226   auto i1 = LLVMIntegerType::get(ctx, 1);
227   auto i32 = LLVMIntegerType::get(ctx, 32);
228   auto i64 = LLVMIntegerType::get(ctx, 64);
229   auto i8Ptr = LLVMPointerType::get(i8);
230 
231   addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
232   addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
233   addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
234   addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
235   addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
236   addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
237   addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
238   addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // Add malloc/free declarations to the module.
243 //===----------------------------------------------------------------------===//
244 
245 static constexpr const char *kMalloc = "malloc";
246 static constexpr const char *kFree = "free";
247 
248 /// Adds malloc/free declarations to the module.
249 static void addCRuntimeDeclarations(ModuleOp module) {
250   using namespace mlir::LLVM;
251 
252   MLIRContext *ctx = module.getContext();
253   ImplicitLocOpBuilder builder(module.getLoc(),
254                                module.getBody()->getTerminator());
255 
256   auto voidTy = LLVMVoidType::get(ctx);
257   auto i64 = LLVMIntegerType::get(ctx, 64);
258   auto i8Ptr = LLVMPointerType::get(LLVMIntegerType::get(ctx, 8));
259 
260   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
261   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // Coroutine resume function wrapper.
266 //===----------------------------------------------------------------------===//
267 
268 static constexpr const char *kResume = "__resume";
269 
270 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
271 /// intrinsics. We need this function to be able to pass it to the async
272 /// runtime execute API.
273 static void addResumeFunction(ModuleOp module) {
274   MLIRContext *ctx = module.getContext();
275 
276   OpBuilder moduleBuilder(module.getBody()->getTerminator());
277   Location loc = module.getLoc();
278 
279   if (module.lookupSymbol(kResume))
280     return;
281 
282   auto voidTy = LLVM::LLVMVoidType::get(ctx);
283   auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
284 
285   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
286       loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
287   resumeOp.setPrivate();
288 
289   auto *block = resumeOp.addEntryBlock();
290   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
291 
292   blockBuilder.create<LLVM::CallOp>(TypeRange(),
293                                     blockBuilder.getSymbolRefAttr(kCoroResume),
294                                     resumeOp.getArgument(0));
295 
296   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // async.execute op outlining to the coroutine functions.
301 //===----------------------------------------------------------------------===//
302 
303 /// Function targeted for coroutine transformation has two additional blocks at
304 /// the end: coroutine cleanup and coroutine suspension.
305 ///
306 /// async.await op lowering additionaly creates a resume block for each
307 /// operation to enable non-blocking waiting via coroutine suspension.
308 namespace {
309 struct CoroMachinery {
310   // Async execute region returns a completion token, and an async value for
311   // each yielded value.
312   //
313   //   %token, %result = async.execute -> !async.value<T> {
314   //     %0 = constant ... : T
315   //     async.yield %0 : T
316   //   }
317   Value asyncToken; // token representing completion of the async region
318   llvm::SmallVector<Value, 4> returnValues; // returned async values
319 
320   Value coroHandle;
321   Block *cleanup;
322   Block *suspend;
323 };
324 } // namespace
325 
326 /// Builds an coroutine template compatible with LLVM coroutines lowering.
327 ///
328 ///  - `entry` block sets up the coroutine.
329 ///  - `cleanup` block cleans up the coroutine state.
330 ///  - `suspend block after the @llvm.coro.end() defines what value will be
331 ///    returned to the initial caller of a coroutine. Everything before the
332 ///    @llvm.coro.end() will be executed at every suspension point.
333 ///
334 /// Coroutine structure (only the important bits):
335 ///
336 ///   func @async_execute_fn(<function-arguments>)
337 ///        -> (!async.token, !async.value<T>)
338 ///   {
339 ///     ^entryBlock(<function-arguments>):
340 ///       %token = <async token> : !async.token    // create async runtime token
341 ///       %value = <async value> : !async.value<T> // create async value
342 ///       %hdl = llvm.call @llvm.coro.id(...)      // create a coroutine handle
343 ///       br ^cleanup
344 ///
345 ///     ^cleanup:
346 ///       llvm.call @llvm.coro.free(...)  // delete coroutine state
347 ///       br ^suspend
348 ///
349 ///     ^suspend:
350 ///       llvm.call @llvm.coro.end(...)  // marks the end of a coroutine
351 ///       return %token, %value : !async.token, !async.value<T>
352 ///   }
353 ///
354 /// The actual code for the async.execute operation body region will be inserted
355 /// before the entry block terminator.
356 ///
357 ///
358 static CoroMachinery setupCoroMachinery(FuncOp func) {
359   assert(func.getBody().empty() && "Function must have empty body");
360 
361   MLIRContext *ctx = func.getContext();
362 
363   auto token = LLVM::LLVMTokenType::get(ctx);
364   auto i1 = LLVM::LLVMIntegerType::get(ctx, 1);
365   auto i32 = LLVM::LLVMIntegerType::get(ctx, 32);
366   auto i64 = LLVM::LLVMIntegerType::get(ctx, 64);
367   auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
368 
369   Block *entryBlock = func.addEntryBlock();
370   Location loc = func.getBody().getLoc();
371 
372   auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock);
373 
374   // ------------------------------------------------------------------------ //
375   // Allocate async tokens/values that we will return from a ramp function.
376   // ------------------------------------------------------------------------ //
377   auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx));
378 
379   // Async value operands and results must be convertible to LLVM types. This is
380   // verified before the function outlining.
381   LLVMTypeConverter converter(ctx);
382 
383   // Returns the size requirements for the async value storage.
384   // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
385   auto sizeOf = [&](ValueType valueType) -> Value {
386     auto storedType = converter.convertType(valueType.getValueType());
387     auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
388 
389     // %Size = getelementptr %T* null, int 1
390     // %SizeI = ptrtoint %T* %Size to i32
391     auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType);
392     auto one = builder.create<LLVM::ConstantOp>(loc, i32,
393                                                 builder.getI32IntegerAttr(1));
394     auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
395                                            one.getResult());
396     auto size = builder.create<LLVM::PtrToIntOp>(loc, i32, gep);
397 
398     // Cast to std type because runtime API defined using std types.
399     return builder.create<LLVM::DialectCastOp>(loc, builder.getI32Type(),
400                                                size.getResult());
401   };
402 
403   // We use the `async.value` type as a return type although it does not match
404   // the `kCreateValue` function signature, because it will be later lowered to
405   // the runtime type (opaque i8* pointer).
406   llvm::SmallVector<CallOp, 4> createValues;
407   for (auto resultType : func.getCallableResults().drop_front(1))
408     createValues.emplace_back(builder.create<CallOp>(
409         loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>())));
410 
411   auto createdValues = llvm::map_range(
412       createValues, [](CallOp call) { return call.getResult(0); });
413   llvm::SmallVector<Value, 4> returnValues(createdValues.begin(),
414                                            createdValues.end());
415 
416   // ------------------------------------------------------------------------ //
417   // Initialize coroutine: allocate frame, get coroutine handle.
418   // ------------------------------------------------------------------------ //
419 
420   // Constants for initializing coroutine frame.
421   auto constZero =
422       builder.create<LLVM::ConstantOp>(i32, builder.getI32IntegerAttr(0));
423   auto constFalse =
424       builder.create<LLVM::ConstantOp>(i1, builder.getBoolAttr(false));
425   auto nullPtr = builder.create<LLVM::NullOp>(i8Ptr);
426 
427   // Get coroutine id: @llvm.coro.id
428   auto coroId = builder.create<LLVM::CallOp>(
429       token, builder.getSymbolRefAttr(kCoroId),
430       ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
431 
432   // Get coroutine frame size: @llvm.coro.size.i64
433   auto coroSize = builder.create<LLVM::CallOp>(
434       i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
435 
436   // Allocate memory for coroutine frame.
437   auto coroAlloc =
438       builder.create<LLVM::CallOp>(i8Ptr, builder.getSymbolRefAttr(kMalloc),
439                                    ValueRange(coroSize.getResult(0)));
440 
441   // Begin a coroutine: @llvm.coro.begin
442   auto coroHdl = builder.create<LLVM::CallOp>(
443       i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
444       ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
445 
446   Block *cleanupBlock = func.addBlock();
447   Block *suspendBlock = func.addBlock();
448 
449   // ------------------------------------------------------------------------ //
450   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
451   // ------------------------------------------------------------------------ //
452   builder.setInsertionPointToStart(cleanupBlock);
453 
454   // Get a pointer to the coroutine frame memory: @llvm.coro.free.
455   auto coroMem = builder.create<LLVM::CallOp>(
456       i8Ptr, builder.getSymbolRefAttr(kCoroFree),
457       ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
458 
459   // Free the memory.
460   builder.create<LLVM::CallOp>(TypeRange(), builder.getSymbolRefAttr(kFree),
461                                ValueRange(coroMem.getResult(0)));
462   // Branch into the suspend block.
463   builder.create<BranchOp>(suspendBlock);
464 
465   // ------------------------------------------------------------------------ //
466   // Coroutine suspend block: mark the end of a coroutine and return allocated
467   // async token.
468   // ------------------------------------------------------------------------ //
469   builder.setInsertionPointToStart(suspendBlock);
470 
471   // Mark the end of a coroutine: @llvm.coro.end.
472   builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd),
473                                ValueRange({coroHdl.getResult(0), constFalse}));
474 
475   // Return created `async.token` and `async.values` from the suspend block.
476   // This will be the return value of a coroutine ramp function.
477   SmallVector<Value, 4> ret{createToken.getResult(0)};
478   ret.insert(ret.end(), returnValues.begin(), returnValues.end());
479   builder.create<ReturnOp>(loc, ret);
480 
481   // Branch from the entry block to the cleanup block to create a valid CFG.
482   builder.setInsertionPointToEnd(entryBlock);
483 
484   builder.create<BranchOp>(cleanupBlock);
485 
486   // `async.await` op lowering will create resume blocks for async
487   // continuations, and will conditionally branch to cleanup or suspend blocks.
488 
489   CoroMachinery machinery;
490   machinery.asyncToken = createToken.getResult(0);
491   machinery.returnValues = returnValues;
492   machinery.coroHandle = coroHdl.getResult(0);
493   machinery.cleanup = cleanupBlock;
494   machinery.suspend = suspendBlock;
495   return machinery;
496 }
497 
498 /// Add a LLVM coroutine suspension point to the end of suspended block, to
499 /// resume execution in resume block. The caller is responsible for creating the
500 /// two suspended/resume blocks with the desired ops contained in each block.
501 /// This function merely provides the required control flow logic.
502 ///
503 /// `coroState` must be a value returned from the call to @llvm.coro.save(...)
504 /// intrinsic (saved coroutine state).
505 ///
506 /// Before:
507 ///
508 ///   ^bb0:
509 ///     "opBefore"(...)
510 ///     "op"(...)
511 ///   ^cleanup: ...
512 ///   ^suspend: ...
513 ///   ^resume:
514 ///     "op"(...)
515 ///
516 /// After:
517 ///
518 ///   ^bb0:
519 ///     "opBefore"(...)
520 ///     %suspend = llmv.call @llvm.coro.suspend(...)
521 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
522 ///   ^resume:
523 ///     "op"(...)
524 ///   ^cleanup: ...
525 ///   ^suspend: ...
526 ///
527 static void addSuspensionPoint(CoroMachinery coro, Value coroState,
528                                Operation *op, Block *suspended, Block *resume,
529                                OpBuilder &builder) {
530   Location loc = op->getLoc();
531   MLIRContext *ctx = op->getContext();
532   auto i1 = LLVM::LLVMIntegerType::get(ctx, 1);
533   auto i8 = LLVM::LLVMIntegerType::get(ctx, 8);
534 
535   // Add a coroutine suspension in place of original `op` in the split block.
536   OpBuilder::InsertionGuard guard(builder);
537   builder.setInsertionPointToEnd(suspended);
538 
539   auto constFalse =
540       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
541 
542   // Suspend a coroutine: @llvm.coro.suspend
543   auto coroSuspend = builder.create<LLVM::CallOp>(
544       loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
545       ValueRange({coroState, constFalse}));
546 
547   // After a suspension point decide if we should branch into resume, cleanup
548   // or suspend block of the coroutine (see @llvm.coro.suspend return code
549   // documentation).
550   auto constZero =
551       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
552   auto constNegOne =
553       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
554 
555   Block *resumeOrCleanup = builder.createBlock(resume);
556 
557   // Suspend the coroutine ...?
558   builder.setInsertionPointToEnd(suspended);
559   auto isNegOne = builder.create<LLVM::ICmpOp>(
560       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
561   builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
562                                  /*falseDest=*/resumeOrCleanup);
563 
564   // ... or resume or cleanup the coroutine?
565   builder.setInsertionPointToStart(resumeOrCleanup);
566   auto isZero = builder.create<LLVM::ICmpOp>(
567       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
568   builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
569                                  /*falseDest=*/coro.cleanup);
570 }
571 
572 /// Outline the body region attached to the `async.execute` op into a standalone
573 /// function.
574 ///
575 /// Note that this is not reversible transformation.
576 static std::pair<FuncOp, CoroMachinery>
577 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
578   ModuleOp module = execute->getParentOfType<ModuleOp>();
579 
580   MLIRContext *ctx = module.getContext();
581   Location loc = execute.getLoc();
582 
583   // Collect all outlined function inputs.
584   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
585                                               execute.dependencies().end());
586   functionInputs.insert(execute.operands().begin(), execute.operands().end());
587   getUsedValuesDefinedAbove(execute.body(), functionInputs);
588 
589   // Collect types for the outlined function inputs and outputs.
590   auto typesRange = llvm::map_range(
591       functionInputs, [](Value value) { return value.getType(); });
592   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
593   auto outputTypes = execute.getResultTypes();
594 
595   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
596   auto funcAttrs = ArrayRef<NamedAttribute>();
597 
598   // TODO: Derive outlined function name from the parent FuncOp (support
599   // multiple nested async.execute operations).
600   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
601   symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
602 
603   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
604 
605   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
606   // blocks, adding llvm.coro instrinsics and setting up control flow.
607   CoroMachinery coro = setupCoroMachinery(func);
608 
609   // Suspend async function at the end of an entry block, and resume it using
610   // Async execute API (execution will be resumed in a thread managed by the
611   // async runtime).
612   Block *entryBlock = &func.getBlocks().front();
613   auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
614 
615   // A pointer to coroutine resume intrinsic wrapper.
616   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
617   auto resumePtr = builder.create<LLVM::AddressOfOp>(
618       LLVM::LLVMPointerType::get(resumeFnTy), kResume);
619 
620   // Save the coroutine state: @llvm.coro.save
621   auto coroSave = builder.create<LLVM::CallOp>(
622       LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
623       ValueRange({coro.coroHandle}));
624 
625   // Call async runtime API to execute a coroutine in the managed thread.
626   SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
627   builder.create<CallOp>(TypeRange(), kExecute, executeArgs);
628 
629   // Split the entry block before the terminator.
630   auto *terminatorOp = entryBlock->getTerminator();
631   Block *suspended = terminatorOp->getBlock();
632   Block *resume = suspended->splitBlock(terminatorOp);
633   addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended,
634                      resume, builder);
635 
636   size_t numDependencies = execute.dependencies().size();
637   size_t numOperands = execute.operands().size();
638 
639   // Await on all dependencies before starting to execute the body region.
640   builder.setInsertionPointToStart(resume);
641   for (size_t i = 0; i < numDependencies; ++i)
642     builder.create<AwaitOp>(func.getArgument(i));
643 
644   // Await on all async value operands and unwrap the payload.
645   SmallVector<Value, 4> unwrappedOperands(numOperands);
646   for (size_t i = 0; i < numOperands; ++i) {
647     Value operand = func.getArgument(numDependencies + i);
648     unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
649   }
650 
651   // Map from function inputs defined above the execute op to the function
652   // arguments.
653   BlockAndValueMapping valueMapping;
654   valueMapping.map(functionInputs, func.getArguments());
655   valueMapping.map(execute.body().getArguments(), unwrappedOperands);
656 
657   // Clone all operations from the execute operation body into the outlined
658   // function body.
659   for (Operation &op : execute.body().getOps())
660     builder.clone(op, valueMapping);
661 
662   // Replace the original `async.execute` with a call to outlined function.
663   ImplicitLocOpBuilder callBuilder(loc, execute);
664   auto callOutlinedFunc = callBuilder.create<CallOp>(
665       func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
666   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
667   execute.erase();
668 
669   return {func, coro};
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // Convert Async dialect types to LLVM types.
674 //===----------------------------------------------------------------------===//
675 
676 namespace {
677 
678 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
679 /// their runtime type (opaque pointers) and does not convert any other types.
680 class AsyncRuntimeTypeConverter : public TypeConverter {
681 public:
682   AsyncRuntimeTypeConverter() {
683     addConversion([](Type type) { return type; });
684     addConversion(convertAsyncTypes);
685   }
686 
687   static Optional<Type> convertAsyncTypes(Type type) {
688     if (type.isa<TokenType, GroupType, ValueType>())
689       return AsyncAPI::opaquePointerType(type.getContext());
690     return llvm::None;
691   }
692 };
693 } // namespace
694 
695 //===----------------------------------------------------------------------===//
696 // Convert return operations that return async values from async regions.
697 //===----------------------------------------------------------------------===//
698 
699 namespace {
700 class ReturnOpOpConversion : public ConversionPattern {
701 public:
702   explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx)
703       : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {}
704 
705   LogicalResult
706   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
707                   ConversionPatternRewriter &rewriter) const override {
708     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
709     return success();
710   }
711 };
712 } // namespace
713 
714 //===----------------------------------------------------------------------===//
715 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref`
716 // to the corresponding API calls).
717 //===----------------------------------------------------------------------===//
718 
719 namespace {
720 
721 template <typename RefCountingOp>
722 class RefCountingOpLowering : public ConversionPattern {
723 public:
724   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
725                                  StringRef apiFunctionName)
726       : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx),
727         apiFunctionName(apiFunctionName) {}
728 
729   LogicalResult
730   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
731                   ConversionPatternRewriter &rewriter) const override {
732     RefCountingOp refCountingOp = cast<RefCountingOp>(op);
733 
734     auto count = rewriter.create<ConstantOp>(
735         op->getLoc(), rewriter.getI32Type(),
736         rewriter.getI32IntegerAttr(refCountingOp.count()));
737 
738     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
739                                         ValueRange({operands[0], count}));
740 
741     return success();
742   }
743 
744 private:
745   StringRef apiFunctionName;
746 };
747 
748 /// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
749 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
750 public:
751   explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
752       : RefCountingOpLowering(converter, ctx, kAddRef) {}
753 };
754 
755 /// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
756 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
757 public:
758   explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
759       : RefCountingOpLowering(converter, ctx, kDropRef) {}
760 };
761 
762 } // namespace
763 
764 //===----------------------------------------------------------------------===//
765 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
766 //===----------------------------------------------------------------------===//
767 
768 namespace {
769 class CreateGroupOpLowering : public ConversionPattern {
770 public:
771   explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
772       : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter,
773                           ctx) {}
774 
775   LogicalResult
776   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
777                   ConversionPatternRewriter &rewriter) const override {
778     auto retTy = GroupType::get(op->getContext());
779     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy);
780     return success();
781   }
782 };
783 } // namespace
784 
785 //===----------------------------------------------------------------------===//
786 // async.add_to_group op lowering to runtime function call.
787 //===----------------------------------------------------------------------===//
788 
789 namespace {
790 class AddToGroupOpLowering : public ConversionPattern {
791 public:
792   explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
793       : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) {
794   }
795 
796   LogicalResult
797   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
798                   ConversionPatternRewriter &rewriter) const override {
799     // Currently we can only add tokens to the group.
800     auto addToGroup = cast<AddToGroupOp>(op);
801     if (!addToGroup.operand().getType().isa<TokenType>())
802       return failure();
803 
804     auto i64 = IntegerType::get(op->getContext(), 64);
805     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
806     return success();
807   }
808 };
809 } // namespace
810 
811 //===----------------------------------------------------------------------===//
812 // async.await and async.await_all op lowerings to the corresponding async
813 // runtime function calls.
814 //===----------------------------------------------------------------------===//
815 
816 namespace {
817 
818 template <typename AwaitType, typename AwaitableType>
819 class AwaitOpLoweringBase : public ConversionPattern {
820 protected:
821   explicit AwaitOpLoweringBase(
822       TypeConverter &converter, MLIRContext *ctx,
823       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
824       StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
825       : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx),
826         outlinedFunctions(outlinedFunctions),
827         blockingAwaitFuncName(blockingAwaitFuncName),
828         coroAwaitFuncName(coroAwaitFuncName) {}
829 
830 public:
831   LogicalResult
832   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
833                   ConversionPatternRewriter &rewriter) const override {
834     // We can only await on one the `AwaitableType` (for `await` it can be
835     // a `token` or a `value`, for `await_all` it must be a `group`).
836     auto await = cast<AwaitType>(op);
837     if (!await.operand().getType().template isa<AwaitableType>())
838       return failure();
839 
840     // Check if await operation is inside the outlined coroutine function.
841     auto func = await->template getParentOfType<FuncOp>();
842     auto outlined = outlinedFunctions.find(func);
843     const bool isInCoroutine = outlined != outlinedFunctions.end();
844 
845     Location loc = op->getLoc();
846 
847     // Inside regular function we convert await operation to the blocking
848     // async API await function call.
849     if (!isInCoroutine)
850       rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName,
851                               ValueRange(operands[0]));
852 
853     // Inside the coroutine we convert await operation into coroutine suspension
854     // point, and resume execution asynchronously.
855     if (isInCoroutine) {
856       const CoroMachinery &coro = outlined->getSecond();
857 
858       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
859       MLIRContext *ctx = op->getContext();
860 
861       // A pointer to coroutine resume intrinsic wrapper.
862       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
863       auto resumePtr = builder.create<LLVM::AddressOfOp>(
864           LLVM::LLVMPointerType::get(resumeFnTy), kResume);
865 
866       // Save the coroutine state: @llvm.coro.save
867       auto coroSave = builder.create<LLVM::CallOp>(
868           LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
869           ValueRange(coro.coroHandle));
870 
871       // Call async runtime API to resume a coroutine in the managed thread when
872       // the async await argument becomes ready.
873       SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle,
874                                                    resumePtr.res()};
875       builder.create<CallOp>(TypeRange(), coroAwaitFuncName,
876                              awaitAndExecuteArgs);
877 
878       Block *suspended = op->getBlock();
879 
880       // Split the entry block before the await operation.
881       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
882       addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
883                          builder);
884 
885       // Make sure that replacement value will be constructed in resume block.
886       rewriter.setInsertionPointToStart(resume);
887     }
888 
889     // Replace or erase the await operation with the new value.
890     if (Value replaceWith = getReplacementValue(op, operands[0], rewriter))
891       rewriter.replaceOp(op, replaceWith);
892     else
893       rewriter.eraseOp(op);
894 
895     return success();
896   }
897 
898   virtual Value getReplacementValue(Operation *op, Value operand,
899                                     ConversionPatternRewriter &rewriter) const {
900     return Value();
901   }
902 
903 private:
904   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
905   StringRef blockingAwaitFuncName;
906   StringRef coroAwaitFuncName;
907 };
908 
909 /// Lowering for `async.await` with a token operand.
910 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
911   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
912 
913 public:
914   explicit AwaitTokenOpLowering(
915       TypeConverter &converter, MLIRContext *ctx,
916       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
917       : Base(converter, ctx, outlinedFunctions, kAwaitToken,
918              kAwaitTokenAndExecute) {}
919 };
920 
921 /// Lowering for `async.await` with a value operand.
922 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
923   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
924 
925 public:
926   explicit AwaitValueOpLowering(
927       TypeConverter &converter, MLIRContext *ctx,
928       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
929       : Base(converter, ctx, outlinedFunctions, kAwaitValue,
930              kAwaitValueAndExecute) {}
931 
932   Value
933   getReplacementValue(Operation *op, Value operand,
934                       ConversionPatternRewriter &rewriter) const override {
935     Location loc = op->getLoc();
936     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
937 
938     // Get the underlying value type from the `async.value`.
939     auto await = cast<AwaitOp>(op);
940     auto valueType = await.operand().getType().cast<ValueType>().getValueType();
941 
942     // Get a pointer to an async value storage from the runtime.
943     auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
944                                            TypeRange(i8Ptr), operand);
945 
946     // Cast from i8* to the pointer pointer to LLVM type.
947     auto llvmValueType = getTypeConverter()->convertType(valueType);
948     auto castedStorage = rewriter.create<LLVM::BitcastOp>(
949         loc, LLVM::LLVMPointerType::get(llvmValueType), storage.getResult(0));
950 
951     // Load from the async value storage.
952     auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult());
953 
954     // Cast from LLVM type to the expected value type. This cast will become
955     // no-op after lowering to LLVM.
956     return rewriter.create<LLVM::DialectCastOp>(loc, valueType, loaded);
957   }
958 };
959 
960 /// Lowering for `async.await_all` operation.
961 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
962   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
963 
964 public:
965   explicit AwaitAllOpLowering(
966       TypeConverter &converter, MLIRContext *ctx,
967       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
968       : Base(converter, ctx, outlinedFunctions, kAwaitGroup,
969              kAwaitAllAndExecute) {}
970 };
971 
972 } // namespace
973 
974 //===----------------------------------------------------------------------===//
975 // async.yield op lowerings to the corresponding async runtime function calls.
976 //===----------------------------------------------------------------------===//
977 
978 class YieldOpLowering : public ConversionPattern {
979 public:
980   explicit YieldOpLowering(
981       TypeConverter &converter, MLIRContext *ctx,
982       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
983       : ConversionPattern(async::YieldOp::getOperationName(), 1, converter,
984                           ctx),
985         outlinedFunctions(outlinedFunctions) {}
986 
987   LogicalResult
988   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
989                   ConversionPatternRewriter &rewriter) const override {
990     // Check if yield operation is inside the outlined coroutine function.
991     auto func = op->template getParentOfType<FuncOp>();
992     auto outlined = outlinedFunctions.find(func);
993     if (outlined == outlinedFunctions.end())
994       return op->emitOpError(
995           "async.yield is not inside the outlined coroutine function");
996 
997     Location loc = op->getLoc();
998     const CoroMachinery &coro = outlined->getSecond();
999 
1000     // Store yielded values into the async values storage and emplace them.
1001     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
1002 
1003     for (auto tuple : llvm::zip(operands, coro.returnValues)) {
1004       // Store `yieldValue` into the `asyncValue` storage.
1005       Value yieldValue = std::get<0>(tuple);
1006       Value asyncValue = std::get<1>(tuple);
1007 
1008       // Get an opaque i8* pointer to an async value storage from the runtime.
1009       auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
1010                                              TypeRange(i8Ptr), asyncValue);
1011 
1012       // Cast storage pointer to the yielded value type.
1013       auto castedStorage = rewriter.create<LLVM::BitcastOp>(
1014           loc, LLVM::LLVMPointerType::get(yieldValue.getType()),
1015           storage.getResult(0));
1016 
1017       // Store the yielded value into the async value storage.
1018       rewriter.create<LLVM::StoreOp>(loc, yieldValue,
1019                                      castedStorage.getResult());
1020 
1021       // Emplace the `async.value` to mark it ready.
1022       rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue);
1023     }
1024 
1025     // Emplace the completion token to mark it ready.
1026     rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
1027 
1028     // Original operation was replaced by the function call(s).
1029     rewriter.eraseOp(op);
1030 
1031     return success();
1032   }
1033 
1034 private:
1035   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
1036 };
1037 
1038 //===----------------------------------------------------------------------===//
1039 
1040 namespace {
1041 struct ConvertAsyncToLLVMPass
1042     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
1043   void runOnOperation() override;
1044 };
1045 
1046 void ConvertAsyncToLLVMPass::runOnOperation() {
1047   ModuleOp module = getOperation();
1048   SymbolTable symbolTable(module);
1049 
1050   MLIRContext *ctx = &getContext();
1051 
1052   // Outline all `async.execute` body regions into async functions (coroutines).
1053   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
1054 
1055   // We use conversion to LLVM type to ensure that all `async.value` operands
1056   // and results can be lowered to LLVM load and store operations.
1057   LLVMTypeConverter llvmConverter(ctx);
1058   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
1059 
1060   // Returns true if the `async.value` payload is convertible to LLVM.
1061   auto isConvertibleToLlvm = [&](Type type) -> bool {
1062     auto valueType = type.cast<ValueType>().getValueType();
1063     return static_cast<bool>(llvmConverter.convertType(valueType));
1064   };
1065 
1066   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
1067     // All operands and results must be convertible to LLVM.
1068     if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) {
1069       execute.emitOpError("operands payload must be convertible to LLVM type");
1070       return WalkResult::interrupt();
1071     }
1072     if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) {
1073       execute.emitOpError("results payload must be convertible to LLVM type");
1074       return WalkResult::interrupt();
1075     }
1076 
1077     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
1078 
1079     return WalkResult::advance();
1080   });
1081 
1082   // Failed to outline all async execute operations.
1083   if (outlineResult.wasInterrupted()) {
1084     signalPassFailure();
1085     return;
1086   }
1087 
1088   LLVM_DEBUG({
1089     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
1090                  << " async functions\n";
1091   });
1092 
1093   // Add declarations for all functions required by the coroutines lowering.
1094   addResumeFunction(module);
1095   addAsyncRuntimeApiDeclarations(module);
1096   addCoroutineIntrinsicsDeclarations(module);
1097   addCRuntimeDeclarations(module);
1098 
1099   // Convert async dialect types and operations to LLVM dialect.
1100   AsyncRuntimeTypeConverter converter;
1101   OwningRewritePatternList patterns;
1102 
1103   // Convert async types in function signatures and function calls.
1104   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
1105   populateCallOpTypeConversionPattern(patterns, ctx, converter);
1106 
1107   // Convert return operations inside async.execute regions.
1108   patterns.insert<ReturnOpOpConversion>(converter, ctx);
1109 
1110   // Lower async operations to async runtime API calls.
1111   patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx);
1112   patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx);
1113 
1114   // Use LLVM type converter to automatically convert between the async value
1115   // payload type and LLVM type when loading/storing from/to the async
1116   // value storage which is an opaque i8* pointer using LLVM load/store ops.
1117   patterns
1118       .insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
1119           llvmConverter, ctx, outlinedFunctions);
1120   patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions);
1121 
1122   ConversionTarget target(*ctx);
1123   target.addLegalOp<ConstantOp>();
1124   target.addLegalDialect<LLVM::LLVMDialect>();
1125 
1126   // All operations from Async dialect must be lowered to the runtime API calls.
1127   target.addIllegalDialect<AsyncDialect>();
1128 
1129   // Add dynamic legality constraints to apply conversions defined above.
1130   target.addDynamicallyLegalOp<FuncOp>(
1131       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1132   target.addDynamicallyLegalOp<ReturnOp>(
1133       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1134   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1135     return converter.isSignatureLegal(op.getCalleeType());
1136   });
1137 
1138   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1139     signalPassFailure();
1140 }
1141 } // namespace
1142 
1143 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1144   return std::make_unique<ConvertAsyncToLLVMPass>();
1145 }
1146