1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Async/IR/Async.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/RegionUtils.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 #define DEBUG_TYPE "convert-async-to-llvm"
25 
26 using namespace mlir;
27 using namespace mlir::async;
28 
29 // Prefix for functions outlined from `async.execute` op regions.
30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
37 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
38 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
39 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
42 static constexpr const char *kAddTokenToGroup =
43     "mlirAsyncRuntimeAddTokenToGroup";
44 static constexpr const char *kAwaitAndExecute =
45     "mlirAsyncRuntimeAwaitTokenAndExecute";
46 static constexpr const char *kAwaitAllAndExecute =
47     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
48 
49 namespace {
50 // Async Runtime API function types.
51 struct AsyncAPI {
52   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
53     return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
54   }
55 
56   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
57     return FunctionType::get({}, {GroupType::get(ctx)}, ctx);
58   }
59 
60   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
61     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
62   }
63 
64   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
65     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
66   }
67 
68   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
69     return FunctionType::get({GroupType::get(ctx)}, {}, ctx);
70   }
71 
72   static FunctionType executeFunctionType(MLIRContext *ctx) {
73     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
74     auto resume = resumeFunctionType(ctx).getPointerTo();
75     return FunctionType::get({hdl, resume}, {}, ctx);
76   }
77 
78   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
79     auto i64 = IntegerType::get(64, ctx);
80     return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64},
81                              ctx);
82   }
83 
84   static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
85     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
86     auto resume = resumeFunctionType(ctx).getPointerTo();
87     return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
88   }
89 
90   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
91     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
92     auto resume = resumeFunctionType(ctx).getPointerTo();
93     return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx);
94   }
95 
96   // Auxiliary coroutine resume intrinsic wrapper.
97   static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
98     auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
99     auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
100     return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false);
101   }
102 };
103 } // namespace
104 
105 // Adds Async Runtime C API declarations to the module.
106 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
107   auto builder = OpBuilder::atBlockTerminator(module.getBody());
108 
109   auto addFuncDecl = [&](StringRef name, FunctionType type) {
110     if (module.lookupSymbol(name))
111       return;
112     builder.create<FuncOp>(module.getLoc(), name, type).setPrivate();
113   };
114 
115   MLIRContext *ctx = module.getContext();
116   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
117   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
118   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
119   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
120   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
121   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
122   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
123   addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
124   addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // LLVM coroutines intrinsics declarations.
129 //===----------------------------------------------------------------------===//
130 
131 static constexpr const char *kCoroId = "llvm.coro.id";
132 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
133 static constexpr const char *kCoroBegin = "llvm.coro.begin";
134 static constexpr const char *kCoroSave = "llvm.coro.save";
135 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
136 static constexpr const char *kCoroEnd = "llvm.coro.end";
137 static constexpr const char *kCoroFree = "llvm.coro.free";
138 static constexpr const char *kCoroResume = "llvm.coro.resume";
139 
140 /// Adds an LLVM function declaration to a module.
141 static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name,
142                             LLVM::LLVMType ret,
143                             ArrayRef<LLVM::LLVMType> params) {
144   if (module.lookupSymbol(name))
145     return;
146   LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false);
147   builder.create<LLVM::LLVMFuncOp>(module.getLoc(), name, type);
148 }
149 
150 /// Adds coroutine intrinsics declarations to the module.
151 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
152   using namespace mlir::LLVM;
153 
154   MLIRContext *ctx = module.getContext();
155   OpBuilder builder(module.getBody()->getTerminator());
156 
157   auto token = LLVMTokenType::get(ctx);
158   auto voidTy = LLVMType::getVoidTy(ctx);
159 
160   auto i8 = LLVMType::getInt8Ty(ctx);
161   auto i1 = LLVMType::getInt1Ty(ctx);
162   auto i32 = LLVMType::getInt32Ty(ctx);
163   auto i64 = LLVMType::getInt64Ty(ctx);
164   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
165 
166   addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
167   addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
168   addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
169   addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
170   addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
171   addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
172   addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
173   addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Add malloc/free declarations to the module.
178 //===----------------------------------------------------------------------===//
179 
180 static constexpr const char *kMalloc = "malloc";
181 static constexpr const char *kFree = "free";
182 
183 /// Adds malloc/free declarations to the module.
184 static void addCRuntimeDeclarations(ModuleOp module) {
185   using namespace mlir::LLVM;
186 
187   MLIRContext *ctx = module.getContext();
188   OpBuilder builder(module.getBody()->getTerminator());
189 
190   auto voidTy = LLVMType::getVoidTy(ctx);
191   auto i64 = LLVMType::getInt64Ty(ctx);
192   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
193 
194   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
195   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // Coroutine resume function wrapper.
200 //===----------------------------------------------------------------------===//
201 
202 static constexpr const char *kResume = "__resume";
203 
204 // A function that takes a coroutine handle and calls a `llvm.coro.resume`
205 // intrinsics. We need this function to be able to pass it to the async
206 // runtime execute API.
207 static void addResumeFunction(ModuleOp module) {
208   MLIRContext *ctx = module.getContext();
209 
210   OpBuilder moduleBuilder(module.getBody()->getTerminator());
211   Location loc = module.getLoc();
212 
213   if (module.lookupSymbol(kResume))
214     return;
215 
216   auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
217   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
218 
219   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
220       loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false));
221   resumeOp.setPrivate();
222 
223   auto *block = resumeOp.addEntryBlock();
224   OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
225 
226   blockBuilder.create<LLVM::CallOp>(loc, Type(),
227                                     blockBuilder.getSymbolRefAttr(kCoroResume),
228                                     resumeOp.getArgument(0));
229 
230   blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // async.execute op outlining to the coroutine functions.
235 //===----------------------------------------------------------------------===//
236 
237 // Function targeted for coroutine transformation has two additional blocks at
238 // the end: coroutine cleanup and coroutine suspension.
239 //
240 // async.await op lowering additionaly creates a resume block for each
241 // operation to enable non-blocking waiting via coroutine suspension.
242 namespace {
243 struct CoroMachinery {
244   Value asyncToken;
245   Value coroHandle;
246   Block *cleanup;
247   Block *suspend;
248 };
249 } // namespace
250 
251 // Builds an coroutine template compatible with LLVM coroutines lowering.
252 //
253 //  - `entry` block sets up the coroutine.
254 //  - `cleanup` block cleans up the coroutine state.
255 //  - `suspend block after the @llvm.coro.end() defines what value will be
256 //    returned to the initial caller of a coroutine. Everything before the
257 //    @llvm.coro.end() will be executed at every suspension point.
258 //
259 // Coroutine structure (only the important bits):
260 //
261 //   func @async_execute_fn(<function-arguments>) -> !async.token {
262 //     ^entryBlock(<function-arguments>):
263 //       %token = <async token> : !async.token // create async runtime token
264 //       %hdl = llvm.call @llvm.coro.id(...)   // create a coroutine handle
265 //       br ^cleanup
266 //
267 //     ^cleanup:
268 //       llvm.call @llvm.coro.free(...)        // delete coroutine state
269 //       br ^suspend
270 //
271 //     ^suspend:
272 //       llvm.call @llvm.coro.end(...)         // marks the end of a coroutine
273 //       return %token : !async.token
274 //   }
275 //
276 // The actual code for the async.execute operation body region will be inserted
277 // before the entry block terminator.
278 //
279 //
280 static CoroMachinery setupCoroMachinery(FuncOp func) {
281   assert(func.getBody().empty() && "Function must have empty body");
282 
283   MLIRContext *ctx = func.getContext();
284 
285   auto token = LLVM::LLVMTokenType::get(ctx);
286   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
287   auto i32 = LLVM::LLVMType::getInt32Ty(ctx);
288   auto i64 = LLVM::LLVMType::getInt64Ty(ctx);
289   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
290 
291   Block *entryBlock = func.addEntryBlock();
292   Location loc = func.getBody().getLoc();
293 
294   OpBuilder builder = OpBuilder::atBlockBegin(entryBlock);
295 
296   // ------------------------------------------------------------------------ //
297   // Allocate async tokens/values that we will return from a ramp function.
298   // ------------------------------------------------------------------------ //
299   auto createToken =
300       builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx));
301 
302   // ------------------------------------------------------------------------ //
303   // Initialize coroutine: allocate frame, get coroutine handle.
304   // ------------------------------------------------------------------------ //
305 
306   // Constants for initializing coroutine frame.
307   auto constZero =
308       builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0));
309   auto constFalse =
310       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
311   auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr);
312 
313   // Get coroutine id: @llvm.coro.id
314   auto coroId = builder.create<LLVM::CallOp>(
315       loc, token, builder.getSymbolRefAttr(kCoroId),
316       ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
317 
318   // Get coroutine frame size: @llvm.coro.size.i64
319   auto coroSize = builder.create<LLVM::CallOp>(
320       loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
321 
322   // Allocate memory for coroutine frame.
323   auto coroAlloc = builder.create<LLVM::CallOp>(
324       loc, i8Ptr, builder.getSymbolRefAttr(kMalloc),
325       ValueRange(coroSize.getResult(0)));
326 
327   // Begin a coroutine: @llvm.coro.begin
328   auto coroHdl = builder.create<LLVM::CallOp>(
329       loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
330       ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
331 
332   Block *cleanupBlock = func.addBlock();
333   Block *suspendBlock = func.addBlock();
334 
335   // ------------------------------------------------------------------------ //
336   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
337   // ------------------------------------------------------------------------ //
338   builder.setInsertionPointToStart(cleanupBlock);
339 
340   // Get a pointer to the coroutine frame memory: @llvm.coro.free.
341   auto coroMem = builder.create<LLVM::CallOp>(
342       loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree),
343       ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
344 
345   // Free the memory.
346   builder.create<LLVM::CallOp>(loc, Type(), builder.getSymbolRefAttr(kFree),
347                                ValueRange(coroMem.getResult(0)));
348   // Branch into the suspend block.
349   builder.create<BranchOp>(loc, suspendBlock);
350 
351   // ------------------------------------------------------------------------ //
352   // Coroutine suspend block: mark the end of a coroutine and return allocated
353   // async token.
354   // ------------------------------------------------------------------------ //
355   builder.setInsertionPointToStart(suspendBlock);
356 
357   // Mark the end of a coroutine: @llvm.coro.end.
358   builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd),
359                                ValueRange({coroHdl.getResult(0), constFalse}));
360 
361   // Return created `async.token` from the suspend block. This will be the
362   // return value of a coroutine ramp function.
363   builder.create<ReturnOp>(loc, createToken.getResult(0));
364 
365   // Branch from the entry block to the cleanup block to create a valid CFG.
366   builder.setInsertionPointToEnd(entryBlock);
367 
368   builder.create<BranchOp>(loc, cleanupBlock);
369 
370   // `async.await` op lowering will create resume blocks for async
371   // continuations, and will conditionally branch to cleanup or suspend blocks.
372 
373   return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
374           suspendBlock};
375 }
376 
377 // Adds a suspension point before the `op`, and moves `op` and all operations
378 // after it into the resume block. Returns a pointer to the resume block.
379 //
380 // `coroState` must be a value returned from the call to @llvm.coro.save(...)
381 // intrinsic (saved coroutine state).
382 //
383 // Before:
384 //
385 //   ^bb0:
386 //     "opBefore"(...)
387 //     "op"(...)
388 //   ^cleanup: ...
389 //   ^suspend: ...
390 //
391 // After:
392 //
393 //   ^bb0:
394 //     "opBefore"(...)
395 //     %suspend = llmv.call @llvm.coro.suspend(...)
396 //     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
397 //   ^resume:
398 //     "op"(...)
399 //   ^cleanup: ...
400 //   ^suspend: ...
401 //
402 static Block *addSuspensionPoint(CoroMachinery coro, Value coroState,
403                                  Operation *op) {
404   MLIRContext *ctx = op->getContext();
405   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
406   auto i8 = LLVM::LLVMType::getInt8Ty(ctx);
407 
408   Location loc = op->getLoc();
409   Block *splitBlock = op->getBlock();
410 
411   // Split the block before `op`, newly added block is the resume block.
412   Block *resume = splitBlock->splitBlock(op);
413 
414   // Add a coroutine suspension in place of original `op` in the split block.
415   OpBuilder builder = OpBuilder::atBlockEnd(splitBlock);
416 
417   auto constFalse =
418       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
419 
420   // Suspend a coroutine: @llvm.coro.suspend
421   auto coroSuspend = builder.create<LLVM::CallOp>(
422       loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
423       ValueRange({coroState, constFalse}));
424 
425   // After a suspension point decide if we should branch into resume, cleanup
426   // or suspend block of the coroutine (see @llvm.coro.suspend return code
427   // documentation).
428   auto constZero =
429       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
430   auto constNegOne =
431       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
432 
433   Block *resumeOrCleanup = builder.createBlock(resume);
434 
435   // Suspend the coroutine ...?
436   builder.setInsertionPointToEnd(splitBlock);
437   auto isNegOne = builder.create<LLVM::ICmpOp>(
438       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
439   builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
440                                  /*falseDest=*/resumeOrCleanup);
441 
442   // ... or resume or cleanup the coroutine?
443   builder.setInsertionPointToStart(resumeOrCleanup);
444   auto isZero = builder.create<LLVM::ICmpOp>(
445       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
446   builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
447                                  /*falseDest=*/coro.cleanup);
448 
449   return resume;
450 }
451 
452 // Outline the body region attached to the `async.execute` op into a standalone
453 // function.
454 static std::pair<FuncOp, CoroMachinery>
455 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
456   ModuleOp module = execute.getParentOfType<ModuleOp>();
457 
458   MLIRContext *ctx = module.getContext();
459   Location loc = execute.getLoc();
460 
461   OpBuilder moduleBuilder(module.getBody()->getTerminator());
462 
463   // Collect all outlined function inputs.
464   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
465                                               execute.dependencies().end());
466   getUsedValuesDefinedAbove(execute.body(), functionInputs);
467 
468   // Collect types for the outlined function inputs and outputs.
469   auto typesRange = llvm::map_range(
470       functionInputs, [](Value value) { return value.getType(); });
471   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
472   auto outputTypes = execute.getResultTypes();
473 
474   auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
475   auto funcAttrs = ArrayRef<NamedAttribute>();
476 
477   // TODO: Derive outlined function name from the parent FuncOp (support
478   // multiple nested async.execute operations).
479   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
480   symbolTable.insert(func, moduleBuilder.getInsertionPoint());
481 
482   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
483 
484   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
485   // blocks, adding llvm.coro instrinsics and setting up control flow.
486   CoroMachinery coro = setupCoroMachinery(func);
487 
488   // Suspend async function at the end of an entry block, and resume it using
489   // Async execute API (execution will be resumed in a thread managed by the
490   // async runtime).
491   Block *entryBlock = &func.getBlocks().front();
492   OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock);
493 
494   // A pointer to coroutine resume intrinsic wrapper.
495   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
496   auto resumePtr = builder.create<LLVM::AddressOfOp>(
497       loc, resumeFnTy.getPointerTo(), kResume);
498 
499   // Save the coroutine state: @llvm.coro.save
500   auto coroSave = builder.create<LLVM::CallOp>(
501       loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
502       ValueRange({coro.coroHandle}));
503 
504   // Call async runtime API to execute a coroutine in the managed thread.
505   SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
506   builder.create<CallOp>(loc, Type(), kExecute, executeArgs);
507 
508   // Split the entry block before the terminator.
509   Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
510                                      entryBlock->getTerminator());
511 
512   // Await on all dependencies before starting to execute the body region.
513   builder.setInsertionPointToStart(resume);
514   for (size_t i = 0; i < execute.dependencies().size(); ++i)
515     builder.create<AwaitOp>(loc, func.getArgument(i));
516 
517   // Map from function inputs defined above the execute op to the function
518   // arguments.
519   BlockAndValueMapping valueMapping;
520   valueMapping.map(functionInputs, func.getArguments());
521 
522   // Clone all operations from the execute operation body into the outlined
523   // function body, and replace all `async.yield` operations with a call
524   // to async runtime to emplace the result token.
525   for (Operation &op : execute.body().getOps()) {
526     if (isa<async::YieldOp>(op)) {
527       builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken);
528       continue;
529     }
530     builder.clone(op, valueMapping);
531   }
532 
533   // Replace the original `async.execute` with a call to outlined function.
534   OpBuilder callBuilder(execute);
535   auto callOutlinedFunc =
536       callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(),
537                                  functionInputs.getArrayRef());
538   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
539   execute.erase();
540 
541   return {func, coro};
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // Convert Async dialect types to LLVM types.
546 //===----------------------------------------------------------------------===//
547 
548 namespace {
549 class AsyncRuntimeTypeConverter : public TypeConverter {
550 public:
551   AsyncRuntimeTypeConverter() { addConversion(convertType); }
552 
553   static Type convertType(Type type) {
554     MLIRContext *ctx = type.getContext();
555     // Convert async tokens and groups to opaque pointers.
556     if (type.isa<TokenType, GroupType>())
557       return LLVM::LLVMType::getInt8PtrTy(ctx);
558     return type;
559   }
560 };
561 } // namespace
562 
563 //===----------------------------------------------------------------------===//
564 // Convert types for all call operations to lowered async types.
565 //===----------------------------------------------------------------------===//
566 
567 namespace {
568 class CallOpOpConversion : public ConversionPattern {
569 public:
570   explicit CallOpOpConversion(MLIRContext *ctx)
571       : ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
572 
573   LogicalResult
574   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
575                   ConversionPatternRewriter &rewriter) const override {
576     AsyncRuntimeTypeConverter converter;
577 
578     SmallVector<Type, 5> resultTypes;
579     converter.convertTypes(op->getResultTypes(), resultTypes);
580 
581     CallOp call = cast<CallOp>(op);
582     rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
583                                         call.getOperands());
584 
585     return success();
586   }
587 };
588 } // namespace
589 
590 //===----------------------------------------------------------------------===//
591 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
592 //===----------------------------------------------------------------------===//
593 
594 namespace {
595 class CreateGroupOpLowering : public ConversionPattern {
596 public:
597   explicit CreateGroupOpLowering(MLIRContext *ctx)
598       : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {}
599 
600   LogicalResult
601   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
602                   ConversionPatternRewriter &rewriter) const override {
603     auto retTy = GroupType::get(op->getContext());
604     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy);
605     return success();
606   }
607 };
608 } // namespace
609 
610 //===----------------------------------------------------------------------===//
611 // async.add_to_group op lowering to runtime function call.
612 //===----------------------------------------------------------------------===//
613 
614 namespace {
615 class AddToGroupOpLowering : public ConversionPattern {
616 public:
617   explicit AddToGroupOpLowering(MLIRContext *ctx)
618       : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {}
619 
620   LogicalResult
621   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
622                   ConversionPatternRewriter &rewriter) const override {
623     // Currently we can only add tokens to the group.
624     auto addToGroup = cast<AddToGroupOp>(op);
625     if (!addToGroup.operand().getType().isa<TokenType>())
626       return failure();
627 
628     auto i64 = IntegerType::get(64, op->getContext());
629     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
630     return success();
631   }
632 };
633 } // namespace
634 
635 //===----------------------------------------------------------------------===//
636 // async.await and async.await_all op lowerings to the corresponding async
637 // runtime function calls.
638 //===----------------------------------------------------------------------===//
639 
640 namespace {
641 
642 template <typename AwaitType, typename AwaitableType>
643 class AwaitOpLoweringBase : public ConversionPattern {
644 protected:
645   explicit AwaitOpLoweringBase(
646       MLIRContext *ctx,
647       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
648       StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
649       : ConversionPattern(AwaitType::getOperationName(), 1, ctx),
650         outlinedFunctions(outlinedFunctions),
651         blockingAwaitFuncName(blockingAwaitFuncName),
652         coroAwaitFuncName(coroAwaitFuncName) {}
653 
654 public:
655   LogicalResult
656   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
657                   ConversionPatternRewriter &rewriter) const override {
658     // We can only await on one the `AwaitableType` (for `await` it can be
659     // only a `token`, for `await_all` it is a `group`).
660     auto await = cast<AwaitType>(op);
661     if (!await.operand().getType().template isa<AwaitableType>())
662       return failure();
663 
664     // Check if await operation is inside the outlined coroutine function.
665     auto func = await.template getParentOfType<FuncOp>();
666     auto outlined = outlinedFunctions.find(func);
667     const bool isInCoroutine = outlined != outlinedFunctions.end();
668 
669     Location loc = op->getLoc();
670 
671     // Inside regular function we convert await operation to the blocking
672     // async API await function call.
673     if (!isInCoroutine)
674       rewriter.create<CallOp>(loc, Type(), blockingAwaitFuncName,
675                               ValueRange(op->getOperand(0)));
676 
677     // Inside the coroutine we convert await operation into coroutine suspension
678     // point, and resume execution asynchronously.
679     if (isInCoroutine) {
680       const CoroMachinery &coro = outlined->getSecond();
681 
682       OpBuilder builder(op);
683       MLIRContext *ctx = op->getContext();
684 
685       // A pointer to coroutine resume intrinsic wrapper.
686       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
687       auto resumePtr = builder.create<LLVM::AddressOfOp>(
688           loc, resumeFnTy.getPointerTo(), kResume);
689 
690       // Save the coroutine state: @llvm.coro.save
691       auto coroSave = builder.create<LLVM::CallOp>(
692           loc, LLVM::LLVMTokenType::get(ctx),
693           builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle));
694 
695       // Call async runtime API to resume a coroutine in the managed thread when
696       // the async await argument becomes ready.
697       SmallVector<Value, 3> awaitAndExecuteArgs = {
698           await.getOperand(), coro.coroHandle, resumePtr.res()};
699       builder.create<CallOp>(loc, Type(), coroAwaitFuncName,
700                              awaitAndExecuteArgs);
701 
702       // Split the entry block before the await operation.
703       addSuspensionPoint(coro, coroSave.getResult(0), op);
704     }
705 
706     // Original operation was replaced by function call or suspension point.
707     rewriter.eraseOp(op);
708 
709     return success();
710   }
711 
712 private:
713   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
714   StringRef blockingAwaitFuncName;
715   StringRef coroAwaitFuncName;
716 };
717 
718 // Lowering for `async.await` operation (only token operands are supported).
719 class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
720   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
721 
722 public:
723   explicit AwaitOpLowering(
724       MLIRContext *ctx,
725       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
726       : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {}
727 };
728 
729 // Lowering for `async.await_all` operation.
730 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
731   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
732 
733 public:
734   explicit AwaitAllOpLowering(
735       MLIRContext *ctx,
736       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
737       : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {}
738 };
739 
740 } // namespace
741 
742 //===----------------------------------------------------------------------===//
743 
744 namespace {
745 struct ConvertAsyncToLLVMPass
746     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
747   void runOnOperation() override;
748 };
749 
750 void ConvertAsyncToLLVMPass::runOnOperation() {
751   ModuleOp module = getOperation();
752   SymbolTable symbolTable(module);
753 
754   // Outline all `async.execute` body regions into async functions (coroutines).
755   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
756 
757   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
758     // We currently do not support execute operations that have async value
759     // operands or produce async results.
760     if (!execute.operands().empty() || !execute.results().empty()) {
761       execute.emitOpError("can't outline async.execute op with async value "
762                           "operands or returned async results");
763       return WalkResult::interrupt();
764     }
765 
766     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
767 
768     return WalkResult::advance();
769   });
770 
771   // Failed to outline all async execute operations.
772   if (outlineResult.wasInterrupted()) {
773     signalPassFailure();
774     return;
775   }
776 
777   LLVM_DEBUG({
778     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
779                  << " async functions\n";
780   });
781 
782   // Add declarations for all functions required by the coroutines lowering.
783   addResumeFunction(module);
784   addAsyncRuntimeApiDeclarations(module);
785   addCoroutineIntrinsicsDeclarations(module);
786   addCRuntimeDeclarations(module);
787 
788   MLIRContext *ctx = &getContext();
789 
790   // Convert async dialect types and operations to LLVM dialect.
791   AsyncRuntimeTypeConverter converter;
792   OwningRewritePatternList patterns;
793 
794   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
795   patterns.insert<CallOpOpConversion>(ctx);
796   patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
797   patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
798 
799   ConversionTarget target(*ctx);
800   target.addLegalDialect<LLVM::LLVMDialect>();
801   target.addIllegalDialect<AsyncDialect>();
802   target.addDynamicallyLegalOp<FuncOp>(
803       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
804   target.addDynamicallyLegalOp<CallOp>(
805       [&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
806 
807   if (failed(applyPartialConversion(module, target, std::move(patterns))))
808     signalPassFailure();
809 }
810 } // namespace
811 
812 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
813   return std::make_unique<ConvertAsyncToLLVMPass>();
814 }
815