1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Async/IR/Async.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/ImplicitLocOpBuilder.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #define DEBUG_TYPE "convert-async-to-llvm"
27 
28 using namespace mlir;
29 using namespace mlir::async;
30 
31 // Prefix for functions outlined from `async.execute` op regions.
32 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
33 
34 //===----------------------------------------------------------------------===//
35 // Async Runtime C API declaration.
36 //===----------------------------------------------------------------------===//
37 
38 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
39 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
40 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
41 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
42 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
43 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
44 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
45 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
46 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
47 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
48 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
49 static constexpr const char *kGetValueStorage =
50     "mlirAsyncRuntimeGetValueStorage";
51 static constexpr const char *kAddTokenToGroup =
52     "mlirAsyncRuntimeAddTokenToGroup";
53 static constexpr const char *kAwaitTokenAndExecute =
54     "mlirAsyncRuntimeAwaitTokenAndExecute";
55 static constexpr const char *kAwaitValueAndExecute =
56     "mlirAsyncRuntimeAwaitValueAndExecute";
57 static constexpr const char *kAwaitAllAndExecute =
58     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
59 
60 namespace {
61 /// Async Runtime API function types.
62 ///
63 /// Because we can't create API function signature for type parametrized
64 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
65 /// lowering all async data types become opaque pointers at runtime.
66 struct AsyncAPI {
67   // All async types are lowered to opaque i8* LLVM pointers at runtime.
68   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
69     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
70   }
71 
72   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
73     return LLVM::LLVMTokenType::get(ctx);
74   }
75 
76   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
77     auto ref = opaquePointerType(ctx);
78     auto count = IntegerType::get(ctx, 32);
79     return FunctionType::get(ctx, {ref, count}, {});
80   }
81 
82   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
83     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
84   }
85 
86   static FunctionType createValueFunctionType(MLIRContext *ctx) {
87     auto i32 = IntegerType::get(ctx, 32);
88     auto value = opaquePointerType(ctx);
89     return FunctionType::get(ctx, {i32}, {value});
90   }
91 
92   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
93     return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
94   }
95 
96   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
97     auto value = opaquePointerType(ctx);
98     auto storage = opaquePointerType(ctx);
99     return FunctionType::get(ctx, {value}, {storage});
100   }
101 
102   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
103     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
104   }
105 
106   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
107     auto value = opaquePointerType(ctx);
108     return FunctionType::get(ctx, {value}, {});
109   }
110 
111   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
112     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
113   }
114 
115   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
116     auto value = opaquePointerType(ctx);
117     return FunctionType::get(ctx, {value}, {});
118   }
119 
120   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
121     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
122   }
123 
124   static FunctionType executeFunctionType(MLIRContext *ctx) {
125     auto hdl = opaquePointerType(ctx);
126     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
127     return FunctionType::get(ctx, {hdl, resume}, {});
128   }
129 
130   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
131     auto i64 = IntegerType::get(ctx, 64);
132     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
133                              {i64});
134   }
135 
136   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
137     auto hdl = opaquePointerType(ctx);
138     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
139     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
140   }
141 
142   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
143     auto value = opaquePointerType(ctx);
144     auto hdl = opaquePointerType(ctx);
145     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
146     return FunctionType::get(ctx, {value, hdl, resume}, {});
147   }
148 
149   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
150     auto hdl = opaquePointerType(ctx);
151     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
152     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
153   }
154 
155   // Auxiliary coroutine resume intrinsic wrapper.
156   static Type resumeFunctionType(MLIRContext *ctx) {
157     auto voidTy = LLVM::LLVMVoidType::get(ctx);
158     auto i8Ptr = opaquePointerType(ctx);
159     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
160   }
161 };
162 } // namespace
163 
164 /// Adds Async Runtime C API declarations to the module.
165 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
166   auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
167                                                          module.getBody());
168 
169   auto addFuncDecl = [&](StringRef name, FunctionType type) {
170     if (module.lookupSymbol(name))
171       return;
172     builder.create<FuncOp>(name, type).setPrivate();
173   };
174 
175   MLIRContext *ctx = module.getContext();
176   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
177   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
178   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
179   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
180   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
181   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
182   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
183   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
184   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
185   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
186   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
187   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
188   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
189   addFuncDecl(kAwaitTokenAndExecute,
190               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
191   addFuncDecl(kAwaitValueAndExecute,
192               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
193   addFuncDecl(kAwaitAllAndExecute,
194               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // LLVM coroutines intrinsics declarations.
199 //===----------------------------------------------------------------------===//
200 
201 static constexpr const char *kCoroId = "llvm.coro.id";
202 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
203 static constexpr const char *kCoroBegin = "llvm.coro.begin";
204 static constexpr const char *kCoroSave = "llvm.coro.save";
205 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
206 static constexpr const char *kCoroEnd = "llvm.coro.end";
207 static constexpr const char *kCoroFree = "llvm.coro.free";
208 static constexpr const char *kCoroResume = "llvm.coro.resume";
209 
210 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
211                             StringRef name, Type ret, ArrayRef<Type> params) {
212   if (module.lookupSymbol(name))
213     return;
214   Type type = LLVM::LLVMFunctionType::get(ret, params);
215   builder.create<LLVM::LLVMFuncOp>(name, type);
216 }
217 
218 /// Adds coroutine intrinsics declarations to the module.
219 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
220   using namespace mlir::LLVM;
221 
222   MLIRContext *ctx = module.getContext();
223   ImplicitLocOpBuilder builder(module.getLoc(),
224                                module.getBody()->getTerminator());
225 
226   auto token = LLVMTokenType::get(ctx);
227   auto voidTy = LLVMVoidType::get(ctx);
228 
229   auto i8 = IntegerType::get(ctx, 8);
230   auto i1 = IntegerType::get(ctx, 1);
231   auto i32 = IntegerType::get(ctx, 32);
232   auto i64 = IntegerType::get(ctx, 64);
233   auto i8Ptr = LLVMPointerType::get(i8);
234 
235   addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
236   addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
237   addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
238   addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
239   addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
240   addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
241   addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
242   addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // Add malloc/free declarations to the module.
247 //===----------------------------------------------------------------------===//
248 
249 static constexpr const char *kMalloc = "malloc";
250 static constexpr const char *kFree = "free";
251 
252 /// Adds malloc/free declarations to the module.
253 static void addCRuntimeDeclarations(ModuleOp module) {
254   using namespace mlir::LLVM;
255 
256   MLIRContext *ctx = module.getContext();
257   ImplicitLocOpBuilder builder(module.getLoc(),
258                                module.getBody()->getTerminator());
259 
260   auto voidTy = LLVMVoidType::get(ctx);
261   auto i64 = IntegerType::get(ctx, 64);
262   auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8));
263 
264   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
265   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // Coroutine resume function wrapper.
270 //===----------------------------------------------------------------------===//
271 
272 static constexpr const char *kResume = "__resume";
273 
274 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
275 /// intrinsics. We need this function to be able to pass it to the async
276 /// runtime execute API.
277 static void addResumeFunction(ModuleOp module) {
278   MLIRContext *ctx = module.getContext();
279 
280   OpBuilder moduleBuilder(module.getBody()->getTerminator());
281   Location loc = module.getLoc();
282 
283   if (module.lookupSymbol(kResume))
284     return;
285 
286   auto voidTy = LLVM::LLVMVoidType::get(ctx);
287   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
288 
289   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
290       loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
291   resumeOp.setPrivate();
292 
293   auto *block = resumeOp.addEntryBlock();
294   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
295 
296   blockBuilder.create<LLVM::CallOp>(TypeRange(),
297                                     blockBuilder.getSymbolRefAttr(kCoroResume),
298                                     resumeOp.getArgument(0));
299 
300   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // async.execute op outlining to the coroutine functions.
305 //===----------------------------------------------------------------------===//
306 
307 /// Function targeted for coroutine transformation has two additional blocks at
308 /// the end: coroutine cleanup and coroutine suspension.
309 ///
310 /// async.await op lowering additionaly creates a resume block for each
311 /// operation to enable non-blocking waiting via coroutine suspension.
312 namespace {
313 struct CoroMachinery {
314   // Async execute region returns a completion token, and an async value for
315   // each yielded value.
316   //
317   //   %token, %result = async.execute -> !async.value<T> {
318   //     %0 = constant ... : T
319   //     async.yield %0 : T
320   //   }
321   Value asyncToken; // token representing completion of the async region
322   llvm::SmallVector<Value, 4> returnValues; // returned async values
323 
324   Value coroHandle; // coroutine handle (!async.coro.handle value)
325   Block *cleanup;   // coroutine cleanup block
326   Block *suspend;   // coroutine suspension block
327 };
328 } // namespace
329 
330 /// Builds an coroutine template compatible with LLVM coroutines switched-resume
331 /// lowering using `async.runtime.*` and `async.coro.*` operations.
332 ///
333 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
334 ///
335 ///  - `entry` block sets up the coroutine.
336 ///  - `cleanup` block cleans up the coroutine state.
337 ///  - `suspend block after the @llvm.coro.end() defines what value will be
338 ///    returned to the initial caller of a coroutine. Everything before the
339 ///    @llvm.coro.end() will be executed at every suspension point.
340 ///
341 /// Coroutine structure (only the important bits):
342 ///
343 ///   func @async_execute_fn(<function-arguments>)
344 ///        -> (!async.token, !async.value<T>)
345 ///   {
346 ///     ^entry(<function-arguments>):
347 ///       %token = <async token> : !async.token    // create async runtime token
348 ///       %value = <async value> : !async.value<T> // create async value
349 ///       %id = async.coro.id                      // create a coroutine id
350 ///       %hdl = async.coro.begin %id              // create a coroutine handle
351 ///       br ^cleanup
352 ///
353 ///     ^cleanup:
354 ///       async.coro.free %hdl // delete the coroutine state
355 ///       br ^suspend
356 ///
357 ///     ^suspend:
358 ///       async.coro.end %hdl // marks the end of a coroutine
359 ///       return %token, %value : !async.token, !async.value<T>
360 ///   }
361 ///
362 /// The actual code for the async.execute operation body region will be inserted
363 /// before the entry block terminator.
364 ///
365 ///
366 static CoroMachinery setupCoroMachinery(FuncOp func) {
367   assert(func.getBody().empty() && "Function must have empty body");
368 
369   MLIRContext *ctx = func.getContext();
370   Block *entryBlock = func.addEntryBlock();
371 
372   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
373 
374   // ------------------------------------------------------------------------ //
375   // Allocate async token/values that we will return from a ramp function.
376   // ------------------------------------------------------------------------ //
377   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
378 
379   llvm::SmallVector<Value, 4> retValues;
380   for (auto resType : func.getCallableResults().drop_front())
381     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
382 
383   // ------------------------------------------------------------------------ //
384   // Initialize coroutine: get coroutine id and coroutine handle.
385   // ------------------------------------------------------------------------ //
386   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
387   auto coroHdlOp =
388       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
389 
390   Block *cleanupBlock = func.addBlock();
391   Block *suspendBlock = func.addBlock();
392 
393   // ------------------------------------------------------------------------ //
394   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
395   // ------------------------------------------------------------------------ //
396   builder.setInsertionPointToStart(cleanupBlock);
397   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
398 
399   // Branch into the suspend block.
400   builder.create<BranchOp>(suspendBlock);
401 
402   // ------------------------------------------------------------------------ //
403   // Coroutine suspend block: mark the end of a coroutine and return allocated
404   // async token.
405   // ------------------------------------------------------------------------ //
406   builder.setInsertionPointToStart(suspendBlock);
407 
408   // Mark the end of a coroutine: async.coro.end
409   builder.create<CoroEndOp>(coroHdlOp.handle());
410 
411   // Return created `async.token` and `async.values` from the suspend block.
412   // This will be the return value of a coroutine ramp function.
413   SmallVector<Value, 4> ret{retToken};
414   ret.insert(ret.end(), retValues.begin(), retValues.end());
415   builder.create<ReturnOp>(ret);
416 
417   // Branch from the entry block to the cleanup block to create a valid CFG.
418   builder.setInsertionPointToEnd(entryBlock);
419   builder.create<BranchOp>(cleanupBlock);
420 
421   // `async.await` op lowering will create resume blocks for async
422   // continuations, and will conditionally branch to cleanup or suspend blocks.
423 
424   CoroMachinery machinery;
425   machinery.asyncToken = retToken;
426   machinery.returnValues = retValues;
427   machinery.coroHandle = coroHdlOp.handle();
428   machinery.cleanup = cleanupBlock;
429   machinery.suspend = suspendBlock;
430   return machinery;
431 }
432 
433 /// Outline the body region attached to the `async.execute` op into a standalone
434 /// function.
435 ///
436 /// Note that this is not reversible transformation.
437 static std::pair<FuncOp, CoroMachinery>
438 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
439   ModuleOp module = execute->getParentOfType<ModuleOp>();
440 
441   MLIRContext *ctx = module.getContext();
442   Location loc = execute.getLoc();
443 
444   // Collect all outlined function inputs.
445   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
446                                               execute.dependencies().end());
447   functionInputs.insert(execute.operands().begin(), execute.operands().end());
448   getUsedValuesDefinedAbove(execute.body(), functionInputs);
449 
450   // Collect types for the outlined function inputs and outputs.
451   auto typesRange = llvm::map_range(
452       functionInputs, [](Value value) { return value.getType(); });
453   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
454   auto outputTypes = execute.getResultTypes();
455 
456   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
457   auto funcAttrs = ArrayRef<NamedAttribute>();
458 
459   // TODO: Derive outlined function name from the parent FuncOp (support
460   // multiple nested async.execute operations).
461   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
462   symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
463 
464   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
465 
466   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
467   // blocks, adding async.coro operations and setting up control flow.
468   CoroMachinery coro = setupCoroMachinery(func);
469 
470   // Suspend async function at the end of an entry block, and resume it using
471   // Async resume operation (execution will be resumed in a thread managed by
472   // the async runtime).
473   Block *entryBlock = &func.getBlocks().front();
474   auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
475 
476   // Save the coroutine state: async.coro.save
477   auto coroSaveOp =
478       builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
479 
480   // Pass coroutine to the runtime to be resumed on a runtime managed thread.
481   builder.create<RuntimeResumeOp>(coro.coroHandle);
482 
483   // Split the entry block before the terminator (branch to suspend block).
484   auto *terminatorOp = entryBlock->getTerminator();
485   Block *suspended = terminatorOp->getBlock();
486   Block *resume = suspended->splitBlock(terminatorOp);
487 
488   // Add async.coro.suspend as a suspended block terminator.
489   builder.setInsertionPointToEnd(suspended);
490   builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
491                                 coro.cleanup);
492 
493   size_t numDependencies = execute.dependencies().size();
494   size_t numOperands = execute.operands().size();
495 
496   // Await on all dependencies before starting to execute the body region.
497   builder.setInsertionPointToStart(resume);
498   for (size_t i = 0; i < numDependencies; ++i)
499     builder.create<AwaitOp>(func.getArgument(i));
500 
501   // Await on all async value operands and unwrap the payload.
502   SmallVector<Value, 4> unwrappedOperands(numOperands);
503   for (size_t i = 0; i < numOperands; ++i) {
504     Value operand = func.getArgument(numDependencies + i);
505     unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
506   }
507 
508   // Map from function inputs defined above the execute op to the function
509   // arguments.
510   BlockAndValueMapping valueMapping;
511   valueMapping.map(functionInputs, func.getArguments());
512   valueMapping.map(execute.body().getArguments(), unwrappedOperands);
513 
514   // Clone all operations from the execute operation body into the outlined
515   // function body.
516   for (Operation &op : execute.body().getOps())
517     builder.clone(op, valueMapping);
518 
519   // Replace the original `async.execute` with a call to outlined function.
520   ImplicitLocOpBuilder callBuilder(loc, execute);
521   auto callOutlinedFunc = callBuilder.create<CallOp>(
522       func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
523   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
524   execute.erase();
525 
526   return {func, coro};
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // Convert Async dialect types to LLVM types.
531 //===----------------------------------------------------------------------===//
532 
533 namespace {
534 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
535 /// their runtime type (opaque pointers) and does not convert any other types.
536 class AsyncRuntimeTypeConverter : public TypeConverter {
537 public:
538   AsyncRuntimeTypeConverter() {
539     addConversion([](Type type) { return type; });
540     addConversion(convertAsyncTypes);
541   }
542 
543   static Optional<Type> convertAsyncTypes(Type type) {
544     if (type.isa<TokenType, GroupType, ValueType>())
545       return AsyncAPI::opaquePointerType(type.getContext());
546 
547     if (type.isa<CoroIdType, CoroStateType>())
548       return AsyncAPI::tokenType(type.getContext());
549     if (type.isa<CoroHandleType>())
550       return AsyncAPI::opaquePointerType(type.getContext());
551 
552     return llvm::None;
553   }
554 };
555 } // namespace
556 
557 //===----------------------------------------------------------------------===//
558 // Convert async.coro.id to @llvm.coro.id intrinsic.
559 //===----------------------------------------------------------------------===//
560 
561 namespace {
562 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
563 public:
564   using OpConversionPattern::OpConversionPattern;
565 
566   LogicalResult
567   matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands,
568                   ConversionPatternRewriter &rewriter) const override {
569     auto token = AsyncAPI::tokenType(op->getContext());
570     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
571     auto loc = op->getLoc();
572 
573     // Constants for initializing coroutine frame.
574     auto constZero = rewriter.create<LLVM::ConstantOp>(
575         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
576     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
577 
578     // Get coroutine id: @llvm.coro.id.
579     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
580         op, token, rewriter.getSymbolRefAttr(kCoroId),
581         ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
582 
583     return success();
584   }
585 };
586 } // namespace
587 
588 //===----------------------------------------------------------------------===//
589 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
590 //===----------------------------------------------------------------------===//
591 
592 namespace {
593 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
594 public:
595   using OpConversionPattern::OpConversionPattern;
596 
597   LogicalResult
598   matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands,
599                   ConversionPatternRewriter &rewriter) const override {
600     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
601     auto loc = op->getLoc();
602 
603     // Get coroutine frame size: @llvm.coro.size.i64.
604     auto coroSize = rewriter.create<LLVM::CallOp>(
605         loc, rewriter.getI64Type(), rewriter.getSymbolRefAttr(kCoroSizeI64),
606         ValueRange());
607 
608     // Allocate memory for the coroutine frame.
609     auto coroAlloc = rewriter.create<LLVM::CallOp>(
610         loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc),
611         ValueRange(coroSize.getResult(0)));
612 
613     // Begin a coroutine: @llvm.coro.begin.
614     auto coroId = CoroBeginOpAdaptor(operands).id();
615     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
616         op, i8Ptr, rewriter.getSymbolRefAttr(kCoroBegin),
617         ValueRange({coroId, coroAlloc.getResult(0)}));
618 
619     return success();
620   }
621 };
622 } // namespace
623 
624 //===----------------------------------------------------------------------===//
625 // Convert async.coro.free to @llvm.coro.free intrinsic.
626 //===----------------------------------------------------------------------===//
627 
628 namespace {
629 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
630 public:
631   using OpConversionPattern::OpConversionPattern;
632 
633   LogicalResult
634   matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands,
635                   ConversionPatternRewriter &rewriter) const override {
636     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
637     auto loc = op->getLoc();
638 
639     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
640     auto coroMem = rewriter.create<LLVM::CallOp>(
641         loc, i8Ptr, rewriter.getSymbolRefAttr(kCoroFree), operands);
642 
643     // Free the memory.
644     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
645                                               rewriter.getSymbolRefAttr(kFree),
646                                               ValueRange(coroMem.getResult(0)));
647 
648     return success();
649   }
650 };
651 } // namespace
652 
653 //===----------------------------------------------------------------------===//
654 // Convert async.coro.end to @llvm.coro.end intrinsic.
655 //===----------------------------------------------------------------------===//
656 
657 namespace {
658 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
659 public:
660   using OpConversionPattern::OpConversionPattern;
661 
662   LogicalResult
663   matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands,
664                   ConversionPatternRewriter &rewriter) const override {
665     // We are not in the block that is part of the unwind sequence.
666     auto constFalse = rewriter.create<LLVM::ConstantOp>(
667         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
668 
669     // Mark the end of a coroutine: @llvm.coro.end.
670     auto coroHdl = CoroEndOpAdaptor(operands).handle();
671     rewriter.create<LLVM::CallOp>(op->getLoc(), rewriter.getI1Type(),
672                                   rewriter.getSymbolRefAttr(kCoroEnd),
673                                   ValueRange({coroHdl, constFalse}));
674     rewriter.eraseOp(op);
675 
676     return success();
677   }
678 };
679 } // namespace
680 
681 //===----------------------------------------------------------------------===//
682 // Convert async.coro.save to @llvm.coro.save intrinsic.
683 //===----------------------------------------------------------------------===//
684 
685 namespace {
686 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
687 public:
688   using OpConversionPattern::OpConversionPattern;
689 
690   LogicalResult
691   matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands,
692                   ConversionPatternRewriter &rewriter) const override {
693     // Save the coroutine state: @llvm.coro.save
694     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
695         op, AsyncAPI::tokenType(op->getContext()),
696         rewriter.getSymbolRefAttr(kCoroSave), operands);
697 
698     return success();
699   }
700 };
701 } // namespace
702 
703 //===----------------------------------------------------------------------===//
704 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
705 //===----------------------------------------------------------------------===//
706 
707 namespace {
708 
709 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
710 /// branch to the appropriate block based on the return code.
711 ///
712 /// Before:
713 ///
714 ///   ^suspended:
715 ///     "opBefore"(...)
716 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
717 ///   ^resume:
718 ///     "op"(...)
719 ///   ^cleanup: ...
720 ///   ^suspend: ...
721 ///
722 /// After:
723 ///
724 ///   ^suspended:
725 ///     "opBefore"(...)
726 ///     %suspend = llmv.call @llvm.coro.suspend(...)
727 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
728 ///   ^resume:
729 ///     "op"(...)
730 ///   ^cleanup: ...
731 ///   ^suspend: ...
732 ///
733 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
734 public:
735   using OpConversionPattern::OpConversionPattern;
736 
737   LogicalResult
738   matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands,
739                   ConversionPatternRewriter &rewriter) const override {
740     auto i8 = rewriter.getIntegerType(8);
741     auto i32 = rewriter.getI32Type();
742     auto loc = op->getLoc();
743 
744     // This is not a final suspension point.
745     auto constFalse = rewriter.create<LLVM::ConstantOp>(
746         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
747 
748     // Suspend a coroutine: @llvm.coro.suspend
749     auto coroState = CoroSuspendOpAdaptor(operands).state();
750     auto coroSuspend = rewriter.create<LLVM::CallOp>(
751         loc, i8, rewriter.getSymbolRefAttr(kCoroSuspend),
752         ValueRange({coroState, constFalse}));
753 
754     // Cast return code to i32.
755 
756     // After a suspension point decide if we should branch into resume, cleanup
757     // or suspend block of the coroutine (see @llvm.coro.suspend return code
758     // documentation).
759     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
760     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
761                                               op.cleanupDest()};
762     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
763         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult(0)),
764         /*defaultDestination=*/op.suspendDest(),
765         /*defaultOperands=*/ValueRange(),
766         /*caseValues=*/caseValues,
767         /*caseDestinations=*/caseDest,
768         /*caseOperands=*/ArrayRef<ValueRange>(),
769         /*branchWeights=*/ArrayRef<int32_t>());
770 
771     return success();
772   }
773 };
774 } // namespace
775 
776 //===----------------------------------------------------------------------===//
777 // Convert async.runtime.create to the corresponding runtime API call.
778 //
779 // To allocate storage for the async values we use getelementptr trick:
780 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
781 //===----------------------------------------------------------------------===//
782 
783 namespace {
784 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
785 public:
786   using OpConversionPattern::OpConversionPattern;
787 
788   LogicalResult
789   matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands,
790                   ConversionPatternRewriter &rewriter) const override {
791     TypeConverter *converter = getTypeConverter();
792     Type resultType = op->getResultTypes()[0];
793 
794     // Tokens and Groups lowered to function calls without arguments.
795     if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
796       rewriter.replaceOpWithNewOp<CallOp>(
797           op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
798           converter->convertType(resultType));
799       return success();
800     }
801 
802     // To create a value we need to compute the storage requirement.
803     if (auto value = resultType.dyn_cast<ValueType>()) {
804       // Returns the size requirements for the async value storage.
805       auto sizeOf = [&](ValueType valueType) -> Value {
806         auto loc = op->getLoc();
807         auto i32 = rewriter.getI32Type();
808 
809         auto storedType = converter->convertType(valueType.getValueType());
810         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
811 
812         // %Size = getelementptr %T* null, int 1
813         // %SizeI = ptrtoint %T* %Size to i32
814         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
815         auto one = rewriter.create<LLVM::ConstantOp>(
816             loc, i32, rewriter.getI32IntegerAttr(1));
817         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
818                                                 one.getResult());
819         return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep);
820       };
821 
822       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
823                                           sizeOf(value));
824 
825       return success();
826     }
827 
828     return rewriter.notifyMatchFailure(op, "unsupported async type");
829   }
830 };
831 } // namespace
832 
833 //===----------------------------------------------------------------------===//
834 // Convert async.runtime.set_available to the corresponding runtime API call.
835 //===----------------------------------------------------------------------===//
836 
837 namespace {
838 class RuntimeSetAvailableOpLowering
839     : public OpConversionPattern<RuntimeSetAvailableOp> {
840 public:
841   using OpConversionPattern::OpConversionPattern;
842 
843   LogicalResult
844   matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
845                   ConversionPatternRewriter &rewriter) const override {
846     Type operandType = op.operand().getType();
847 
848     if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) {
849       rewriter.create<CallOp>(op->getLoc(),
850                               operandType.isa<TokenType>() ? kEmplaceToken
851                                                            : kEmplaceValue,
852                               TypeRange(), operands);
853       rewriter.eraseOp(op);
854       return success();
855     }
856 
857     return rewriter.notifyMatchFailure(op, "unsupported async type");
858   }
859 };
860 } // namespace
861 
862 //===----------------------------------------------------------------------===//
863 // Convert async.runtime.await to the corresponding runtime API call.
864 //===----------------------------------------------------------------------===//
865 
866 namespace {
867 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
868 public:
869   using OpConversionPattern::OpConversionPattern;
870 
871   LogicalResult
872   matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
873                   ConversionPatternRewriter &rewriter) const override {
874     Type operandType = op.operand().getType();
875 
876     StringRef apiFuncName;
877     if (operandType.isa<TokenType>())
878       apiFuncName = kAwaitToken;
879     else if (operandType.isa<ValueType>())
880       apiFuncName = kAwaitValue;
881     else if (operandType.isa<GroupType>())
882       apiFuncName = kAwaitGroup;
883     else
884       return rewriter.notifyMatchFailure(op, "unsupported async type");
885 
886     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
887     rewriter.eraseOp(op);
888 
889     return success();
890   }
891 };
892 } // namespace
893 
894 //===----------------------------------------------------------------------===//
895 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
896 //===----------------------------------------------------------------------===//
897 
898 namespace {
899 class RuntimeAwaitAndResumeOpLowering
900     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
901 public:
902   using OpConversionPattern::OpConversionPattern;
903 
904   LogicalResult
905   matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
906                   ConversionPatternRewriter &rewriter) const override {
907     Type operandType = op.operand().getType();
908 
909     StringRef apiFuncName;
910     if (operandType.isa<TokenType>())
911       apiFuncName = kAwaitTokenAndExecute;
912     else if (operandType.isa<ValueType>())
913       apiFuncName = kAwaitValueAndExecute;
914     else if (operandType.isa<GroupType>())
915       apiFuncName = kAwaitAllAndExecute;
916     else
917       return rewriter.notifyMatchFailure(op, "unsupported async type");
918 
919     Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
920     Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
921 
922     // A pointer to coroutine resume intrinsic wrapper.
923     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
924     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
925         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
926 
927     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
928                             ValueRange({operand, handle, resumePtr.res()}));
929     rewriter.eraseOp(op);
930 
931     return success();
932   }
933 };
934 } // namespace
935 
936 //===----------------------------------------------------------------------===//
937 // Convert async.runtime.resume to the corresponding runtime API call.
938 //===----------------------------------------------------------------------===//
939 
940 namespace {
941 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
942 public:
943   using OpConversionPattern::OpConversionPattern;
944 
945   LogicalResult
946   matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands,
947                   ConversionPatternRewriter &rewriter) const override {
948     // A pointer to coroutine resume intrinsic wrapper.
949     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
950     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
951         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
952 
953     // Call async runtime API to execute a coroutine in the managed thread.
954     auto coroHdl = RuntimeResumeOpAdaptor(operands).handle();
955     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute,
956                                         ValueRange({coroHdl, resumePtr.res()}));
957 
958     return success();
959   }
960 };
961 } // namespace
962 
963 //===----------------------------------------------------------------------===//
964 // Convert async.runtime.store to the corresponding runtime API call.
965 //===----------------------------------------------------------------------===//
966 
967 namespace {
968 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
969 public:
970   using OpConversionPattern::OpConversionPattern;
971 
972   LogicalResult
973   matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands,
974                   ConversionPatternRewriter &rewriter) const override {
975     Location loc = op->getLoc();
976 
977     // Get a pointer to the async value storage from the runtime.
978     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
979     auto storage = RuntimeStoreOpAdaptor(operands).storage();
980     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
981                                               TypeRange(i8Ptr), storage);
982 
983     // Cast from i8* to the LLVM pointer type.
984     auto valueType = op.value().getType();
985     auto llvmValueType = getTypeConverter()->convertType(valueType);
986     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
987         loc, LLVM::LLVMPointerType::get(llvmValueType),
988         storagePtr.getResult(0));
989 
990     // Store the yielded value into the async value storage.
991     auto value = RuntimeStoreOpAdaptor(operands).value();
992     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
993 
994     // Erase the original runtime store operation.
995     rewriter.eraseOp(op);
996 
997     return success();
998   }
999 };
1000 } // namespace
1001 
1002 //===----------------------------------------------------------------------===//
1003 // Convert async.runtime.load to the corresponding runtime API call.
1004 //===----------------------------------------------------------------------===//
1005 
1006 namespace {
1007 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
1008 public:
1009   using OpConversionPattern::OpConversionPattern;
1010 
1011   LogicalResult
1012   matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands,
1013                   ConversionPatternRewriter &rewriter) const override {
1014     Location loc = op->getLoc();
1015 
1016     // Get a pointer to the async value storage from the runtime.
1017     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
1018     auto storage = RuntimeLoadOpAdaptor(operands).storage();
1019     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
1020                                               TypeRange(i8Ptr), storage);
1021 
1022     // Cast from i8* to the LLVM pointer type.
1023     auto valueType = op.result().getType();
1024     auto llvmValueType = getTypeConverter()->convertType(valueType);
1025     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
1026         loc, LLVM::LLVMPointerType::get(llvmValueType),
1027         storagePtr.getResult(0));
1028 
1029     // Load from the casted pointer.
1030     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
1031 
1032     return success();
1033   }
1034 };
1035 } // namespace
1036 
1037 //===----------------------------------------------------------------------===//
1038 // Convert async.runtime.add_to_group to the corresponding runtime API call.
1039 //===----------------------------------------------------------------------===//
1040 
1041 namespace {
1042 class RuntimeAddToGroupOpLowering
1043     : public OpConversionPattern<RuntimeAddToGroupOp> {
1044 public:
1045   using OpConversionPattern::OpConversionPattern;
1046 
1047   LogicalResult
1048   matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands,
1049                   ConversionPatternRewriter &rewriter) const override {
1050     // Currently we can only add tokens to the group.
1051     if (!op.operand().getType().isa<TokenType>())
1052       return rewriter.notifyMatchFailure(op, "only token type is supported");
1053 
1054     // Replace with a runtime API function call.
1055     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup,
1056                                         rewriter.getI64Type(), operands);
1057 
1058     return success();
1059   }
1060 };
1061 } // namespace
1062 
1063 //===----------------------------------------------------------------------===//
1064 // Async reference counting ops lowering (`async.runtime.add_ref` and
1065 // `async.runtime.drop_ref` to the corresponding API calls).
1066 //===----------------------------------------------------------------------===//
1067 
1068 namespace {
1069 template <typename RefCountingOp>
1070 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
1071 public:
1072   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
1073                                  StringRef apiFunctionName)
1074       : OpConversionPattern<RefCountingOp>(converter, ctx),
1075         apiFunctionName(apiFunctionName) {}
1076 
1077   LogicalResult
1078   matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands,
1079                   ConversionPatternRewriter &rewriter) const override {
1080     auto count =
1081         rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(),
1082                                     rewriter.getI32IntegerAttr(op.count()));
1083 
1084     auto operand = typename RefCountingOp::Adaptor(operands).operand();
1085     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
1086                                         ValueRange({operand, count}));
1087 
1088     return success();
1089   }
1090 
1091 private:
1092   StringRef apiFunctionName;
1093 };
1094 
1095 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
1096 public:
1097   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
1098       : RefCountingOpLowering(converter, ctx, kAddRef) {}
1099 };
1100 
1101 class RuntimeDropRefOpLowering
1102     : public RefCountingOpLowering<RuntimeDropRefOp> {
1103 public:
1104   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
1105       : RefCountingOpLowering(converter, ctx, kDropRef) {}
1106 };
1107 } // namespace
1108 
1109 //===----------------------------------------------------------------------===//
1110 // Convert return operations that return async values from async regions.
1111 //===----------------------------------------------------------------------===//
1112 
1113 namespace {
1114 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
1115 public:
1116   using OpConversionPattern::OpConversionPattern;
1117 
1118   LogicalResult
1119   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
1120                   ConversionPatternRewriter &rewriter) const override {
1121     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
1122     return success();
1123   }
1124 };
1125 } // namespace
1126 
1127 //===----------------------------------------------------------------------===//
1128 // Convert async.create_group operation to async.runtime.create
1129 //===----------------------------------------------------------------------===//
1130 
1131 namespace {
1132 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
1133 public:
1134   using OpConversionPattern::OpConversionPattern;
1135 
1136   LogicalResult
1137   matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
1138                   ConversionPatternRewriter &rewriter) const override {
1139     rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
1140         op, GroupType::get(op->getContext()));
1141     return success();
1142   }
1143 };
1144 } // namespace
1145 
1146 //===----------------------------------------------------------------------===//
1147 // Convert async.add_to_group operation to async.runtime.add_to_group.
1148 //===----------------------------------------------------------------------===//
1149 
1150 namespace {
1151 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
1152 public:
1153   using OpConversionPattern::OpConversionPattern;
1154 
1155   LogicalResult
1156   matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands,
1157                   ConversionPatternRewriter &rewriter) const override {
1158     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
1159         op, rewriter.getIndexType(), operands);
1160     return success();
1161   }
1162 };
1163 } // namespace
1164 
1165 //===----------------------------------------------------------------------===//
1166 // Convert async.await and async.await_all operations to the async.runtime.await
1167 // or async.runtime.await_and_resume operations.
1168 //===----------------------------------------------------------------------===//
1169 
1170 namespace {
1171 template <typename AwaitType, typename AwaitableType>
1172 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
1173   using AwaitAdaptor = typename AwaitType::Adaptor;
1174 
1175 public:
1176   AwaitOpLoweringBase(
1177       MLIRContext *ctx,
1178       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
1179       : OpConversionPattern<AwaitType>(ctx),
1180         outlinedFunctions(outlinedFunctions) {}
1181 
1182   LogicalResult
1183   matchAndRewrite(AwaitType op, ArrayRef<Value> operands,
1184                   ConversionPatternRewriter &rewriter) const override {
1185     // We can only await on one the `AwaitableType` (for `await` it can be
1186     // a `token` or a `value`, for `await_all` it must be a `group`).
1187     if (!op.operand().getType().template isa<AwaitableType>())
1188       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
1189 
1190     // Check if await operation is inside the outlined coroutine function.
1191     auto func = op->template getParentOfType<FuncOp>();
1192     auto outlined = outlinedFunctions.find(func);
1193     const bool isInCoroutine = outlined != outlinedFunctions.end();
1194 
1195     Location loc = op->getLoc();
1196     Value operand = AwaitAdaptor(operands).operand();
1197 
1198     // Inside regular functions we use the blocking wait operation to wait for
1199     // the async object (token, value or group) to become available.
1200     if (!isInCoroutine)
1201       rewriter.create<RuntimeAwaitOp>(loc, operand);
1202 
1203     // Inside the coroutine we convert await operation into coroutine suspension
1204     // point, and resume execution asynchronously.
1205     if (isInCoroutine) {
1206       const CoroMachinery &coro = outlined->getSecond();
1207       Block *suspended = op->getBlock();
1208 
1209       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
1210       MLIRContext *ctx = op->getContext();
1211 
1212       // Save the coroutine state and resume on a runtime managed thread when
1213       // the operand becomes available.
1214       auto coroSaveOp =
1215           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
1216       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
1217 
1218       // Split the entry block before the await operation.
1219       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
1220 
1221       // Add async.coro.suspend as a suspended block terminator.
1222       builder.setInsertionPointToEnd(suspended);
1223       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
1224                                     coro.cleanup);
1225 
1226       // Make sure that replacement value will be constructed in resume block.
1227       rewriter.setInsertionPointToStart(resume);
1228     }
1229 
1230     // Erase or replace the await operation with the new value.
1231     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
1232       rewriter.replaceOp(op, replaceWith);
1233     else
1234       rewriter.eraseOp(op);
1235 
1236     return success();
1237   }
1238 
1239   virtual Value getReplacementValue(AwaitType op, Value operand,
1240                                     ConversionPatternRewriter &rewriter) const {
1241     return Value();
1242   }
1243 
1244 private:
1245   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
1246 };
1247 
1248 /// Lowering for `async.await` with a token operand.
1249 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
1250   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
1251 
1252 public:
1253   using Base::Base;
1254 };
1255 
1256 /// Lowering for `async.await` with a value operand.
1257 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
1258   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
1259 
1260 public:
1261   using Base::Base;
1262 
1263   Value
1264   getReplacementValue(AwaitOp op, Value operand,
1265                       ConversionPatternRewriter &rewriter) const override {
1266     // Load from the async value storage.
1267     auto valueType = operand.getType().cast<ValueType>().getValueType();
1268     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
1269   }
1270 };
1271 
1272 /// Lowering for `async.await_all` operation.
1273 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
1274   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
1275 
1276 public:
1277   using Base::Base;
1278 };
1279 
1280 } // namespace
1281 
1282 //===----------------------------------------------------------------------===//
1283 // Convert async.yield operation to async.runtime operations.
1284 //===----------------------------------------------------------------------===//
1285 
1286 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
1287 public:
1288   YieldOpLowering(
1289       MLIRContext *ctx,
1290       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
1291       : OpConversionPattern<async::YieldOp>(ctx),
1292         outlinedFunctions(outlinedFunctions) {}
1293 
1294   LogicalResult
1295   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
1296                   ConversionPatternRewriter &rewriter) const override {
1297     // Check if yield operation is inside the outlined coroutine function.
1298     auto func = op->template getParentOfType<FuncOp>();
1299     auto outlined = outlinedFunctions.find(func);
1300     if (outlined == outlinedFunctions.end())
1301       return rewriter.notifyMatchFailure(
1302           op, "operation is not inside the outlined async.execute function");
1303 
1304     Location loc = op->getLoc();
1305     const CoroMachinery &coro = outlined->getSecond();
1306 
1307     // Store yielded values into the async values storage and switch async
1308     // values state to available.
1309     for (auto tuple : llvm::zip(operands, coro.returnValues)) {
1310       Value yieldValue = std::get<0>(tuple);
1311       Value asyncValue = std::get<1>(tuple);
1312       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
1313       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
1314     }
1315 
1316     // Switch the coroutine completion token to available state.
1317     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
1318 
1319     return success();
1320   }
1321 
1322 private:
1323   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
1324 };
1325 
1326 //===----------------------------------------------------------------------===//
1327 
1328 namespace {
1329 struct ConvertAsyncToLLVMPass
1330     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
1331   void runOnOperation() override;
1332 };
1333 } // namespace
1334 
1335 void ConvertAsyncToLLVMPass::runOnOperation() {
1336   ModuleOp module = getOperation();
1337   SymbolTable symbolTable(module);
1338 
1339   MLIRContext *ctx = &getContext();
1340 
1341   // Outline all `async.execute` body regions into async functions (coroutines).
1342   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
1343 
1344   // We use conversion to LLVM type to ensure that all `async.value` operands
1345   // and results can be lowered to LLVM load and store operations.
1346   LLVMTypeConverter llvmConverter(ctx);
1347   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
1348 
1349   // Returns true if the `async.value` payload is convertible to LLVM.
1350   auto isConvertibleToLlvm = [&](Type type) -> bool {
1351     auto valueType = type.cast<ValueType>().getValueType();
1352     return static_cast<bool>(llvmConverter.convertType(valueType));
1353   };
1354 
1355   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
1356     // All operands and results must be convertible to LLVM.
1357     if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) {
1358       execute.emitOpError("operands payload must be convertible to LLVM type");
1359       return WalkResult::interrupt();
1360     }
1361     if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) {
1362       execute.emitOpError("results payload must be convertible to LLVM type");
1363       return WalkResult::interrupt();
1364     }
1365 
1366     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
1367 
1368     return WalkResult::advance();
1369   });
1370 
1371   // Failed to outline all async execute operations.
1372   if (outlineResult.wasInterrupted()) {
1373     signalPassFailure();
1374     return;
1375   }
1376 
1377   LLVM_DEBUG({
1378     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
1379                  << " async functions\n";
1380   });
1381 
1382   // Add declarations for all functions required by the coroutines lowering.
1383   addResumeFunction(module);
1384   addAsyncRuntimeApiDeclarations(module);
1385   addCoroutineIntrinsicsDeclarations(module);
1386   addCRuntimeDeclarations(module);
1387 
1388   // ------------------------------------------------------------------------ //
1389   // Lower async operations to async.runtime operations.
1390   // ------------------------------------------------------------------------ //
1391   OwningRewritePatternList asyncPatterns;
1392 
1393   // Async lowering does not use type converter because it must preserve all
1394   // types for async.runtime operations.
1395   asyncPatterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
1396   asyncPatterns.insert<AwaitTokenOpLowering, AwaitValueOpLowering,
1397                        AwaitAllOpLowering, YieldOpLowering>(ctx,
1398                                                             outlinedFunctions);
1399 
1400   // All high level async operations must be lowered to the runtime operations.
1401   ConversionTarget runtimeTarget(*ctx);
1402   runtimeTarget.addLegalDialect<AsyncDialect>();
1403   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
1404   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
1405 
1406   if (failed(applyPartialConversion(module, runtimeTarget,
1407                                     std::move(asyncPatterns)))) {
1408     signalPassFailure();
1409     return;
1410   }
1411 
1412   // ------------------------------------------------------------------------ //
1413   // Lower async.runtime and async.coro operations to Async Runtime API and
1414   // LLVM coroutine intrinsics.
1415   // ------------------------------------------------------------------------ //
1416 
1417   // Convert async dialect types and operations to LLVM dialect.
1418   AsyncRuntimeTypeConverter converter;
1419   OwningRewritePatternList patterns;
1420 
1421   // Convert async types in function signatures and function calls.
1422   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
1423   populateCallOpTypeConversionPattern(patterns, ctx, converter);
1424 
1425   // Convert return operations inside async.execute regions.
1426   patterns.insert<ReturnOpOpConversion>(converter, ctx);
1427 
1428   // Lower async.runtime operations to the async runtime API calls.
1429   patterns.insert<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
1430                   RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1431                   RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
1432                   RuntimeDropRefOpLowering>(converter, ctx);
1433 
1434   // Lower async.runtime operations that rely on LLVM type converter to convert
1435   // from async value payload type to the LLVM type.
1436   patterns.insert<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
1437                   RuntimeLoadOpLowering>(llvmConverter, ctx);
1438 
1439   // Lower async coroutine operations to LLVM coroutine intrinsics.
1440   patterns.insert<CoroIdOpConversion, CoroBeginOpConversion,
1441                   CoroFreeOpConversion, CoroEndOpConversion,
1442                   CoroSaveOpConversion, CoroSuspendOpConversion>(converter,
1443                                                                  ctx);
1444 
1445   ConversionTarget target(*ctx);
1446   target.addLegalOp<ConstantOp>();
1447   target.addLegalDialect<LLVM::LLVMDialect>();
1448 
1449   // All operations from Async dialect must be lowered to the runtime API and
1450   // LLVM intrinsics calls.
1451   target.addIllegalDialect<AsyncDialect>();
1452 
1453   // Add dynamic legality constraints to apply conversions defined above.
1454   target.addDynamicallyLegalOp<FuncOp>(
1455       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1456   target.addDynamicallyLegalOp<ReturnOp>(
1457       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1458   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1459     return converter.isSignatureLegal(op.getCalleeType());
1460   });
1461 
1462   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1463     signalPassFailure();
1464 }
1465 
1466 //===----------------------------------------------------------------------===//
1467 // Patterns for structural type conversions for the Async dialect operations.
1468 //===----------------------------------------------------------------------===//
1469 
1470 namespace {
1471 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1472 public:
1473   using OpConversionPattern::OpConversionPattern;
1474   LogicalResult
1475   matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
1476                   ConversionPatternRewriter &rewriter) const override {
1477     ExecuteOp newOp =
1478         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1479     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1480                                 newOp.getRegion().end());
1481 
1482     // Set operands and update block argument and result types.
1483     newOp->setOperands(operands);
1484     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1485       return failure();
1486     for (auto result : newOp.getResults())
1487       result.setType(typeConverter->convertType(result.getType()));
1488 
1489     rewriter.replaceOp(op, newOp.getResults());
1490     return success();
1491   }
1492 };
1493 
1494 // Dummy pattern to trigger the appropriate type conversion / materialization.
1495 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1496 public:
1497   using OpConversionPattern::OpConversionPattern;
1498   LogicalResult
1499   matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
1500                   ConversionPatternRewriter &rewriter) const override {
1501     rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
1502     return success();
1503   }
1504 };
1505 
1506 // Dummy pattern to trigger the appropriate type conversion / materialization.
1507 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1508 public:
1509   using OpConversionPattern::OpConversionPattern;
1510   LogicalResult
1511   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
1512                   ConversionPatternRewriter &rewriter) const override {
1513     rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
1514     return success();
1515   }
1516 };
1517 } // namespace
1518 
1519 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1520   return std::make_unique<ConvertAsyncToLLVMPass>();
1521 }
1522 
1523 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1524     MLIRContext *context, TypeConverter &typeConverter,
1525     OwningRewritePatternList &patterns, ConversionTarget &target) {
1526   typeConverter.addConversion([&](TokenType type) { return type; });
1527   typeConverter.addConversion([&](ValueType type) {
1528     return ValueType::get(typeConverter.convertType(type.getValueType()));
1529   });
1530 
1531   patterns
1532       .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1533           typeConverter, context);
1534 
1535   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1536       [&](Operation *op) { return typeConverter.isLegal(op); });
1537 }
1538