1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Async/IR/Async.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #define DEBUG_TYPE "convert-async-to-llvm"
27 
28 using namespace mlir;
29 using namespace mlir::async;
30 
31 // Prefix for functions outlined from `async.execute` op regions.
32 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
33 
34 //===----------------------------------------------------------------------===//
35 // Async Runtime C API declaration.
36 //===----------------------------------------------------------------------===//
37 
38 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
39 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
40 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
41 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
42 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
43 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
44 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
45 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
46 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
47 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
48 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
49 static constexpr const char *kGetValueStorage =
50     "mlirAsyncRuntimeGetValueStorage";
51 static constexpr const char *kAddTokenToGroup =
52     "mlirAsyncRuntimeAddTokenToGroup";
53 static constexpr const char *kAwaitTokenAndExecute =
54     "mlirAsyncRuntimeAwaitTokenAndExecute";
55 static constexpr const char *kAwaitValueAndExecute =
56     "mlirAsyncRuntimeAwaitValueAndExecute";
57 static constexpr const char *kAwaitAllAndExecute =
58     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
59 
60 namespace {
61 /// Async Runtime API function types.
62 ///
63 /// Because we can't create API function signature for type parametrized
64 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
65 /// lowering all async data types become opaque pointers at runtime.
66 struct AsyncAPI {
67   // All async types are lowered to opaque i8* LLVM pointers at runtime.
68   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
69     return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
70   }
71 
72   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
73     auto ref = opaquePointerType(ctx);
74     auto count = IntegerType::get(ctx, 32);
75     return FunctionType::get(ctx, {ref, count}, {});
76   }
77 
78   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
79     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
80   }
81 
82   static FunctionType createValueFunctionType(MLIRContext *ctx) {
83     auto i32 = IntegerType::get(ctx, 32);
84     auto value = opaquePointerType(ctx);
85     return FunctionType::get(ctx, {i32}, {value});
86   }
87 
88   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
89     return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
90   }
91 
92   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
93     auto value = opaquePointerType(ctx);
94     auto storage = opaquePointerType(ctx);
95     return FunctionType::get(ctx, {value}, {storage});
96   }
97 
98   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
99     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
100   }
101 
102   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
103     auto value = opaquePointerType(ctx);
104     return FunctionType::get(ctx, {value}, {});
105   }
106 
107   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
108     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
109   }
110 
111   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
112     auto value = opaquePointerType(ctx);
113     return FunctionType::get(ctx, {value}, {});
114   }
115 
116   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
117     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
118   }
119 
120   static FunctionType executeFunctionType(MLIRContext *ctx) {
121     auto hdl = opaquePointerType(ctx);
122     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
123     return FunctionType::get(ctx, {hdl, resume}, {});
124   }
125 
126   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
127     auto i64 = IntegerType::get(ctx, 64);
128     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
129                              {i64});
130   }
131 
132   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
133     auto hdl = opaquePointerType(ctx);
134     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
135     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
136   }
137 
138   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
139     auto value = opaquePointerType(ctx);
140     auto hdl = opaquePointerType(ctx);
141     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
142     return FunctionType::get(ctx, {value, hdl, resume}, {});
143   }
144 
145   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
146     auto hdl = opaquePointerType(ctx);
147     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
148     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
149   }
150 
151   // Auxiliary coroutine resume intrinsic wrapper.
152   static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
153     auto voidTy = LLVM::LLVMVoidType::get(ctx);
154     auto i8Ptr = opaquePointerType(ctx);
155     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
156   }
157 };
158 } // namespace
159 
160 /// Adds Async Runtime C API declarations to the module.
161 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
162   auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
163                                                          module.getBody());
164 
165   auto addFuncDecl = [&](StringRef name, FunctionType type) {
166     if (module.lookupSymbol(name))
167       return;
168     builder.create<FuncOp>(name, type).setPrivate();
169   };
170 
171   MLIRContext *ctx = module.getContext();
172   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
173   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
174   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
175   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
176   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
177   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
178   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
179   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
180   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
181   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
182   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
183   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
184   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
185   addFuncDecl(kAwaitTokenAndExecute,
186               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
187   addFuncDecl(kAwaitValueAndExecute,
188               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
189   addFuncDecl(kAwaitAllAndExecute,
190               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // LLVM coroutines intrinsics declarations.
195 //===----------------------------------------------------------------------===//
196 
197 static constexpr const char *kCoroId = "llvm.coro.id";
198 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
199 static constexpr const char *kCoroBegin = "llvm.coro.begin";
200 static constexpr const char *kCoroSave = "llvm.coro.save";
201 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
202 static constexpr const char *kCoroEnd = "llvm.coro.end";
203 static constexpr const char *kCoroFree = "llvm.coro.free";
204 static constexpr const char *kCoroResume = "llvm.coro.resume";
205 
206 /// Adds an LLVM function declaration to a module.
207 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
208                             StringRef name, LLVM::LLVMType ret,
209                             ArrayRef<LLVM::LLVMType> params) {
210   if (module.lookupSymbol(name))
211     return;
212   LLVM::LLVMType type = LLVM::LLVMFunctionType::get(ret, params);
213   builder.create<LLVM::LLVMFuncOp>(name, type);
214 }
215 
216 /// Adds coroutine intrinsics declarations to the module.
217 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
218   using namespace mlir::LLVM;
219 
220   MLIRContext *ctx = module.getContext();
221   ImplicitLocOpBuilder builder(module.getLoc(),
222                                module.getBody()->getTerminator());
223 
224   auto token = LLVMTokenType::get(ctx);
225   auto voidTy = LLVMVoidType::get(ctx);
226 
227   auto i8 = LLVMIntegerType::get(ctx, 8);
228   auto i1 = LLVMIntegerType::get(ctx, 1);
229   auto i32 = LLVMIntegerType::get(ctx, 32);
230   auto i64 = LLVMIntegerType::get(ctx, 64);
231   auto i8Ptr = LLVMPointerType::get(i8);
232 
233   addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
234   addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
235   addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
236   addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
237   addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
238   addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
239   addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
240   addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Add malloc/free declarations to the module.
245 //===----------------------------------------------------------------------===//
246 
247 static constexpr const char *kMalloc = "malloc";
248 static constexpr const char *kFree = "free";
249 
250 /// Adds malloc/free declarations to the module.
251 static void addCRuntimeDeclarations(ModuleOp module) {
252   using namespace mlir::LLVM;
253 
254   MLIRContext *ctx = module.getContext();
255   ImplicitLocOpBuilder builder(module.getLoc(),
256                                module.getBody()->getTerminator());
257 
258   auto voidTy = LLVMVoidType::get(ctx);
259   auto i64 = LLVMIntegerType::get(ctx, 64);
260   auto i8Ptr = LLVMPointerType::get(LLVMIntegerType::get(ctx, 8));
261 
262   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
263   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // Coroutine resume function wrapper.
268 //===----------------------------------------------------------------------===//
269 
270 static constexpr const char *kResume = "__resume";
271 
272 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
273 /// intrinsics. We need this function to be able to pass it to the async
274 /// runtime execute API.
275 static void addResumeFunction(ModuleOp module) {
276   MLIRContext *ctx = module.getContext();
277 
278   OpBuilder moduleBuilder(module.getBody()->getTerminator());
279   Location loc = module.getLoc();
280 
281   if (module.lookupSymbol(kResume))
282     return;
283 
284   auto voidTy = LLVM::LLVMVoidType::get(ctx);
285   auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
286 
287   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
288       loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
289   resumeOp.setPrivate();
290 
291   auto *block = resumeOp.addEntryBlock();
292   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
293 
294   blockBuilder.create<LLVM::CallOp>(TypeRange(),
295                                     blockBuilder.getSymbolRefAttr(kCoroResume),
296                                     resumeOp.getArgument(0));
297 
298   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // async.execute op outlining to the coroutine functions.
303 //===----------------------------------------------------------------------===//
304 
305 /// Function targeted for coroutine transformation has two additional blocks at
306 /// the end: coroutine cleanup and coroutine suspension.
307 ///
308 /// async.await op lowering additionaly creates a resume block for each
309 /// operation to enable non-blocking waiting via coroutine suspension.
310 namespace {
311 struct CoroMachinery {
312   // Async execute region returns a completion token, and an async value for
313   // each yielded value.
314   //
315   //   %token, %result = async.execute -> !async.value<T> {
316   //     %0 = constant ... : T
317   //     async.yield %0 : T
318   //   }
319   Value asyncToken; // token representing completion of the async region
320   llvm::SmallVector<Value, 4> returnValues; // returned async values
321 
322   Value coroHandle;
323   Block *cleanup;
324   Block *suspend;
325 };
326 } // namespace
327 
328 /// Builds an coroutine template compatible with LLVM coroutines lowering.
329 ///
330 ///  - `entry` block sets up the coroutine.
331 ///  - `cleanup` block cleans up the coroutine state.
332 ///  - `suspend block after the @llvm.coro.end() defines what value will be
333 ///    returned to the initial caller of a coroutine. Everything before the
334 ///    @llvm.coro.end() will be executed at every suspension point.
335 ///
336 /// Coroutine structure (only the important bits):
337 ///
338 ///   func @async_execute_fn(<function-arguments>)
339 ///        -> (!async.token, !async.value<T>)
340 ///   {
341 ///     ^entryBlock(<function-arguments>):
342 ///       %token = <async token> : !async.token    // create async runtime token
343 ///       %value = <async value> : !async.value<T> // create async value
344 ///       %hdl = llvm.call @llvm.coro.id(...)      // create a coroutine handle
345 ///       br ^cleanup
346 ///
347 ///     ^cleanup:
348 ///       llvm.call @llvm.coro.free(...)  // delete coroutine state
349 ///       br ^suspend
350 ///
351 ///     ^suspend:
352 ///       llvm.call @llvm.coro.end(...)  // marks the end of a coroutine
353 ///       return %token, %value : !async.token, !async.value<T>
354 ///   }
355 ///
356 /// The actual code for the async.execute operation body region will be inserted
357 /// before the entry block terminator.
358 ///
359 ///
360 static CoroMachinery setupCoroMachinery(FuncOp func) {
361   assert(func.getBody().empty() && "Function must have empty body");
362 
363   MLIRContext *ctx = func.getContext();
364 
365   auto token = LLVM::LLVMTokenType::get(ctx);
366   auto i1 = LLVM::LLVMIntegerType::get(ctx, 1);
367   auto i32 = LLVM::LLVMIntegerType::get(ctx, 32);
368   auto i64 = LLVM::LLVMIntegerType::get(ctx, 64);
369   auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
370 
371   Block *entryBlock = func.addEntryBlock();
372   Location loc = func.getBody().getLoc();
373 
374   auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock);
375 
376   // ------------------------------------------------------------------------ //
377   // Allocate async tokens/values that we will return from a ramp function.
378   // ------------------------------------------------------------------------ //
379   auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx));
380 
381   // Async value operands and results must be convertible to LLVM types. This is
382   // verified before the function outlining.
383   LLVMTypeConverter converter(ctx);
384 
385   // Returns the size requirements for the async value storage.
386   // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
387   auto sizeOf = [&](ValueType valueType) -> Value {
388     auto storedType = converter.convertType(valueType.getValueType());
389     auto storagePtrType =
390         LLVM::LLVMPointerType::get(storedType.cast<LLVM::LLVMType>());
391 
392     // %Size = getelementptr %T* null, int 1
393     // %SizeI = ptrtoint %T* %Size to i32
394     auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType);
395     auto one = builder.create<LLVM::ConstantOp>(loc, i32,
396                                                 builder.getI32IntegerAttr(1));
397     auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
398                                            one.getResult());
399     auto size = builder.create<LLVM::PtrToIntOp>(loc, i32, gep);
400 
401     // Cast to std type because runtime API defined using std types.
402     return builder.create<LLVM::DialectCastOp>(loc, builder.getI32Type(),
403                                                size.getResult());
404   };
405 
406   // We use the `async.value` type as a return type although it does not match
407   // the `kCreateValue` function signature, because it will be later lowered to
408   // the runtime type (opaque i8* pointer).
409   llvm::SmallVector<CallOp, 4> createValues;
410   for (auto resultType : func.getCallableResults().drop_front(1))
411     createValues.emplace_back(builder.create<CallOp>(
412         loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>())));
413 
414   auto createdValues = llvm::map_range(
415       createValues, [](CallOp call) { return call.getResult(0); });
416   llvm::SmallVector<Value, 4> returnValues(createdValues.begin(),
417                                            createdValues.end());
418 
419   // ------------------------------------------------------------------------ //
420   // Initialize coroutine: allocate frame, get coroutine handle.
421   // ------------------------------------------------------------------------ //
422 
423   // Constants for initializing coroutine frame.
424   auto constZero =
425       builder.create<LLVM::ConstantOp>(i32, builder.getI32IntegerAttr(0));
426   auto constFalse =
427       builder.create<LLVM::ConstantOp>(i1, builder.getBoolAttr(false));
428   auto nullPtr = builder.create<LLVM::NullOp>(i8Ptr);
429 
430   // Get coroutine id: @llvm.coro.id
431   auto coroId = builder.create<LLVM::CallOp>(
432       token, builder.getSymbolRefAttr(kCoroId),
433       ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
434 
435   // Get coroutine frame size: @llvm.coro.size.i64
436   auto coroSize = builder.create<LLVM::CallOp>(
437       i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
438 
439   // Allocate memory for coroutine frame.
440   auto coroAlloc =
441       builder.create<LLVM::CallOp>(i8Ptr, builder.getSymbolRefAttr(kMalloc),
442                                    ValueRange(coroSize.getResult(0)));
443 
444   // Begin a coroutine: @llvm.coro.begin
445   auto coroHdl = builder.create<LLVM::CallOp>(
446       i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
447       ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
448 
449   Block *cleanupBlock = func.addBlock();
450   Block *suspendBlock = func.addBlock();
451 
452   // ------------------------------------------------------------------------ //
453   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
454   // ------------------------------------------------------------------------ //
455   builder.setInsertionPointToStart(cleanupBlock);
456 
457   // Get a pointer to the coroutine frame memory: @llvm.coro.free.
458   auto coroMem = builder.create<LLVM::CallOp>(
459       i8Ptr, builder.getSymbolRefAttr(kCoroFree),
460       ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
461 
462   // Free the memory.
463   builder.create<LLVM::CallOp>(TypeRange(), builder.getSymbolRefAttr(kFree),
464                                ValueRange(coroMem.getResult(0)));
465   // Branch into the suspend block.
466   builder.create<BranchOp>(suspendBlock);
467 
468   // ------------------------------------------------------------------------ //
469   // Coroutine suspend block: mark the end of a coroutine and return allocated
470   // async token.
471   // ------------------------------------------------------------------------ //
472   builder.setInsertionPointToStart(suspendBlock);
473 
474   // Mark the end of a coroutine: @llvm.coro.end.
475   builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd),
476                                ValueRange({coroHdl.getResult(0), constFalse}));
477 
478   // Return created `async.token` and `async.values` from the suspend block.
479   // This will be the return value of a coroutine ramp function.
480   SmallVector<Value, 4> ret{createToken.getResult(0)};
481   ret.insert(ret.end(), returnValues.begin(), returnValues.end());
482   builder.create<ReturnOp>(loc, ret);
483 
484   // Branch from the entry block to the cleanup block to create a valid CFG.
485   builder.setInsertionPointToEnd(entryBlock);
486 
487   builder.create<BranchOp>(cleanupBlock);
488 
489   // `async.await` op lowering will create resume blocks for async
490   // continuations, and will conditionally branch to cleanup or suspend blocks.
491 
492   CoroMachinery machinery;
493   machinery.asyncToken = createToken.getResult(0);
494   machinery.returnValues = returnValues;
495   machinery.coroHandle = coroHdl.getResult(0);
496   machinery.cleanup = cleanupBlock;
497   machinery.suspend = suspendBlock;
498   return machinery;
499 }
500 
501 /// Add a LLVM coroutine suspension point to the end of suspended block, to
502 /// resume execution in resume block. The caller is responsible for creating the
503 /// two suspended/resume blocks with the desired ops contained in each block.
504 /// This function merely provides the required control flow logic.
505 ///
506 /// `coroState` must be a value returned from the call to @llvm.coro.save(...)
507 /// intrinsic (saved coroutine state).
508 ///
509 /// Before:
510 ///
511 ///   ^bb0:
512 ///     "opBefore"(...)
513 ///     "op"(...)
514 ///   ^cleanup: ...
515 ///   ^suspend: ...
516 ///   ^resume:
517 ///     "op"(...)
518 ///
519 /// After:
520 ///
521 ///   ^bb0:
522 ///     "opBefore"(...)
523 ///     %suspend = llmv.call @llvm.coro.suspend(...)
524 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
525 ///   ^resume:
526 ///     "op"(...)
527 ///   ^cleanup: ...
528 ///   ^suspend: ...
529 ///
530 static void addSuspensionPoint(CoroMachinery coro, Value coroState,
531                                Operation *op, Block *suspended, Block *resume,
532                                OpBuilder &builder) {
533   Location loc = op->getLoc();
534   MLIRContext *ctx = op->getContext();
535   auto i1 = LLVM::LLVMIntegerType::get(ctx, 1);
536   auto i8 = LLVM::LLVMIntegerType::get(ctx, 8);
537 
538   // Add a coroutine suspension in place of original `op` in the split block.
539   OpBuilder::InsertionGuard guard(builder);
540   builder.setInsertionPointToEnd(suspended);
541 
542   auto constFalse =
543       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
544 
545   // Suspend a coroutine: @llvm.coro.suspend
546   auto coroSuspend = builder.create<LLVM::CallOp>(
547       loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
548       ValueRange({coroState, constFalse}));
549 
550   // After a suspension point decide if we should branch into resume, cleanup
551   // or suspend block of the coroutine (see @llvm.coro.suspend return code
552   // documentation).
553   auto constZero =
554       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
555   auto constNegOne =
556       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
557 
558   Block *resumeOrCleanup = builder.createBlock(resume);
559 
560   // Suspend the coroutine ...?
561   builder.setInsertionPointToEnd(suspended);
562   auto isNegOne = builder.create<LLVM::ICmpOp>(
563       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
564   builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
565                                  /*falseDest=*/resumeOrCleanup);
566 
567   // ... or resume or cleanup the coroutine?
568   builder.setInsertionPointToStart(resumeOrCleanup);
569   auto isZero = builder.create<LLVM::ICmpOp>(
570       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
571   builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
572                                  /*falseDest=*/coro.cleanup);
573 }
574 
575 /// Outline the body region attached to the `async.execute` op into a standalone
576 /// function.
577 ///
578 /// Note that this is not reversible transformation.
579 static std::pair<FuncOp, CoroMachinery>
580 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
581   ModuleOp module = execute->getParentOfType<ModuleOp>();
582 
583   MLIRContext *ctx = module.getContext();
584   Location loc = execute.getLoc();
585 
586   // Collect all outlined function inputs.
587   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
588                                               execute.dependencies().end());
589   functionInputs.insert(execute.operands().begin(), execute.operands().end());
590   getUsedValuesDefinedAbove(execute.body(), functionInputs);
591 
592   // Collect types for the outlined function inputs and outputs.
593   auto typesRange = llvm::map_range(
594       functionInputs, [](Value value) { return value.getType(); });
595   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
596   auto outputTypes = execute.getResultTypes();
597 
598   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
599   auto funcAttrs = ArrayRef<NamedAttribute>();
600 
601   // TODO: Derive outlined function name from the parent FuncOp (support
602   // multiple nested async.execute operations).
603   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
604   symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
605 
606   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
607 
608   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
609   // blocks, adding llvm.coro instrinsics and setting up control flow.
610   CoroMachinery coro = setupCoroMachinery(func);
611 
612   // Suspend async function at the end of an entry block, and resume it using
613   // Async execute API (execution will be resumed in a thread managed by the
614   // async runtime).
615   Block *entryBlock = &func.getBlocks().front();
616   auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
617 
618   // A pointer to coroutine resume intrinsic wrapper.
619   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
620   auto resumePtr = builder.create<LLVM::AddressOfOp>(
621       LLVM::LLVMPointerType::get(resumeFnTy), kResume);
622 
623   // Save the coroutine state: @llvm.coro.save
624   auto coroSave = builder.create<LLVM::CallOp>(
625       LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
626       ValueRange({coro.coroHandle}));
627 
628   // Call async runtime API to execute a coroutine in the managed thread.
629   SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
630   builder.create<CallOp>(TypeRange(), kExecute, executeArgs);
631 
632   // Split the entry block before the terminator.
633   auto *terminatorOp = entryBlock->getTerminator();
634   Block *suspended = terminatorOp->getBlock();
635   Block *resume = suspended->splitBlock(terminatorOp);
636   addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended,
637                      resume, builder);
638 
639   size_t numDependencies = execute.dependencies().size();
640   size_t numOperands = execute.operands().size();
641 
642   // Await on all dependencies before starting to execute the body region.
643   builder.setInsertionPointToStart(resume);
644   for (size_t i = 0; i < numDependencies; ++i)
645     builder.create<AwaitOp>(func.getArgument(i));
646 
647   // Await on all async value operands and unwrap the payload.
648   SmallVector<Value, 4> unwrappedOperands(numOperands);
649   for (size_t i = 0; i < numOperands; ++i) {
650     Value operand = func.getArgument(numDependencies + i);
651     unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
652   }
653 
654   // Map from function inputs defined above the execute op to the function
655   // arguments.
656   BlockAndValueMapping valueMapping;
657   valueMapping.map(functionInputs, func.getArguments());
658   valueMapping.map(execute.body().getArguments(), unwrappedOperands);
659 
660   // Clone all operations from the execute operation body into the outlined
661   // function body.
662   for (Operation &op : execute.body().getOps())
663     builder.clone(op, valueMapping);
664 
665   // Replace the original `async.execute` with a call to outlined function.
666   ImplicitLocOpBuilder callBuilder(loc, execute);
667   auto callOutlinedFunc = callBuilder.create<CallOp>(
668       func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
669   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
670   execute.erase();
671 
672   return {func, coro};
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // Convert Async dialect types to LLVM types.
677 //===----------------------------------------------------------------------===//
678 
679 namespace {
680 
681 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
682 /// their runtime type (opaque pointers) and does not convert any other types.
683 class AsyncRuntimeTypeConverter : public TypeConverter {
684 public:
685   AsyncRuntimeTypeConverter() {
686     addConversion([](Type type) { return type; });
687     addConversion(convertAsyncTypes);
688   }
689 
690   static Optional<Type> convertAsyncTypes(Type type) {
691     if (type.isa<TokenType, GroupType, ValueType>())
692       return AsyncAPI::opaquePointerType(type.getContext());
693     return llvm::None;
694   }
695 };
696 } // namespace
697 
698 //===----------------------------------------------------------------------===//
699 // Convert return operations that return async values from async regions.
700 //===----------------------------------------------------------------------===//
701 
702 namespace {
703 class ReturnOpOpConversion : public ConversionPattern {
704 public:
705   explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx)
706       : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {}
707 
708   LogicalResult
709   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
710                   ConversionPatternRewriter &rewriter) const override {
711     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
712     return success();
713   }
714 };
715 } // namespace
716 
717 //===----------------------------------------------------------------------===//
718 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref`
719 // to the corresponding API calls).
720 //===----------------------------------------------------------------------===//
721 
722 namespace {
723 
724 template <typename RefCountingOp>
725 class RefCountingOpLowering : public ConversionPattern {
726 public:
727   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
728                                  StringRef apiFunctionName)
729       : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx),
730         apiFunctionName(apiFunctionName) {}
731 
732   LogicalResult
733   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
734                   ConversionPatternRewriter &rewriter) const override {
735     RefCountingOp refCountingOp = cast<RefCountingOp>(op);
736 
737     auto count = rewriter.create<ConstantOp>(
738         op->getLoc(), rewriter.getI32Type(),
739         rewriter.getI32IntegerAttr(refCountingOp.count()));
740 
741     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
742                                         ValueRange({operands[0], count}));
743 
744     return success();
745   }
746 
747 private:
748   StringRef apiFunctionName;
749 };
750 
751 /// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
752 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
753 public:
754   explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
755       : RefCountingOpLowering(converter, ctx, kAddRef) {}
756 };
757 
758 /// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
759 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
760 public:
761   explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
762       : RefCountingOpLowering(converter, ctx, kDropRef) {}
763 };
764 
765 } // namespace
766 
767 //===----------------------------------------------------------------------===//
768 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
769 //===----------------------------------------------------------------------===//
770 
771 namespace {
772 class CreateGroupOpLowering : public ConversionPattern {
773 public:
774   explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
775       : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter,
776                           ctx) {}
777 
778   LogicalResult
779   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
780                   ConversionPatternRewriter &rewriter) const override {
781     auto retTy = GroupType::get(op->getContext());
782     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy);
783     return success();
784   }
785 };
786 } // namespace
787 
788 //===----------------------------------------------------------------------===//
789 // async.add_to_group op lowering to runtime function call.
790 //===----------------------------------------------------------------------===//
791 
792 namespace {
793 class AddToGroupOpLowering : public ConversionPattern {
794 public:
795   explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
796       : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) {
797   }
798 
799   LogicalResult
800   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
801                   ConversionPatternRewriter &rewriter) const override {
802     // Currently we can only add tokens to the group.
803     auto addToGroup = cast<AddToGroupOp>(op);
804     if (!addToGroup.operand().getType().isa<TokenType>())
805       return failure();
806 
807     auto i64 = IntegerType::get(op->getContext(), 64);
808     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
809     return success();
810   }
811 };
812 } // namespace
813 
814 //===----------------------------------------------------------------------===//
815 // async.await and async.await_all op lowerings to the corresponding async
816 // runtime function calls.
817 //===----------------------------------------------------------------------===//
818 
819 namespace {
820 
821 template <typename AwaitType, typename AwaitableType>
822 class AwaitOpLoweringBase : public ConversionPattern {
823 protected:
824   explicit AwaitOpLoweringBase(
825       TypeConverter &converter, MLIRContext *ctx,
826       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
827       StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
828       : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx),
829         outlinedFunctions(outlinedFunctions),
830         blockingAwaitFuncName(blockingAwaitFuncName),
831         coroAwaitFuncName(coroAwaitFuncName) {}
832 
833 public:
834   LogicalResult
835   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
836                   ConversionPatternRewriter &rewriter) const override {
837     // We can only await on one the `AwaitableType` (for `await` it can be
838     // a `token` or a `value`, for `await_all` it must be a `group`).
839     auto await = cast<AwaitType>(op);
840     if (!await.operand().getType().template isa<AwaitableType>())
841       return failure();
842 
843     // Check if await operation is inside the outlined coroutine function.
844     auto func = await->template getParentOfType<FuncOp>();
845     auto outlined = outlinedFunctions.find(func);
846     const bool isInCoroutine = outlined != outlinedFunctions.end();
847 
848     Location loc = op->getLoc();
849 
850     // Inside regular function we convert await operation to the blocking
851     // async API await function call.
852     if (!isInCoroutine)
853       rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName,
854                               ValueRange(operands[0]));
855 
856     // Inside the coroutine we convert await operation into coroutine suspension
857     // point, and resume execution asynchronously.
858     if (isInCoroutine) {
859       const CoroMachinery &coro = outlined->getSecond();
860 
861       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
862       MLIRContext *ctx = op->getContext();
863 
864       // A pointer to coroutine resume intrinsic wrapper.
865       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
866       auto resumePtr = builder.create<LLVM::AddressOfOp>(
867           LLVM::LLVMPointerType::get(resumeFnTy), kResume);
868 
869       // Save the coroutine state: @llvm.coro.save
870       auto coroSave = builder.create<LLVM::CallOp>(
871           LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
872           ValueRange(coro.coroHandle));
873 
874       // Call async runtime API to resume a coroutine in the managed thread when
875       // the async await argument becomes ready.
876       SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle,
877                                                    resumePtr.res()};
878       builder.create<CallOp>(TypeRange(), coroAwaitFuncName,
879                              awaitAndExecuteArgs);
880 
881       Block *suspended = op->getBlock();
882 
883       // Split the entry block before the await operation.
884       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
885       addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
886                          builder);
887 
888       // Make sure that replacement value will be constructed in resume block.
889       rewriter.setInsertionPointToStart(resume);
890     }
891 
892     // Replace or erase the await operation with the new value.
893     if (Value replaceWith = getReplacementValue(op, operands[0], rewriter))
894       rewriter.replaceOp(op, replaceWith);
895     else
896       rewriter.eraseOp(op);
897 
898     return success();
899   }
900 
901   virtual Value getReplacementValue(Operation *op, Value operand,
902                                     ConversionPatternRewriter &rewriter) const {
903     return Value();
904   }
905 
906 private:
907   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
908   StringRef blockingAwaitFuncName;
909   StringRef coroAwaitFuncName;
910 };
911 
912 /// Lowering for `async.await` with a token operand.
913 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
914   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
915 
916 public:
917   explicit AwaitTokenOpLowering(
918       TypeConverter &converter, MLIRContext *ctx,
919       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
920       : Base(converter, ctx, outlinedFunctions, kAwaitToken,
921              kAwaitTokenAndExecute) {}
922 };
923 
924 /// Lowering for `async.await` with a value operand.
925 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
926   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
927 
928 public:
929   explicit AwaitValueOpLowering(
930       TypeConverter &converter, MLIRContext *ctx,
931       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
932       : Base(converter, ctx, outlinedFunctions, kAwaitValue,
933              kAwaitValueAndExecute) {}
934 
935   Value
936   getReplacementValue(Operation *op, Value operand,
937                       ConversionPatternRewriter &rewriter) const override {
938     Location loc = op->getLoc();
939     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
940 
941     // Get the underlying value type from the `async.value`.
942     auto await = cast<AwaitOp>(op);
943     auto valueType = await.operand().getType().cast<ValueType>().getValueType();
944 
945     // Get a pointer to an async value storage from the runtime.
946     auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
947                                            TypeRange(i8Ptr), operand);
948 
949     // Cast from i8* to the pointer pointer to LLVM type.
950     auto llvmValueType = getTypeConverter()->convertType(valueType);
951     auto castedStorage = rewriter.create<LLVM::BitcastOp>(
952         loc, LLVM::LLVMPointerType::get(llvmValueType.cast<LLVM::LLVMType>()),
953         storage.getResult(0));
954 
955     // Load from the async value storage.
956     auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult());
957 
958     // Cast from LLVM type to the expected value type. This cast will become
959     // no-op after lowering to LLVM.
960     return rewriter.create<LLVM::DialectCastOp>(loc, valueType, loaded);
961   }
962 };
963 
964 /// Lowering for `async.await_all` operation.
965 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
966   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
967 
968 public:
969   explicit AwaitAllOpLowering(
970       TypeConverter &converter, MLIRContext *ctx,
971       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
972       : Base(converter, ctx, outlinedFunctions, kAwaitGroup,
973              kAwaitAllAndExecute) {}
974 };
975 
976 } // namespace
977 
978 //===----------------------------------------------------------------------===//
979 // async.yield op lowerings to the corresponding async runtime function calls.
980 //===----------------------------------------------------------------------===//
981 
982 class YieldOpLowering : public ConversionPattern {
983 public:
984   explicit YieldOpLowering(
985       TypeConverter &converter, MLIRContext *ctx,
986       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
987       : ConversionPattern(async::YieldOp::getOperationName(), 1, converter,
988                           ctx),
989         outlinedFunctions(outlinedFunctions) {}
990 
991   LogicalResult
992   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
993                   ConversionPatternRewriter &rewriter) const override {
994     // Check if yield operation is inside the outlined coroutine function.
995     auto func = op->template getParentOfType<FuncOp>();
996     auto outlined = outlinedFunctions.find(func);
997     if (outlined == outlinedFunctions.end())
998       return op->emitOpError(
999           "async.yield is not inside the outlined coroutine function");
1000 
1001     Location loc = op->getLoc();
1002     const CoroMachinery &coro = outlined->getSecond();
1003 
1004     // Store yielded values into the async values storage and emplace them.
1005     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
1006 
1007     for (auto tuple : llvm::zip(operands, coro.returnValues)) {
1008       // Store `yieldValue` into the `asyncValue` storage.
1009       Value yieldValue = std::get<0>(tuple);
1010       Value asyncValue = std::get<1>(tuple);
1011 
1012       // Get an opaque i8* pointer to an async value storage from the runtime.
1013       auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
1014                                              TypeRange(i8Ptr), asyncValue);
1015 
1016       // Cast storage pointer to the yielded value type.
1017       auto castedStorage = rewriter.create<LLVM::BitcastOp>(
1018           loc,
1019           LLVM::LLVMPointerType::get(
1020               yieldValue.getType().cast<LLVM::LLVMType>()),
1021           storage.getResult(0));
1022 
1023       // Store the yielded value into the async value storage.
1024       rewriter.create<LLVM::StoreOp>(loc, yieldValue,
1025                                      castedStorage.getResult());
1026 
1027       // Emplace the `async.value` to mark it ready.
1028       rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue);
1029     }
1030 
1031     // Emplace the completion token to mark it ready.
1032     rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
1033 
1034     // Original operation was replaced by the function call(s).
1035     rewriter.eraseOp(op);
1036 
1037     return success();
1038   }
1039 
1040 private:
1041   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
1042 };
1043 
1044 //===----------------------------------------------------------------------===//
1045 
1046 namespace {
1047 struct ConvertAsyncToLLVMPass
1048     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
1049   void runOnOperation() override;
1050 };
1051 
1052 void ConvertAsyncToLLVMPass::runOnOperation() {
1053   ModuleOp module = getOperation();
1054   SymbolTable symbolTable(module);
1055 
1056   MLIRContext *ctx = &getContext();
1057 
1058   // Outline all `async.execute` body regions into async functions (coroutines).
1059   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
1060 
1061   // We use conversion to LLVM type to ensure that all `async.value` operands
1062   // and results can be lowered to LLVM load and store operations.
1063   LLVMTypeConverter llvmConverter(ctx);
1064   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
1065 
1066   // Returns true if the `async.value` payload is convertible to LLVM.
1067   auto isConvertibleToLlvm = [&](Type type) -> bool {
1068     auto valueType = type.cast<ValueType>().getValueType();
1069     return static_cast<bool>(llvmConverter.convertType(valueType));
1070   };
1071 
1072   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
1073     // All operands and results must be convertible to LLVM.
1074     if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) {
1075       execute.emitOpError("operands payload must be convertible to LLVM type");
1076       return WalkResult::interrupt();
1077     }
1078     if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) {
1079       execute.emitOpError("results payload must be convertible to LLVM type");
1080       return WalkResult::interrupt();
1081     }
1082 
1083     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
1084 
1085     return WalkResult::advance();
1086   });
1087 
1088   // Failed to outline all async execute operations.
1089   if (outlineResult.wasInterrupted()) {
1090     signalPassFailure();
1091     return;
1092   }
1093 
1094   LLVM_DEBUG({
1095     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
1096                  << " async functions\n";
1097   });
1098 
1099   // Add declarations for all functions required by the coroutines lowering.
1100   addResumeFunction(module);
1101   addAsyncRuntimeApiDeclarations(module);
1102   addCoroutineIntrinsicsDeclarations(module);
1103   addCRuntimeDeclarations(module);
1104 
1105   // Convert async dialect types and operations to LLVM dialect.
1106   AsyncRuntimeTypeConverter converter;
1107   OwningRewritePatternList patterns;
1108 
1109   // Convert async types in function signatures and function calls.
1110   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
1111   populateCallOpTypeConversionPattern(patterns, ctx, converter);
1112 
1113   // Convert return operations inside async.execute regions.
1114   patterns.insert<ReturnOpOpConversion>(converter, ctx);
1115 
1116   // Lower async operations to async runtime API calls.
1117   patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx);
1118   patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx);
1119 
1120   // Use LLVM type converter to automatically convert between the async value
1121   // payload type and LLVM type when loading/storing from/to the async
1122   // value storage which is an opaque i8* pointer using LLVM load/store ops.
1123   patterns
1124       .insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
1125           llvmConverter, ctx, outlinedFunctions);
1126   patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions);
1127 
1128   ConversionTarget target(*ctx);
1129   target.addLegalOp<ConstantOp>();
1130   target.addLegalDialect<LLVM::LLVMDialect>();
1131 
1132   // All operations from Async dialect must be lowered to the runtime API calls.
1133   target.addIllegalDialect<AsyncDialect>();
1134 
1135   // Add dynamic legality constraints to apply conversions defined above.
1136   target.addDynamicallyLegalOp<FuncOp>(
1137       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1138   target.addDynamicallyLegalOp<ReturnOp>(
1139       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1140   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1141     return converter.isSignatureLegal(op.getCalleeType());
1142   });
1143 
1144   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1145     signalPassFailure();
1146 }
1147 } // namespace
1148 
1149 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1150   return std::make_unique<ConvertAsyncToLLVMPass>();
1151 }
1152