1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Dialect/Async/IR/Async.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "mlir/Transforms/RegionUtils.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 #define DEBUG_TYPE "convert-async-to-llvm"
25 
26 using namespace mlir;
27 using namespace mlir::async;
28 
29 // Prefix for functions outlined from `async.execute` op regions.
30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
37 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
39 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
40 static constexpr const char *kAwaitAndExecute =
41     "mlirAsyncRuntimeAwaitTokenAndExecute";
42 
43 namespace {
44 // Async Runtime API function types.
45 struct AsyncAPI {
46   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
47     return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
48   }
49 
50   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
51     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
52   }
53 
54   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
55     return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
56   }
57 
58   static FunctionType executeFunctionType(MLIRContext *ctx) {
59     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
60     auto resume = resumeFunctionType(ctx).getPointerTo();
61     return FunctionType::get({hdl, resume}, {}, ctx);
62   }
63 
64   static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
65     auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
66     auto resume = resumeFunctionType(ctx).getPointerTo();
67     return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
68   }
69 
70   // Auxiliary coroutine resume intrinsic wrapper.
71   static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
72     auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
73     auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
74     return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false);
75   }
76 };
77 } // namespace
78 
79 // Adds Async Runtime C API declarations to the module.
80 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
81   auto builder = OpBuilder::atBlockTerminator(module.getBody());
82 
83   MLIRContext *ctx = module.getContext();
84   Location loc = module.getLoc();
85 
86   if (!module.lookupSymbol(kCreateToken))
87     builder.create<FuncOp>(loc, kCreateToken,
88                            AsyncAPI::createTokenFunctionType(ctx));
89 
90   if (!module.lookupSymbol(kEmplaceToken))
91     builder.create<FuncOp>(loc, kEmplaceToken,
92                            AsyncAPI::emplaceTokenFunctionType(ctx));
93 
94   if (!module.lookupSymbol(kAwaitToken))
95     builder.create<FuncOp>(loc, kAwaitToken,
96                            AsyncAPI::awaitTokenFunctionType(ctx));
97 
98   if (!module.lookupSymbol(kExecute))
99     builder.create<FuncOp>(loc, kExecute, AsyncAPI::executeFunctionType(ctx));
100 
101   if (!module.lookupSymbol(kAwaitAndExecute))
102     builder.create<FuncOp>(loc, kAwaitAndExecute,
103                            AsyncAPI::awaitAndExecuteFunctionType(ctx));
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // LLVM coroutines intrinsics declarations.
108 //===----------------------------------------------------------------------===//
109 
110 static constexpr const char *kCoroId = "llvm.coro.id";
111 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64";
112 static constexpr const char *kCoroBegin = "llvm.coro.begin";
113 static constexpr const char *kCoroSave = "llvm.coro.save";
114 static constexpr const char *kCoroSuspend = "llvm.coro.suspend";
115 static constexpr const char *kCoroEnd = "llvm.coro.end";
116 static constexpr const char *kCoroFree = "llvm.coro.free";
117 static constexpr const char *kCoroResume = "llvm.coro.resume";
118 
119 /// Adds coroutine intrinsics declarations to the module.
120 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
121   using namespace mlir::LLVM;
122 
123   MLIRContext *ctx = module.getContext();
124   Location loc = module.getLoc();
125 
126   OpBuilder builder(module.getBody()->getTerminator());
127 
128   auto token = LLVMTokenType::get(ctx);
129   auto voidTy = LLVMType::getVoidTy(ctx);
130 
131   auto i8 = LLVMType::getInt8Ty(ctx);
132   auto i1 = LLVMType::getInt1Ty(ctx);
133   auto i32 = LLVMType::getInt32Ty(ctx);
134   auto i64 = LLVMType::getInt64Ty(ctx);
135   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
136 
137   if (!module.lookupSymbol(kCoroId))
138     builder.create<LLVMFuncOp>(
139         loc, kCoroId,
140         LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false));
141 
142   if (!module.lookupSymbol(kCoroSizeI64))
143     builder.create<LLVMFuncOp>(loc, kCoroSizeI64,
144                                LLVMType::getFunctionTy(i64, false));
145 
146   if (!module.lookupSymbol(kCoroBegin))
147     builder.create<LLVMFuncOp>(
148         loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
149 
150   if (!module.lookupSymbol(kCoroSave))
151     builder.create<LLVMFuncOp>(loc, kCoroSave,
152                                LLVMType::getFunctionTy(token, i8Ptr, false));
153 
154   if (!module.lookupSymbol(kCoroSuspend))
155     builder.create<LLVMFuncOp>(loc, kCoroSuspend,
156                                LLVMType::getFunctionTy(i8, {token, i1}, false));
157 
158   if (!module.lookupSymbol(kCoroEnd))
159     builder.create<LLVMFuncOp>(loc, kCoroEnd,
160                                LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false));
161 
162   if (!module.lookupSymbol(kCoroFree))
163     builder.create<LLVMFuncOp>(
164         loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
165 
166   if (!module.lookupSymbol(kCoroResume))
167     builder.create<LLVMFuncOp>(loc, kCoroResume,
168                                LLVMType::getFunctionTy(voidTy, i8Ptr, false));
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Add malloc/free declarations to the module.
173 //===----------------------------------------------------------------------===//
174 
175 static constexpr const char *kMalloc = "malloc";
176 static constexpr const char *kFree = "free";
177 
178 /// Adds malloc/free declarations to the module.
179 static void addCRuntimeDeclarations(ModuleOp module) {
180   using namespace mlir::LLVM;
181 
182   MLIRContext *ctx = module.getContext();
183   Location loc = module.getLoc();
184 
185   OpBuilder builder(module.getBody()->getTerminator());
186 
187   auto voidTy = LLVMType::getVoidTy(ctx);
188   auto i64 = LLVMType::getInt64Ty(ctx);
189   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
190 
191   if (!module.lookupSymbol(kMalloc))
192     builder.create<LLVM::LLVMFuncOp>(
193         loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false));
194 
195   if (!module.lookupSymbol(kFree))
196     builder.create<LLVM::LLVMFuncOp>(
197         loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false));
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // Coroutine resume function wrapper.
202 //===----------------------------------------------------------------------===//
203 
204 static constexpr const char *kResume = "__resume";
205 
206 // A function that takes a coroutine handle and calls a `llvm.coro.resume`
207 // intrinsics. We need this function to be able to pass it to the async
208 // runtime execute API.
209 static void addResumeFunction(ModuleOp module) {
210   MLIRContext *ctx = module.getContext();
211 
212   OpBuilder moduleBuilder(module.getBody()->getTerminator());
213   Location loc = module.getLoc();
214 
215   if (module.lookupSymbol(kResume))
216     return;
217 
218   auto voidTy = LLVM::LLVMType::getVoidTy(ctx);
219   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
220 
221   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
222       loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
223   SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private);
224 
225   auto *block = resumeOp.addEntryBlock();
226   OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
227 
228   blockBuilder.create<LLVM::CallOp>(loc, Type(),
229                                     blockBuilder.getSymbolRefAttr(kCoroResume),
230                                     resumeOp.getArgument(0));
231 
232   blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange());
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // async.execute op outlining to the coroutine functions.
237 //===----------------------------------------------------------------------===//
238 
239 // Function targeted for coroutine transformation has two additional blocks at
240 // the end: coroutine cleanup and coroutine suspension.
241 //
242 // async.await op lowering additionaly creates a resume block for each
243 // operation to enable non-blocking waiting via coroutine suspension.
244 namespace {
245 struct CoroMachinery {
246   Value asyncToken;
247   Value coroHandle;
248   Block *cleanup;
249   Block *suspend;
250 };
251 } // namespace
252 
253 // Builds an coroutine template compatible with LLVM coroutines lowering.
254 //
255 //  - `entry` block sets up the coroutine.
256 //  - `cleanup` block cleans up the coroutine state.
257 //  - `suspend block after the @llvm.coro.end() defines what value will be
258 //    returned to the initial caller of a coroutine. Everything before the
259 //    @llvm.coro.end() will be executed at every suspension point.
260 //
261 // Coroutine structure (only the important bits):
262 //
263 //   func @async_execute_fn(<function-arguments>) -> !async.token {
264 //     ^entryBlock(<function-arguments>):
265 //       %token = <async token> : !async.token // create async runtime token
266 //       %hdl = llvm.call @llvm.coro.id(...)   // create a coroutine handle
267 //       br ^cleanup
268 //
269 //     ^cleanup:
270 //       llvm.call @llvm.coro.free(...)        // delete coroutine state
271 //       br ^suspend
272 //
273 //     ^suspend:
274 //       llvm.call @llvm.coro.end(...)         // marks the end of a coroutine
275 //       return %token : !async.token
276 //   }
277 //
278 // The actual code for the async.execute operation body region will be inserted
279 // before the entry block terminator.
280 //
281 //
282 static CoroMachinery setupCoroMachinery(FuncOp func) {
283   assert(func.getBody().empty() && "Function must have empty body");
284 
285   MLIRContext *ctx = func.getContext();
286 
287   auto token = LLVM::LLVMTokenType::get(ctx);
288   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
289   auto i32 = LLVM::LLVMType::getInt32Ty(ctx);
290   auto i64 = LLVM::LLVMType::getInt64Ty(ctx);
291   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
292 
293   Block *entryBlock = func.addEntryBlock();
294   Location loc = func.getBody().getLoc();
295 
296   OpBuilder builder = OpBuilder::atBlockBegin(entryBlock);
297 
298   // ------------------------------------------------------------------------ //
299   // Allocate async tokens/values that we will return from a ramp function.
300   // ------------------------------------------------------------------------ //
301   auto createToken =
302       builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx));
303 
304   // ------------------------------------------------------------------------ //
305   // Initialize coroutine: allocate frame, get coroutine handle.
306   // ------------------------------------------------------------------------ //
307 
308   // Constants for initializing coroutine frame.
309   auto constZero =
310       builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0));
311   auto constFalse =
312       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
313   auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr);
314 
315   // Get coroutine id: @llvm.coro.id
316   auto coroId = builder.create<LLVM::CallOp>(
317       loc, token, builder.getSymbolRefAttr(kCoroId),
318       ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
319 
320   // Get coroutine frame size: @llvm.coro.size.i64
321   auto coroSize = builder.create<LLVM::CallOp>(
322       loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange());
323 
324   // Allocate memory for coroutine frame.
325   auto coroAlloc = builder.create<LLVM::CallOp>(
326       loc, i8Ptr, builder.getSymbolRefAttr(kMalloc),
327       ValueRange(coroSize.getResult(0)));
328 
329   // Begin a coroutine: @llvm.coro.begin
330   auto coroHdl = builder.create<LLVM::CallOp>(
331       loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin),
332       ValueRange({coroId.getResult(0), coroAlloc.getResult(0)}));
333 
334   Block *cleanupBlock = func.addBlock();
335   Block *suspendBlock = func.addBlock();
336 
337   // ------------------------------------------------------------------------ //
338   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
339   // ------------------------------------------------------------------------ //
340   builder.setInsertionPointToStart(cleanupBlock);
341 
342   // Get a pointer to the coroutine frame memory: @llvm.coro.free.
343   auto coroMem = builder.create<LLVM::CallOp>(
344       loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree),
345       ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
346 
347   // Free the memory.
348   builder.create<LLVM::CallOp>(loc, Type(), builder.getSymbolRefAttr(kFree),
349                                ValueRange(coroMem.getResult(0)));
350   // Branch into the suspend block.
351   builder.create<BranchOp>(loc, suspendBlock);
352 
353   // ------------------------------------------------------------------------ //
354   // Coroutine suspend block: mark the end of a coroutine and return allocated
355   // async token.
356   // ------------------------------------------------------------------------ //
357   builder.setInsertionPointToStart(suspendBlock);
358 
359   // Mark the end of a coroutine: @llvm.coro.end.
360   builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd),
361                                ValueRange({coroHdl.getResult(0), constFalse}));
362 
363   // Return created `async.token` from the suspend block. This will be the
364   // return value of a coroutine ramp function.
365   builder.create<ReturnOp>(loc, createToken.getResult(0));
366 
367   // Branch from the entry block to the cleanup block to create a valid CFG.
368   builder.setInsertionPointToEnd(entryBlock);
369 
370   builder.create<BranchOp>(loc, cleanupBlock);
371 
372   // `async.await` op lowering will create resume blocks for async
373   // continuations, and will conditionally branch to cleanup or suspend blocks.
374 
375   return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
376           suspendBlock};
377 }
378 
379 // Adds a suspension point before the `op`, and moves `op` and all operations
380 // after it into the resume block. Returns a pointer to the resume block.
381 //
382 // `coroState` must be a value returned from the call to @llvm.coro.save(...)
383 // intrinsic (saved coroutine state).
384 //
385 // Before:
386 //
387 //   ^bb0:
388 //     "opBefore"(...)
389 //     "op"(...)
390 //   ^cleanup: ...
391 //   ^suspend: ...
392 //
393 // After:
394 //
395 //   ^bb0:
396 //     "opBefore"(...)
397 //     %suspend = llmv.call @llvm.coro.suspend(...)
398 //     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
399 //   ^resume:
400 //     "op"(...)
401 //   ^cleanup: ...
402 //   ^suspend: ...
403 //
404 static Block *addSuspensionPoint(CoroMachinery coro, Value coroState,
405                                  Operation *op) {
406   MLIRContext *ctx = op->getContext();
407   auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
408   auto i8 = LLVM::LLVMType::getInt8Ty(ctx);
409 
410   Location loc = op->getLoc();
411   Block *splitBlock = op->getBlock();
412 
413   // Split the block before `op`, newly added block is the resume block.
414   Block *resume = splitBlock->splitBlock(op);
415 
416   // Add a coroutine suspension in place of original `op` in the split block.
417   OpBuilder builder = OpBuilder::atBlockEnd(splitBlock);
418 
419   auto constFalse =
420       builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
421 
422   // Suspend a coroutine: @llvm.coro.suspend
423   auto coroSuspend = builder.create<LLVM::CallOp>(
424       loc, i8, builder.getSymbolRefAttr(kCoroSuspend),
425       ValueRange({coroState, constFalse}));
426 
427   // After a suspension point decide if we should branch into resume, cleanup
428   // or suspend block of the coroutine (see @llvm.coro.suspend return code
429   // documentation).
430   auto constZero =
431       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0));
432   auto constNegOne =
433       builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1));
434 
435   Block *resumeOrCleanup = builder.createBlock(resume);
436 
437   // Suspend the coroutine ...?
438   builder.setInsertionPointToEnd(splitBlock);
439   auto isNegOne = builder.create<LLVM::ICmpOp>(
440       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
441   builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
442                                  /*falseDest=*/resumeOrCleanup);
443 
444   // ... or resume or cleanup the coroutine?
445   builder.setInsertionPointToStart(resumeOrCleanup);
446   auto isZero = builder.create<LLVM::ICmpOp>(
447       loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
448   builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
449                                  /*falseDest=*/coro.cleanup);
450 
451   return resume;
452 }
453 
454 // Outline the body region attached to the `async.execute` op into a standalone
455 // function.
456 static std::pair<FuncOp, CoroMachinery>
457 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
458   ModuleOp module = execute.getParentOfType<ModuleOp>();
459 
460   MLIRContext *ctx = module.getContext();
461   Location loc = execute.getLoc();
462 
463   OpBuilder moduleBuilder(module.getBody()->getTerminator());
464 
465   // Collect all outlined function inputs.
466   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
467                                               execute.dependencies().end());
468   getUsedValuesDefinedAbove(execute.body(), functionInputs);
469 
470   // Collect types for the outlined function inputs and outputs.
471   auto typesRange = llvm::map_range(
472       functionInputs, [](Value value) { return value.getType(); });
473   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
474   auto outputTypes = execute.getResultTypes();
475 
476   auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
477   auto funcAttrs = ArrayRef<NamedAttribute>();
478 
479   // TODO: Derive outlined function name from the parent FuncOp (support
480   // multiple nested async.execute operations).
481   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
482   symbolTable.insert(func, moduleBuilder.getInsertionPoint());
483 
484   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
485 
486   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
487   // blocks, adding llvm.coro instrinsics and setting up control flow.
488   CoroMachinery coro = setupCoroMachinery(func);
489 
490   // Suspend async function at the end of an entry block, and resume it using
491   // Async execute API (execution will be resumed in a thread managed by the
492   // async runtime).
493   Block *entryBlock = &func.getBlocks().front();
494   OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock);
495 
496   // A pointer to coroutine resume intrinsic wrapper.
497   auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
498   auto resumePtr = builder.create<LLVM::AddressOfOp>(
499       loc, resumeFnTy.getPointerTo(), kResume);
500 
501   // Save the coroutine state: @llvm.coro.save
502   auto coroSave = builder.create<LLVM::CallOp>(
503       loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave),
504       ValueRange({coro.coroHandle}));
505 
506   // Call async runtime API to execute a coroutine in the managed thread.
507   SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
508   builder.create<CallOp>(loc, Type(), kExecute, executeArgs);
509 
510   // Split the entry block before the terminator.
511   Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
512                                      entryBlock->getTerminator());
513 
514   // Await on all dependencies before starting to execute the body region.
515   builder.setInsertionPointToStart(resume);
516   for (size_t i = 0; i < execute.dependencies().size(); ++i)
517     builder.create<AwaitOp>(loc, func.getArgument(i));
518 
519   // Map from function inputs defined above the execute op to the function
520   // arguments.
521   BlockAndValueMapping valueMapping;
522   valueMapping.map(functionInputs, func.getArguments());
523 
524   // Clone all operations from the execute operation body into the outlined
525   // function body, and replace all `async.yield` operations with a call
526   // to async runtime to emplace the result token.
527   for (Operation &op : execute.body().getOps()) {
528     if (isa<async::YieldOp>(op)) {
529       builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken);
530       continue;
531     }
532     builder.clone(op, valueMapping);
533   }
534 
535   // Replace the original `async.execute` with a call to outlined function.
536   OpBuilder callBuilder(execute);
537   auto callOutlinedFunc =
538       callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(),
539                                  functionInputs.getArrayRef());
540   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
541   execute.erase();
542 
543   return {func, coro};
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // Convert Async dialect types to LLVM types.
548 //===----------------------------------------------------------------------===//
549 
550 namespace {
551 class AsyncRuntimeTypeConverter : public TypeConverter {
552 public:
553   AsyncRuntimeTypeConverter() { addConversion(convertType); }
554 
555   static Type convertType(Type type) {
556     MLIRContext *ctx = type.getContext();
557     // Convert async tokens to opaque pointers.
558     if (type.isa<TokenType>())
559       return LLVM::LLVMType::getInt8PtrTy(ctx);
560     return type;
561   }
562 };
563 } // namespace
564 
565 //===----------------------------------------------------------------------===//
566 // Convert types for all call operations to lowered async types.
567 //===----------------------------------------------------------------------===//
568 
569 namespace {
570 class CallOpOpConversion : public ConversionPattern {
571 public:
572   explicit CallOpOpConversion(MLIRContext *ctx)
573       : ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
574 
575   LogicalResult
576   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
577                   ConversionPatternRewriter &rewriter) const override {
578     AsyncRuntimeTypeConverter converter;
579 
580     SmallVector<Type, 5> resultTypes;
581     converter.convertTypes(op->getResultTypes(), resultTypes);
582 
583     CallOp call = cast<CallOp>(op);
584     rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
585                                         call.getOperands());
586 
587     return success();
588   }
589 };
590 } // namespace
591 
592 //===----------------------------------------------------------------------===//
593 // async.await op lowering to mlirAsyncRuntimeAwaitToken function call.
594 //===----------------------------------------------------------------------===//
595 
596 namespace {
597 class AwaitOpLowering : public ConversionPattern {
598 public:
599   explicit AwaitOpLowering(
600       MLIRContext *ctx,
601       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
602       : ConversionPattern(AwaitOp::getOperationName(), 1, ctx),
603         outlinedFunctions(outlinedFunctions) {}
604 
605   LogicalResult
606   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
607                   ConversionPatternRewriter &rewriter) const override {
608     // We can only await on the token operand. Async valus are not supported.
609     auto await = cast<AwaitOp>(op);
610     if (!await.operand().getType().isa<TokenType>())
611       return failure();
612 
613     // Check if `async.await` is inside the outlined coroutine function.
614     auto func = await.getParentOfType<FuncOp>();
615     auto outlined = outlinedFunctions.find(func);
616     const bool isInCoroutine = outlined != outlinedFunctions.end();
617 
618     Location loc = op->getLoc();
619 
620     // Inside regular function we convert await operation to the blocking
621     // async API await function call.
622     if (!isInCoroutine)
623       rewriter.create<CallOp>(loc, Type(), kAwaitToken,
624                               ValueRange(op->getOperand(0)));
625 
626     // Inside the coroutine we convert await operation into coroutine suspension
627     // point, and resume execution asynchronously.
628     if (isInCoroutine) {
629       const CoroMachinery &coro = outlined->getSecond();
630 
631       OpBuilder builder(op);
632       MLIRContext *ctx = op->getContext();
633 
634       // A pointer to coroutine resume intrinsic wrapper.
635       auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx);
636       auto resumePtr = builder.create<LLVM::AddressOfOp>(
637           loc, resumeFnTy.getPointerTo(), kResume);
638 
639       // Save the coroutine state: @llvm.coro.save
640       auto coroSave = builder.create<LLVM::CallOp>(
641           loc, LLVM::LLVMTokenType::get(ctx),
642           builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle));
643 
644       // Call async runtime API to resume a coroutine in the managed thread when
645       // the async await argument becomes ready.
646       SmallVector<Value, 3> awaitAndExecuteArgs = {
647           await.getOperand(), coro.coroHandle, resumePtr.res()};
648       builder.create<CallOp>(loc, Type(), kAwaitAndExecute,
649                              awaitAndExecuteArgs);
650 
651       // Split the entry block before the await operation.
652       addSuspensionPoint(coro, coroSave.getResult(0), op);
653     }
654 
655     // Original operation was replaced by function call or suspension point.
656     rewriter.eraseOp(op);
657 
658     return success();
659   }
660 
661 private:
662   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
663 };
664 } // namespace
665 
666 //===----------------------------------------------------------------------===//
667 
668 namespace {
669 struct ConvertAsyncToLLVMPass
670     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
671   void runOnOperation() override;
672 };
673 
674 void ConvertAsyncToLLVMPass::runOnOperation() {
675   ModuleOp module = getOperation();
676   SymbolTable symbolTable(module);
677 
678   // Outline all `async.execute` body regions into async functions (coroutines).
679   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
680 
681   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
682     // We currently do not support execute operations that have async value
683     // operands or produce async results.
684     if (!execute.operands().empty() || !execute.results().empty()) {
685       execute.emitOpError("can't outline async.execute op with async value "
686                           "operands or returned async results");
687       return WalkResult::interrupt();
688     }
689 
690     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
691 
692     return WalkResult::advance();
693   });
694 
695   // Failed to outline all async execute operations.
696   if (outlineResult.wasInterrupted()) {
697     signalPassFailure();
698     return;
699   }
700 
701   LLVM_DEBUG({
702     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
703                  << " async functions\n";
704   });
705 
706   // Add declarations for all functions required by the coroutines lowering.
707   addResumeFunction(module);
708   addAsyncRuntimeApiDeclarations(module);
709   addCoroutineIntrinsicsDeclarations(module);
710   addCRuntimeDeclarations(module);
711 
712   MLIRContext *ctx = &getContext();
713 
714   // Convert async dialect types and operations to LLVM dialect.
715   AsyncRuntimeTypeConverter converter;
716   OwningRewritePatternList patterns;
717 
718   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
719   patterns.insert<CallOpOpConversion>(ctx);
720   patterns.insert<AwaitOpLowering>(ctx, outlinedFunctions);
721 
722   ConversionTarget target(*ctx);
723   target.addLegalDialect<LLVM::LLVMDialect>();
724   target.addIllegalDialect<AsyncDialect>();
725   target.addDynamicallyLegalOp<FuncOp>(
726       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
727   target.addDynamicallyLegalOp<CallOp>(
728       [&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
729 
730   if (failed(applyPartialConversion(module, target, std::move(patterns))))
731     signalPassFailure();
732 }
733 } // namespace
734 
735 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
736   return std::make_unique<ConvertAsyncToLLVMPass>();
737 }
738