1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Async/IR/Async.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #define DEBUG_TYPE "convert-async-to-llvm"
27 
28 using namespace mlir;
29 using namespace mlir::async;
30 
31 // Prefix for functions outlined from `async.execute` op regions.
32 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
33 
34 //===----------------------------------------------------------------------===//
35 // Async Runtime C API declaration.
36 //===----------------------------------------------------------------------===//
37 
38 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
39 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
40 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
41 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
42 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
43 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
44 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
45 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
46 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
47 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
48 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
49 static constexpr const char *kGetValueStorage =
50     "mlirAsyncRuntimeGetValueStorage";
51 static constexpr const char *kAddTokenToGroup =
52     "mlirAsyncRuntimeAddTokenToGroup";
53 static constexpr const char *kAwaitTokenAndExecute =
54     "mlirAsyncRuntimeAwaitTokenAndExecute";
55 static constexpr const char *kAwaitValueAndExecute =
56     "mlirAsyncRuntimeAwaitValueAndExecute";
57 static constexpr const char *kAwaitAllAndExecute =
58     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
59 
60 namespace {
61 /// Async Runtime API function types.
62 ///
63 /// Because we can't create API function signature for type parametrized
64 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
65 /// lowering all async data types become opaque pointers at runtime.
66 struct AsyncAPI {
67   // All async types are lowered to opaque i8* LLVM pointers at runtime.
68   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
69     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
70   }
71 
72   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
73     auto ref = opaquePointerType(ctx);
74     auto count = IntegerType::get(ctx, 32);
75     return FunctionType::get(ctx, {ref, count}, {});
76   }
77 
78   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
79     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
80   }
81 
82   static FunctionType createValueFunctionType(MLIRContext *ctx) {
83     auto i32 = IntegerType::get(ctx, 32);
84     auto value = opaquePointerType(ctx);
85     return FunctionType::get(ctx, {i32}, {value});
86   }
87 
88   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
89     return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
90   }
91 
92   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
93     auto value = opaquePointerType(ctx);
94     auto storage = opaquePointerType(ctx);
95     return FunctionType::get(ctx, {value}, {storage});
96   }
97 
98   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
99     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
100   }
101 
102   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
103     auto value = opaquePointerType(ctx);
104     return FunctionType::get(ctx, {value}, {});
105   }
106 
107   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
108     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
109   }
110 
111   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
112     auto value = opaquePointerType(ctx);
113     return FunctionType::get(ctx, {value}, {});
114   }
115 
116   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
117     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
118   }
119 
120   static FunctionType executeFunctionType(MLIRContext *ctx) {
121     auto hdl = opaquePointerType(ctx);
122     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
123     return FunctionType::get(ctx, {hdl, resume}, {});
124   }
125 
126   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
127     auto i64 = IntegerType::get(ctx, 64);
128     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
129                              {i64});
130   }
131 
132   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
133     auto hdl = opaquePointerType(ctx);
134     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
135     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
136   }
137 
138   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
139     auto value = opaquePointerType(ctx);
140     auto hdl = opaquePointerType(ctx);
141     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
142     return FunctionType::get(ctx, {value, hdl, resume}, {});
143   }
144 
145   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
146     auto hdl = opaquePointerType(ctx);
147     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
148     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
149   }
150 
151   // Auxiliary coroutine resume intrinsic wrapper.
152   static Type resumeFunctionType(MLIRContext *ctx) {
153     auto voidTy = LLVM::LLVMVoidType::get(ctx);
154     auto i8Ptr = opaquePointerType(ctx);
155     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
156   }
157 };
158 } // namespace
159 
160 /// Adds Async Runtime C API declarations to the module.
161 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
162   auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
163                                                          module.getBody());
164 
165   auto addFuncDecl = [&](StringRef name, FunctionType type) {
166     if (module.lookupSymbol(name))
167       return;
168     builder.create<FuncOp>(name, type).setPrivate();
169   };
170 
171   MLIRContext *ctx = module.getContext();
172   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
173   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
174   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
175   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
176   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
177   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
178   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
179   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
180   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
181   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
182   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
183   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
184   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
185   addFuncDecl(kAwaitTokenAndExecute,
186               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
187   addFuncDecl(kAwaitValueAndExecute,
188               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
189   addFuncDecl(kAwaitAllAndExecute,
190               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // LLVM coroutines intrinsics declarations.
195 //===----------------------------------------------------------------------===//
196 
197 static constexpr const char *kCoroId = "llvm.coro.id";
198 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
199 static constexpr const char *kCoroBegin = "llvm.coro.begin";
200 static constexpr const char *kCoroSave = "llvm.coro.save";
201 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
202 static constexpr const char *kCoroEnd = "llvm.coro.end";
203 static constexpr const char *kCoroFree = "llvm.coro.free";
204 static constexpr const char *kCoroResume = "llvm.coro.resume";
205 
206 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
207                             StringRef name, Type ret, ArrayRef<Type> params) {
208   if (module.lookupSymbol(name))
209     return;
210   Type type = LLVM::LLVMFunctionType::get(ret, params);
211   builder.create<LLVM::LLVMFuncOp>(name, type);
212 }
213 
214 /// Adds coroutine intrinsics declarations to the module.
215 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
216   using namespace mlir::LLVM;
217 
218   MLIRContext *ctx = module.getContext();
219   ImplicitLocOpBuilder builder(module.getLoc(),
220                                module.getBody()->getTerminator());
221 
222   auto token = LLVMTokenType::get(ctx);
223   auto voidTy = LLVMVoidType::get(ctx);
224 
225   auto i8 = IntegerType::get(ctx, 8);
226   auto i1 = IntegerType::get(ctx, 1);
227   auto i32 = IntegerType::get(ctx, 32);
228   auto i64 = IntegerType::get(ctx, 64);
229   auto i8Ptr = LLVMPointerType::get(i8);
230 
231   addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
232   addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
233   addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
234   addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
235   addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
236   addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
237   addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
238   addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // Add malloc/free declarations to the module.
243 //===----------------------------------------------------------------------===//
244 
245 static constexpr const char *kMalloc = "malloc";
246 static constexpr const char *kFree = "free";
247 
248 /// Adds malloc/free declarations to the module.
249 static void addCRuntimeDeclarations(ModuleOp module) {
250   using namespace mlir::LLVM;
251 
252   MLIRContext *ctx = module.getContext();
253   ImplicitLocOpBuilder builder(module.getLoc(),
254                                module.getBody()->getTerminator());
255 
256   auto voidTy = LLVMVoidType::get(ctx);
257   auto i64 = IntegerType::get(ctx, 64);
258   auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8));
259 
260   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
261   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // Coroutine resume function wrapper.
266 //===----------------------------------------------------------------------===//
267 
268 static constexpr const char *kResume = "__resume";
269 
270 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
271 /// intrinsics. We need this function to be able to pass it to the async
272 /// runtime execute API.
273 static void addResumeFunction(ModuleOp module) {
274   MLIRContext *ctx = module.getContext();
275 
276   OpBuilder moduleBuilder(module.getBody()->getTerminator());
277   Location loc = module.getLoc();
278 
279   if (module.lookupSymbol(kResume))
280     return;
281 
282   auto voidTy = LLVM::LLVMVoidType::get(ctx);
283   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
284 
285   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
286       loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
287   resumeOp.setPrivate();
288 
289   auto *block = resumeOp.addEntryBlock();
290   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
291 
292   blockBuilder.create<LLVM::CallOp>(TypeRange(),
293                                     blockBuilder.getSymbolRefAttr(kCoroResume),
294                                     resumeOp.getArgument(0));
295 
296   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // async.execute op outlining to the coroutine functions.
301 //===----------------------------------------------------------------------===//
302 
303 /// Function targeted for coroutine transformation has two additional blocks at
304 /// the end: coroutine cleanup and coroutine suspension.
305 ///
306 /// async.await op lowering additionaly creates a resume block for each
307 /// operation to enable non-blocking waiting via coroutine suspension.
308 namespace {
309 struct CoroMachinery {
310   // Async execute region returns a completion token, and an async value for
311   // each yielded value.
312   //
313   //   %token, %result = async.execute -> !async.value<T> {
314   //     %0 = constant ... : T
315   //     async.yield %0 : T
316   //   }
317   Value asyncToken; // token representing completion of the async region
318   llvm::SmallVector<Value, 4> returnValues; // returned async values
319 
320   Value coroHandle;
321   Block *cleanup;
322   Block *suspend;
323 };
324 } // namespace
325 
326 /// Builds an coroutine template compatible with LLVM coroutines lowering.
327 ///
328 ///  - `entry` block sets up the coroutine.
329 ///  - `cleanup` block cleans up the coroutine state.
330 ///  - `suspend block after the @llvm.coro.end() defines what value will be
331 ///    returned to the initial caller of a coroutine. Everything before the
332 ///    @llvm.coro.end() will be executed at every suspension point.
333 ///
334 /// Coroutine structure (only the important bits):
335 ///
336 ///   func @async_execute_fn(<function-arguments>)
337 ///        -> (!async.token, !async.value<T>)
338 ///   {
339 ///     ^entryBlock(<function-arguments>):
340 ///       %token = <async token> : !async.token    // create async runtime token
341 ///       %value = <async value> : !async.value<T> // create async value
342 ///       %hdl = llvm.call @llvm.coro.id(...)      // create a coroutine handle
343 ///       br ^cleanup
344 ///
345 ///     ^cleanup:
346 ///       llvm.call @llvm.coro.free(...)  // delete coroutine state
347 ///       br ^suspend
348 ///
349 ///     ^suspend:
350 ///       llvm.call @llvm.coro.end(...)  // marks the end of a coroutine
351 ///       return %token, %value : !async.token, !async.value<T>
352 ///   }
353 ///
354 /// The actual code for the async.execute operation body region will be inserted
355 /// before the entry block terminator.
356 ///
357 ///
358 static CoroMachinery setupCoroMachinery(FuncOp func) {
359   assert(func.getBody().empty() && "Function must have empty body");
360 
361   MLIRContext *ctx = func.getContext();
362 
363   auto token = LLVM::LLVMTokenType::get(ctx);
364   auto i1 = IntegerType::get(ctx, 1);
365   auto i32 = IntegerType::get(ctx, 32);
366   auto i64 = IntegerType::get(ctx, 64);
367   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
368 
369   Block *entryBlock = func.addEntryBlock();
370   Location loc = func.getBody().getLoc();
371 
372   auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock);
373 
374   // ------------------------------------------------------------------------ //
375   // Allocate async tokens/values that we will return from a ramp function.
376   // ------------------------------------------------------------------------ //
377   auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx));
378 
379   // Async value operands and results must be convertible to LLVM types. This is
380   // verified before the function outlining.
381   LLVMTypeConverter converter(ctx);
382 
383   // Returns the size requirements for the async value storage.
384   // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
385   auto sizeOf = [&](ValueType valueType) -> Value {
386     auto storedType = converter.convertType(valueType.getValueType());
387     auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
388 
389     // %Size = getelementptr %T* null, int 1
390     // %SizeI = ptrtoint %T* %Size to i32
391     auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType);
392     auto one = builder.create<LLVM::ConstantOp>(loc, i32,
393                                                 builder.getI32IntegerAttr(1));
394     auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
395                                            one.getResult());
396     return builder.create<LLVM::PtrToIntOp>(loc, i32, gep);
397   };
398 
399   // We use the `async.value` type as a return type although it does not match
400   // the `kCreateValue` function signature, because it will be later lowered to
401   // the runtime type (opaque i8* pointer).
402   llvm::SmallVector<CallOp, 4> createValues;
403   for (auto resultType : func.getCallableResults().drop_front(1))
404     createValues.emplace_back(builder.create<CallOp>(
405         loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>())));
406 
407   auto createdValues = llvm::map_range(
408       createValues, [](CallOp call) { return call.getResult(0); });
409   llvm::SmallVector<Value, 4> returnValues(createdValues.begin(),
410                                            createdValues.end());
411 
412   // ------------------------------------------------------------------------ //
413   // Initialize coroutine: allocate frame, get coroutine handle.
414   // ------------------------------------------------------------------------ //
415 
416   // Constants for initializing coroutine frame.
417   auto constZero =
418       builder.create<LLVM::ConstantOp>(i32, builder.getI32IntegerAttr(0));
419   auto constFalse =
420       builder.create<LLVM::ConstantOp>(i1, builder.getBoolAttr(false));
421   auto nullPtr = builder.create<LLVM::NullOp>(i8Ptr);
422 
423   // Get coroutine id: @llvm.coro.id
424   auto coroId = builder.create<LLVM::CallOp>(
425       token, builder.getSymbolRefAttr(kCoroId),
426       ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
427 
428   // Get coroutine frame size: @llvm.coro.size.i64
429   auto coroSize = builder.create<LLVM::CallOp>(
430       i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
431 
432   // Allocate memory for coroutine frame.
433   auto coroAlloc =
434       builder.create<LLVM::CallOp>(i8Ptr, builder.getSymbolRefAttr(kMalloc),
435                                    ValueRange(coroSize.getResult(0)));
436 
437   // Begin a coroutine: @llvm.coro.begin
438   auto coroHdl = builder.create<LLVM::CallOp>(
439       i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
440       ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
441 
442   Block *cleanupBlock = func.addBlock();
443   Block *suspendBlock = func.addBlock();
444 
445   // ------------------------------------------------------------------------ //
446   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
447   // ------------------------------------------------------------------------ //
448   builder.setInsertionPointToStart(cleanupBlock);
449 
450   // Get a pointer to the coroutine frame memory: @llvm.coro.free.
451   auto coroMem = builder.create<LLVM::CallOp>(
452       i8Ptr, builder.getSymbolRefAttr(kCoroFree),
453       ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
454 
455   // Free the memory.
456   builder.create<LLVM::CallOp>(TypeRange(), builder.getSymbolRefAttr(kFree),
457                                ValueRange(coroMem.getResult(0)));
458   // Branch into the suspend block.
459   builder.create<BranchOp>(suspendBlock);
460 
461   // ------------------------------------------------------------------------ //
462   // Coroutine suspend block: mark the end of a coroutine and return allocated
463   // async token.
464   // ------------------------------------------------------------------------ //
465   builder.setInsertionPointToStart(suspendBlock);
466 
467   // Mark the end of a coroutine: @llvm.coro.end.
468   builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd),
469                                ValueRange({coroHdl.getResult(0), constFalse}));
470 
471   // Return created `async.token` and `async.values` from the suspend block.
472   // This will be the return value of a coroutine ramp function.
473   SmallVector<Value, 4> ret{createToken.getResult(0)};
474   ret.insert(ret.end(), returnValues.begin(), returnValues.end());
475   builder.create<ReturnOp>(loc, ret);
476 
477   // Branch from the entry block to the cleanup block to create a valid CFG.
478   builder.setInsertionPointToEnd(entryBlock);
479 
480   builder.create<BranchOp>(cleanupBlock);
481 
482   // `async.await` op lowering will create resume blocks for async
483   // continuations, and will conditionally branch to cleanup or suspend blocks.
484 
485   CoroMachinery machinery;
486   machinery.asyncToken = createToken.getResult(0);
487   machinery.returnValues = returnValues;
488   machinery.coroHandle = coroHdl.getResult(0);
489   machinery.cleanup = cleanupBlock;
490   machinery.suspend = suspendBlock;
491   return machinery;
492 }
493 
494 /// Add a LLVM coroutine suspension point to the end of suspended block, to
495 /// resume execution in resume block. The caller is responsible for creating the
496 /// two suspended/resume blocks with the desired ops contained in each block.
497 /// This function merely provides the required control flow logic.
498 ///
499 /// `coroState` must be a value returned from the call to @llvm.coro.save(...)
500 /// intrinsic (saved coroutine state).
501 ///
502 /// Before:
503 ///
504 ///   ^bb0:
505 ///     "opBefore"(...)
506 ///     "op"(...)
507 ///   ^cleanup: ...
508 ///   ^suspend: ...
509 ///   ^resume:
510 ///     "op"(...)
511 ///
512 /// After:
513 ///
514 ///   ^bb0:
515 ///     "opBefore"(...)
516 ///     %suspend = llmv.call @llvm.coro.suspend(...)
517 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
518 ///   ^resume:
519 ///     "op"(...)
520 ///   ^cleanup: ...
521 ///   ^suspend: ...
522 ///
523 static void addSuspensionPoint(CoroMachinery coro, Value coroState,
524                                Operation *op, Block *suspended, Block *resume,
525                                OpBuilder &builder) {
526   Location loc = op->getLoc();
527   MLIRContext *ctx = op->getContext();
528   auto i1 = IntegerType::get(ctx, 1);
529   auto i8 = IntegerType::get(ctx, 8);
530 
531   // Add a coroutine suspension in place of original `op` in the split block.
532   OpBuilder::InsertionGuard guard(builder);
533   builder.setInsertionPointToEnd(suspended);
534 
535   auto constFalse =
536       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
537 
538   // Suspend a coroutine: @llvm.coro.suspend
539   auto coroSuspend = builder.create<LLVM::CallOp>(
540       loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
541       ValueRange({coroState, constFalse}));
542 
543   // After a suspension point decide if we should branch into resume, cleanup
544   // or suspend block of the coroutine (see @llvm.coro.suspend return code
545   // documentation).
546   auto constZero =
547       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
548   auto constNegOne =
549       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
550 
551   Block *resumeOrCleanup = builder.createBlock(resume);
552 
553   // Suspend the coroutine ...?
554   builder.setInsertionPointToEnd(suspended);
555   auto isNegOne = builder.create<LLVM::ICmpOp>(
556       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
557   builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
558                                  /*falseDest=*/resumeOrCleanup);
559 
560   // ... or resume or cleanup the coroutine?
561   builder.setInsertionPointToStart(resumeOrCleanup);
562   auto isZero = builder.create<LLVM::ICmpOp>(
563       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
564   builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
565                                  /*falseDest=*/coro.cleanup);
566 }
567 
568 /// Outline the body region attached to the `async.execute` op into a standalone
569 /// function.
570 ///
571 /// Note that this is not reversible transformation.
572 static std::pair<FuncOp, CoroMachinery>
573 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
574   ModuleOp module = execute->getParentOfType<ModuleOp>();
575 
576   MLIRContext *ctx = module.getContext();
577   Location loc = execute.getLoc();
578 
579   // Collect all outlined function inputs.
580   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
581                                               execute.dependencies().end());
582   functionInputs.insert(execute.operands().begin(), execute.operands().end());
583   getUsedValuesDefinedAbove(execute.body(), functionInputs);
584 
585   // Collect types for the outlined function inputs and outputs.
586   auto typesRange = llvm::map_range(
587       functionInputs, [](Value value) { return value.getType(); });
588   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
589   auto outputTypes = execute.getResultTypes();
590 
591   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
592   auto funcAttrs = ArrayRef<NamedAttribute>();
593 
594   // TODO: Derive outlined function name from the parent FuncOp (support
595   // multiple nested async.execute operations).
596   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
597   symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
598 
599   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
600 
601   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
602   // blocks, adding llvm.coro instrinsics and setting up control flow.
603   CoroMachinery coro = setupCoroMachinery(func);
604 
605   // Suspend async function at the end of an entry block, and resume it using
606   // Async execute API (execution will be resumed in a thread managed by the
607   // async runtime).
608   Block *entryBlock = &func.getBlocks().front();
609   auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
610 
611   // A pointer to coroutine resume intrinsic wrapper.
612   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
613   auto resumePtr = builder.create<LLVM::AddressOfOp>(
614       LLVM::LLVMPointerType::get(resumeFnTy), kResume);
615 
616   // Save the coroutine state: @llvm.coro.save
617   auto coroSave = builder.create<LLVM::CallOp>(
618       LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
619       ValueRange({coro.coroHandle}));
620 
621   // Call async runtime API to execute a coroutine in the managed thread.
622   SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
623   builder.create<CallOp>(TypeRange(), kExecute, executeArgs);
624 
625   // Split the entry block before the terminator.
626   auto *terminatorOp = entryBlock->getTerminator();
627   Block *suspended = terminatorOp->getBlock();
628   Block *resume = suspended->splitBlock(terminatorOp);
629   addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended,
630                      resume, builder);
631 
632   size_t numDependencies = execute.dependencies().size();
633   size_t numOperands = execute.operands().size();
634 
635   // Await on all dependencies before starting to execute the body region.
636   builder.setInsertionPointToStart(resume);
637   for (size_t i = 0; i < numDependencies; ++i)
638     builder.create<AwaitOp>(func.getArgument(i));
639 
640   // Await on all async value operands and unwrap the payload.
641   SmallVector<Value, 4> unwrappedOperands(numOperands);
642   for (size_t i = 0; i < numOperands; ++i) {
643     Value operand = func.getArgument(numDependencies + i);
644     unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
645   }
646 
647   // Map from function inputs defined above the execute op to the function
648   // arguments.
649   BlockAndValueMapping valueMapping;
650   valueMapping.map(functionInputs, func.getArguments());
651   valueMapping.map(execute.body().getArguments(), unwrappedOperands);
652 
653   // Clone all operations from the execute operation body into the outlined
654   // function body.
655   for (Operation &op : execute.body().getOps())
656     builder.clone(op, valueMapping);
657 
658   // Replace the original `async.execute` with a call to outlined function.
659   ImplicitLocOpBuilder callBuilder(loc, execute);
660   auto callOutlinedFunc = callBuilder.create<CallOp>(
661       func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
662   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
663   execute.erase();
664 
665   return {func, coro};
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // Convert Async dialect types to LLVM types.
670 //===----------------------------------------------------------------------===//
671 
672 namespace {
673 
674 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
675 /// their runtime type (opaque pointers) and does not convert any other types.
676 class AsyncRuntimeTypeConverter : public TypeConverter {
677 public:
678   AsyncRuntimeTypeConverter() {
679     addConversion([](Type type) { return type; });
680     addConversion(convertAsyncTypes);
681   }
682 
683   static Optional<Type> convertAsyncTypes(Type type) {
684     if (type.isa<TokenType, GroupType, ValueType>())
685       return AsyncAPI::opaquePointerType(type.getContext());
686     return llvm::None;
687   }
688 };
689 } // namespace
690 
691 //===----------------------------------------------------------------------===//
692 // Convert return operations that return async values from async regions.
693 //===----------------------------------------------------------------------===//
694 
695 namespace {
696 class ReturnOpOpConversion : public ConversionPattern {
697 public:
698   explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx)
699       : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {}
700 
701   LogicalResult
702   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
703                   ConversionPatternRewriter &rewriter) const override {
704     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
705     return success();
706   }
707 };
708 } // namespace
709 
710 //===----------------------------------------------------------------------===//
711 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref`
712 // to the corresponding API calls).
713 //===----------------------------------------------------------------------===//
714 
715 namespace {
716 
717 template <typename RefCountingOp>
718 class RefCountingOpLowering : public ConversionPattern {
719 public:
720   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
721                                  StringRef apiFunctionName)
722       : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx),
723         apiFunctionName(apiFunctionName) {}
724 
725   LogicalResult
726   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
727                   ConversionPatternRewriter &rewriter) const override {
728     RefCountingOp refCountingOp = cast<RefCountingOp>(op);
729 
730     auto count = rewriter.create<ConstantOp>(
731         op->getLoc(), rewriter.getI32Type(),
732         rewriter.getI32IntegerAttr(refCountingOp.count()));
733 
734     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
735                                         ValueRange({operands[0], count}));
736 
737     return success();
738   }
739 
740 private:
741   StringRef apiFunctionName;
742 };
743 
744 /// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
745 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
746 public:
747   explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
748       : RefCountingOpLowering(converter, ctx, kAddRef) {}
749 };
750 
751 /// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
752 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
753 public:
754   explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
755       : RefCountingOpLowering(converter, ctx, kDropRef) {}
756 };
757 
758 } // namespace
759 
760 //===----------------------------------------------------------------------===//
761 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
762 //===----------------------------------------------------------------------===//
763 
764 namespace {
765 class CreateGroupOpLowering : public ConversionPattern {
766 public:
767   explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
768       : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter,
769                           ctx) {}
770 
771   LogicalResult
772   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
773                   ConversionPatternRewriter &rewriter) const override {
774     auto retTy = GroupType::get(op->getContext());
775     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy);
776     return success();
777   }
778 };
779 } // namespace
780 
781 //===----------------------------------------------------------------------===//
782 // async.add_to_group op lowering to runtime function call.
783 //===----------------------------------------------------------------------===//
784 
785 namespace {
786 class AddToGroupOpLowering : public ConversionPattern {
787 public:
788   explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
789       : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) {
790   }
791 
792   LogicalResult
793   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
794                   ConversionPatternRewriter &rewriter) const override {
795     // Currently we can only add tokens to the group.
796     auto addToGroup = cast<AddToGroupOp>(op);
797     if (!addToGroup.operand().getType().isa<TokenType>())
798       return failure();
799 
800     auto i64 = IntegerType::get(op->getContext(), 64);
801     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
802     return success();
803   }
804 };
805 } // namespace
806 
807 //===----------------------------------------------------------------------===//
808 // async.await and async.await_all op lowerings to the corresponding async
809 // runtime function calls.
810 //===----------------------------------------------------------------------===//
811 
812 namespace {
813 
814 template <typename AwaitType, typename AwaitableType>
815 class AwaitOpLoweringBase : public ConversionPattern {
816 protected:
817   explicit AwaitOpLoweringBase(
818       TypeConverter &converter, MLIRContext *ctx,
819       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
820       StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
821       : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx),
822         outlinedFunctions(outlinedFunctions),
823         blockingAwaitFuncName(blockingAwaitFuncName),
824         coroAwaitFuncName(coroAwaitFuncName) {}
825 
826 public:
827   LogicalResult
828   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
829                   ConversionPatternRewriter &rewriter) const override {
830     // We can only await on one the `AwaitableType` (for `await` it can be
831     // a `token` or a `value`, for `await_all` it must be a `group`).
832     auto await = cast<AwaitType>(op);
833     if (!await.operand().getType().template isa<AwaitableType>())
834       return failure();
835 
836     // Check if await operation is inside the outlined coroutine function.
837     auto func = await->template getParentOfType<FuncOp>();
838     auto outlined = outlinedFunctions.find(func);
839     const bool isInCoroutine = outlined != outlinedFunctions.end();
840 
841     Location loc = op->getLoc();
842 
843     // Inside regular function we convert await operation to the blocking
844     // async API await function call.
845     if (!isInCoroutine)
846       rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName,
847                               ValueRange(operands[0]));
848 
849     // Inside the coroutine we convert await operation into coroutine suspension
850     // point, and resume execution asynchronously.
851     if (isInCoroutine) {
852       const CoroMachinery &coro = outlined->getSecond();
853 
854       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
855       MLIRContext *ctx = op->getContext();
856 
857       // A pointer to coroutine resume intrinsic wrapper.
858       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
859       auto resumePtr = builder.create<LLVM::AddressOfOp>(
860           LLVM::LLVMPointerType::get(resumeFnTy), kResume);
861 
862       // Save the coroutine state: @llvm.coro.save
863       auto coroSave = builder.create<LLVM::CallOp>(
864           LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
865           ValueRange(coro.coroHandle));
866 
867       // Call async runtime API to resume a coroutine in the managed thread when
868       // the async await argument becomes ready.
869       SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle,
870                                                    resumePtr.res()};
871       builder.create<CallOp>(TypeRange(), coroAwaitFuncName,
872                              awaitAndExecuteArgs);
873 
874       Block *suspended = op->getBlock();
875 
876       // Split the entry block before the await operation.
877       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
878       addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
879                          builder);
880 
881       // Make sure that replacement value will be constructed in resume block.
882       rewriter.setInsertionPointToStart(resume);
883     }
884 
885     // Replace or erase the await operation with the new value.
886     if (Value replaceWith = getReplacementValue(op, operands[0], rewriter))
887       rewriter.replaceOp(op, replaceWith);
888     else
889       rewriter.eraseOp(op);
890 
891     return success();
892   }
893 
894   virtual Value getReplacementValue(Operation *op, Value operand,
895                                     ConversionPatternRewriter &rewriter) const {
896     return Value();
897   }
898 
899 private:
900   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
901   StringRef blockingAwaitFuncName;
902   StringRef coroAwaitFuncName;
903 };
904 
905 /// Lowering for `async.await` with a token operand.
906 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
907   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
908 
909 public:
910   explicit AwaitTokenOpLowering(
911       TypeConverter &converter, MLIRContext *ctx,
912       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
913       : Base(converter, ctx, outlinedFunctions, kAwaitToken,
914              kAwaitTokenAndExecute) {}
915 };
916 
917 /// Lowering for `async.await` with a value operand.
918 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
919   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
920 
921 public:
922   explicit AwaitValueOpLowering(
923       TypeConverter &converter, MLIRContext *ctx,
924       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
925       : Base(converter, ctx, outlinedFunctions, kAwaitValue,
926              kAwaitValueAndExecute) {}
927 
928   Value
929   getReplacementValue(Operation *op, Value operand,
930                       ConversionPatternRewriter &rewriter) const override {
931     Location loc = op->getLoc();
932     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
933 
934     // Get the underlying value type from the `async.value`.
935     auto await = cast<AwaitOp>(op);
936     auto valueType = await.operand().getType().cast<ValueType>().getValueType();
937 
938     // Get a pointer to an async value storage from the runtime.
939     auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
940                                            TypeRange(i8Ptr), operand);
941 
942     // Cast from i8* to the pointer pointer to LLVM type.
943     auto llvmValueType = getTypeConverter()->convertType(valueType);
944     auto castedStorage = rewriter.create<LLVM::BitcastOp>(
945         loc, LLVM::LLVMPointerType::get(llvmValueType), storage.getResult(0));
946 
947     // Load from the async value storage.
948     return rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult());
949   }
950 };
951 
952 /// Lowering for `async.await_all` operation.
953 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
954   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
955 
956 public:
957   explicit AwaitAllOpLowering(
958       TypeConverter &converter, MLIRContext *ctx,
959       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
960       : Base(converter, ctx, outlinedFunctions, kAwaitGroup,
961              kAwaitAllAndExecute) {}
962 };
963 
964 } // namespace
965 
966 //===----------------------------------------------------------------------===//
967 // async.yield op lowerings to the corresponding async runtime function calls.
968 //===----------------------------------------------------------------------===//
969 
970 class YieldOpLowering : public ConversionPattern {
971 public:
972   explicit YieldOpLowering(
973       TypeConverter &converter, MLIRContext *ctx,
974       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
975       : ConversionPattern(async::YieldOp::getOperationName(), 1, converter,
976                           ctx),
977         outlinedFunctions(outlinedFunctions) {}
978 
979   LogicalResult
980   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
981                   ConversionPatternRewriter &rewriter) const override {
982     // Check if yield operation is inside the outlined coroutine function.
983     auto func = op->template getParentOfType<FuncOp>();
984     auto outlined = outlinedFunctions.find(func);
985     if (outlined == outlinedFunctions.end())
986       return op->emitOpError(
987           "async.yield is not inside the outlined coroutine function");
988 
989     Location loc = op->getLoc();
990     const CoroMachinery &coro = outlined->getSecond();
991 
992     // Store yielded values into the async values storage and emplace them.
993     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
994 
995     for (auto tuple : llvm::zip(operands, coro.returnValues)) {
996       // Store `yieldValue` into the `asyncValue` storage.
997       Value yieldValue = std::get<0>(tuple);
998       Value asyncValue = std::get<1>(tuple);
999 
1000       // Get an opaque i8* pointer to an async value storage from the runtime.
1001       auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
1002                                              TypeRange(i8Ptr), asyncValue);
1003 
1004       // Cast storage pointer to the yielded value type.
1005       auto castedStorage = rewriter.create<LLVM::BitcastOp>(
1006           loc, LLVM::LLVMPointerType::get(yieldValue.getType()),
1007           storage.getResult(0));
1008 
1009       // Store the yielded value into the async value storage.
1010       rewriter.create<LLVM::StoreOp>(loc, yieldValue,
1011                                      castedStorage.getResult());
1012 
1013       // Emplace the `async.value` to mark it ready.
1014       rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue);
1015     }
1016 
1017     // Emplace the completion token to mark it ready.
1018     rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
1019 
1020     // Original operation was replaced by the function call(s).
1021     rewriter.eraseOp(op);
1022 
1023     return success();
1024   }
1025 
1026 private:
1027   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
1028 };
1029 
1030 //===----------------------------------------------------------------------===//
1031 
1032 namespace {
1033 struct ConvertAsyncToLLVMPass
1034     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
1035   void runOnOperation() override;
1036 };
1037 
1038 void ConvertAsyncToLLVMPass::runOnOperation() {
1039   ModuleOp module = getOperation();
1040   SymbolTable symbolTable(module);
1041 
1042   MLIRContext *ctx = &getContext();
1043 
1044   // Outline all `async.execute` body regions into async functions (coroutines).
1045   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
1046 
1047   // We use conversion to LLVM type to ensure that all `async.value` operands
1048   // and results can be lowered to LLVM load and store operations.
1049   LLVMTypeConverter llvmConverter(ctx);
1050   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
1051 
1052   // Returns true if the `async.value` payload is convertible to LLVM.
1053   auto isConvertibleToLlvm = [&](Type type) -> bool {
1054     auto valueType = type.cast<ValueType>().getValueType();
1055     return static_cast<bool>(llvmConverter.convertType(valueType));
1056   };
1057 
1058   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
1059     // All operands and results must be convertible to LLVM.
1060     if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) {
1061       execute.emitOpError("operands payload must be convertible to LLVM type");
1062       return WalkResult::interrupt();
1063     }
1064     if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) {
1065       execute.emitOpError("results payload must be convertible to LLVM type");
1066       return WalkResult::interrupt();
1067     }
1068 
1069     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
1070 
1071     return WalkResult::advance();
1072   });
1073 
1074   // Failed to outline all async execute operations.
1075   if (outlineResult.wasInterrupted()) {
1076     signalPassFailure();
1077     return;
1078   }
1079 
1080   LLVM_DEBUG({
1081     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
1082                  << " async functions\n";
1083   });
1084 
1085   // Add declarations for all functions required by the coroutines lowering.
1086   addResumeFunction(module);
1087   addAsyncRuntimeApiDeclarations(module);
1088   addCoroutineIntrinsicsDeclarations(module);
1089   addCRuntimeDeclarations(module);
1090 
1091   // Convert async dialect types and operations to LLVM dialect.
1092   AsyncRuntimeTypeConverter converter;
1093   OwningRewritePatternList patterns;
1094 
1095   // Convert async types in function signatures and function calls.
1096   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
1097   populateCallOpTypeConversionPattern(patterns, ctx, converter);
1098 
1099   // Convert return operations inside async.execute regions.
1100   patterns.insert<ReturnOpOpConversion>(converter, ctx);
1101 
1102   // Lower async operations to async runtime API calls.
1103   patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx);
1104   patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx);
1105 
1106   // Use LLVM type converter to automatically convert between the async value
1107   // payload type and LLVM type when loading/storing from/to the async
1108   // value storage which is an opaque i8* pointer using LLVM load/store ops.
1109   patterns
1110       .insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
1111           llvmConverter, ctx, outlinedFunctions);
1112   patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions);
1113 
1114   ConversionTarget target(*ctx);
1115   target.addLegalOp<ConstantOp>();
1116   target.addLegalDialect<LLVM::LLVMDialect>();
1117 
1118   // All operations from Async dialect must be lowered to the runtime API calls.
1119   target.addIllegalDialect<AsyncDialect>();
1120 
1121   // Add dynamic legality constraints to apply conversions defined above.
1122   target.addDynamicallyLegalOp<FuncOp>(
1123       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1124   target.addDynamicallyLegalOp<ReturnOp>(
1125       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1126   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1127     return converter.isSignatureLegal(op.getCalleeType());
1128   });
1129 
1130   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1131     signalPassFailure();
1132 }
1133 } // namespace
1134 
1135 namespace {
1136 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1137 public:
1138   using OpConversionPattern::OpConversionPattern;
1139   LogicalResult
1140   matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
1141                   ConversionPatternRewriter &rewriter) const override {
1142     ExecuteOp newOp =
1143         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1144     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1145                                 newOp.getRegion().end());
1146 
1147     // Set operands and update block argument and result types.
1148     newOp->setOperands(operands);
1149     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1150       return failure();
1151     for (auto result : newOp.getResults())
1152       result.setType(typeConverter->convertType(result.getType()));
1153 
1154     rewriter.replaceOp(op, newOp.getResults());
1155     return success();
1156   }
1157 };
1158 
1159 // Dummy pattern to trigger the appropriate type conversion / materialization.
1160 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1161 public:
1162   using OpConversionPattern::OpConversionPattern;
1163   LogicalResult
1164   matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
1165                   ConversionPatternRewriter &rewriter) const override {
1166     rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
1167     return success();
1168   }
1169 };
1170 
1171 // Dummy pattern to trigger the appropriate type conversion / materialization.
1172 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1173 public:
1174   using OpConversionPattern::OpConversionPattern;
1175   LogicalResult
1176   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
1177                   ConversionPatternRewriter &rewriter) const override {
1178     rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
1179     return success();
1180   }
1181 };
1182 } // namespace
1183 
1184 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1185   return std::make_unique<ConvertAsyncToLLVMPass>();
1186 }
1187 
1188 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1189     MLIRContext *context, TypeConverter &typeConverter,
1190     OwningRewritePatternList &patterns, ConversionTarget &target) {
1191   typeConverter.addConversion([&](TokenType type) { return type; });
1192   typeConverter.addConversion([&](ValueType type) {
1193     return ValueType::get(typeConverter.convertType(type.getValueType()));
1194   });
1195 
1196   patterns
1197       .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1198           typeConverter, context);
1199 
1200   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1201       [&](Operation *op) { return typeConverter.isLegal(op); });
1202 }
1203