1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 #define DEBUG_TYPE "convert-async-to-llvm"
28 
29 using namespace mlir;
30 using namespace mlir::async;
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
52 static constexpr const char *kGetValueStorage =
53     "mlirAsyncRuntimeGetValueStorage";
54 static constexpr const char *kAddTokenToGroup =
55     "mlirAsyncRuntimeAddTokenToGroup";
56 static constexpr const char *kAwaitTokenAndExecute =
57     "mlirAsyncRuntimeAwaitTokenAndExecute";
58 static constexpr const char *kAwaitValueAndExecute =
59     "mlirAsyncRuntimeAwaitValueAndExecute";
60 static constexpr const char *kAwaitAllAndExecute =
61     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
62 static constexpr const char *kGetNumWorkerThreads =
63     "mlirAsyncRuntimGetNumWorkerThreads";
64 
65 namespace {
66 /// Async Runtime API function types.
67 ///
68 /// Because we can't create API function signature for type parametrized
69 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
70 /// lowering all async data types become opaque pointers at runtime.
71 struct AsyncAPI {
72   // All async types are lowered to opaque i8* LLVM pointers at runtime.
73   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
74     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
75   }
76 
77   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
78     return LLVM::LLVMTokenType::get(ctx);
79   }
80 
81   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
82     auto ref = opaquePointerType(ctx);
83     auto count = IntegerType::get(ctx, 64);
84     return FunctionType::get(ctx, {ref, count}, {});
85   }
86 
87   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
88     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
89   }
90 
91   static FunctionType createValueFunctionType(MLIRContext *ctx) {
92     auto i64 = IntegerType::get(ctx, 64);
93     auto value = opaquePointerType(ctx);
94     return FunctionType::get(ctx, {i64}, {value});
95   }
96 
97   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
98     auto i64 = IntegerType::get(ctx, 64);
99     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
100   }
101 
102   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
103     auto value = opaquePointerType(ctx);
104     auto storage = opaquePointerType(ctx);
105     return FunctionType::get(ctx, {value}, {storage});
106   }
107 
108   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
109     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
110   }
111 
112   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
113     auto value = opaquePointerType(ctx);
114     return FunctionType::get(ctx, {value}, {});
115   }
116 
117   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
118     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
119   }
120 
121   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
122     auto value = opaquePointerType(ctx);
123     return FunctionType::get(ctx, {value}, {});
124   }
125 
126   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
127     auto i1 = IntegerType::get(ctx, 1);
128     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
129   }
130 
131   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
132     auto value = opaquePointerType(ctx);
133     auto i1 = IntegerType::get(ctx, 1);
134     return FunctionType::get(ctx, {value}, {i1});
135   }
136 
137   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
138     auto i1 = IntegerType::get(ctx, 1);
139     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
140   }
141 
142   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
143     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
144   }
145 
146   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
147     auto value = opaquePointerType(ctx);
148     return FunctionType::get(ctx, {value}, {});
149   }
150 
151   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
152     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
153   }
154 
155   static FunctionType executeFunctionType(MLIRContext *ctx) {
156     auto hdl = opaquePointerType(ctx);
157     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
158     return FunctionType::get(ctx, {hdl, resume}, {});
159   }
160 
161   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
162     auto i64 = IntegerType::get(ctx, 64);
163     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
164                              {i64});
165   }
166 
167   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
168     auto hdl = opaquePointerType(ctx);
169     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
170     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
171   }
172 
173   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
174     auto value = opaquePointerType(ctx);
175     auto hdl = opaquePointerType(ctx);
176     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
177     return FunctionType::get(ctx, {value, hdl, resume}, {});
178   }
179 
180   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
181     auto hdl = opaquePointerType(ctx);
182     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
183     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
184   }
185 
186   static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
187     return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
188   }
189 
190   // Auxiliary coroutine resume intrinsic wrapper.
191   static Type resumeFunctionType(MLIRContext *ctx) {
192     auto voidTy = LLVM::LLVMVoidType::get(ctx);
193     auto i8Ptr = opaquePointerType(ctx);
194     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
195   }
196 };
197 } // namespace
198 
199 /// Adds Async Runtime C API declarations to the module.
200 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
201   auto builder =
202       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
203 
204   auto addFuncDecl = [&](StringRef name, FunctionType type) {
205     if (module.lookupSymbol(name))
206       return;
207     builder.create<FuncOp>(name, type).setPrivate();
208   };
209 
210   MLIRContext *ctx = module.getContext();
211   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
213   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
214   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
215   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
216   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
217   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
218   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
219   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
220   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
221   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
222   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
223   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
224   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
225   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
226   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
227   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
228   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
229   addFuncDecl(kAwaitTokenAndExecute,
230               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
231   addFuncDecl(kAwaitValueAndExecute,
232               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
233   addFuncDecl(kAwaitAllAndExecute,
234               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
235   addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // Coroutine resume function wrapper.
240 //===----------------------------------------------------------------------===//
241 
242 static constexpr const char *kResume = "__resume";
243 
244 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
245 /// intrinsics. We need this function to be able to pass it to the async
246 /// runtime execute API.
247 static void addResumeFunction(ModuleOp module) {
248   if (module.lookupSymbol(kResume))
249     return;
250 
251   MLIRContext *ctx = module.getContext();
252   auto loc = module.getLoc();
253   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
254 
255   auto voidTy = LLVM::LLVMVoidType::get(ctx);
256   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
257 
258   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
259       kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
260   resumeOp.setPrivate();
261 
262   auto *block = resumeOp.addEntryBlock();
263   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
264 
265   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
266   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // Convert Async dialect types to LLVM types.
271 //===----------------------------------------------------------------------===//
272 
273 namespace {
274 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
275 /// their runtime type (opaque pointers) and does not convert any other types.
276 class AsyncRuntimeTypeConverter : public TypeConverter {
277 public:
278   AsyncRuntimeTypeConverter() {
279     addConversion([](Type type) { return type; });
280     addConversion(convertAsyncTypes);
281   }
282 
283   static Optional<Type> convertAsyncTypes(Type type) {
284     if (type.isa<TokenType, GroupType, ValueType>())
285       return AsyncAPI::opaquePointerType(type.getContext());
286 
287     if (type.isa<CoroIdType, CoroStateType>())
288       return AsyncAPI::tokenType(type.getContext());
289     if (type.isa<CoroHandleType>())
290       return AsyncAPI::opaquePointerType(type.getContext());
291 
292     return llvm::None;
293   }
294 };
295 } // namespace
296 
297 //===----------------------------------------------------------------------===//
298 // Convert async.coro.id to @llvm.coro.id intrinsic.
299 //===----------------------------------------------------------------------===//
300 
301 namespace {
302 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
303 public:
304   using OpConversionPattern::OpConversionPattern;
305 
306   LogicalResult
307   matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
308                   ConversionPatternRewriter &rewriter) const override {
309     auto token = AsyncAPI::tokenType(op->getContext());
310     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
311     auto loc = op->getLoc();
312 
313     // Constants for initializing coroutine frame.
314     auto constZero = rewriter.create<LLVM::ConstantOp>(
315         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
316     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
317 
318     // Get coroutine id: @llvm.coro.id.
319     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
320         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
321 
322     return success();
323   }
324 };
325 } // namespace
326 
327 //===----------------------------------------------------------------------===//
328 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
329 //===----------------------------------------------------------------------===//
330 
331 namespace {
332 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
333 public:
334   using OpConversionPattern::OpConversionPattern;
335 
336   LogicalResult
337   matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
338                   ConversionPatternRewriter &rewriter) const override {
339     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
340     auto loc = op->getLoc();
341 
342     // Get coroutine frame size: @llvm.coro.size.i64.
343     Value coroSize =
344         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
345     // Get coroutine frame alignment: @llvm.coro.align.i64.
346     Value coroAlign =
347         rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
348 
349     // Round up the size to be multiple of the alignment. Since aligned_alloc
350     // requires the size parameter be an integral multiple of the alignment
351     // parameter.
352     auto makeConstant = [&](uint64_t c) {
353       return rewriter.create<LLVM::ConstantOp>(
354           op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(c));
355     };
356     coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
357     coroSize =
358         rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
359     Value negCoroAlign =
360         rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
361     coroSize =
362         rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
363 
364     // Allocate memory for the coroutine frame.
365     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
366         op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
367     auto coroAlloc = rewriter.create<LLVM::CallOp>(
368         loc, i8Ptr, SymbolRefAttr::get(allocFuncOp),
369         ValueRange{coroAlign, coroSize});
370 
371     // Begin a coroutine: @llvm.coro.begin.
372     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
373     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
374         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
375 
376     return success();
377   }
378 };
379 } // namespace
380 
381 //===----------------------------------------------------------------------===//
382 // Convert async.coro.free to @llvm.coro.free intrinsic.
383 //===----------------------------------------------------------------------===//
384 
385 namespace {
386 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
387 public:
388   using OpConversionPattern::OpConversionPattern;
389 
390   LogicalResult
391   matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
392                   ConversionPatternRewriter &rewriter) const override {
393     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
394     auto loc = op->getLoc();
395 
396     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
397     auto coroMem =
398         rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
399 
400     // Free the memory.
401     auto freeFuncOp =
402         LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
403     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
404                                               SymbolRefAttr::get(freeFuncOp),
405                                               ValueRange(coroMem.getResult()));
406 
407     return success();
408   }
409 };
410 } // namespace
411 
412 //===----------------------------------------------------------------------===//
413 // Convert async.coro.end to @llvm.coro.end intrinsic.
414 //===----------------------------------------------------------------------===//
415 
416 namespace {
417 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
418 public:
419   using OpConversionPattern::OpConversionPattern;
420 
421   LogicalResult
422   matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
423                   ConversionPatternRewriter &rewriter) const override {
424     // We are not in the block that is part of the unwind sequence.
425     auto constFalse = rewriter.create<LLVM::ConstantOp>(
426         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
427 
428     // Mark the end of a coroutine: @llvm.coro.end.
429     auto coroHdl = adaptor.handle();
430     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
431                                      ValueRange({coroHdl, constFalse}));
432     rewriter.eraseOp(op);
433 
434     return success();
435   }
436 };
437 } // namespace
438 
439 //===----------------------------------------------------------------------===//
440 // Convert async.coro.save to @llvm.coro.save intrinsic.
441 //===----------------------------------------------------------------------===//
442 
443 namespace {
444 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
445 public:
446   using OpConversionPattern::OpConversionPattern;
447 
448   LogicalResult
449   matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
450                   ConversionPatternRewriter &rewriter) const override {
451     // Save the coroutine state: @llvm.coro.save
452     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
453         op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
454 
455     return success();
456   }
457 };
458 } // namespace
459 
460 //===----------------------------------------------------------------------===//
461 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
462 //===----------------------------------------------------------------------===//
463 
464 namespace {
465 
466 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
467 /// branch to the appropriate block based on the return code.
468 ///
469 /// Before:
470 ///
471 ///   ^suspended:
472 ///     "opBefore"(...)
473 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
474 ///   ^resume:
475 ///     "op"(...)
476 ///   ^cleanup: ...
477 ///   ^suspend: ...
478 ///
479 /// After:
480 ///
481 ///   ^suspended:
482 ///     "opBefore"(...)
483 ///     %suspend = llmv.intr.coro.suspend ...
484 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
485 ///   ^resume:
486 ///     "op"(...)
487 ///   ^cleanup: ...
488 ///   ^suspend: ...
489 ///
490 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
491 public:
492   using OpConversionPattern::OpConversionPattern;
493 
494   LogicalResult
495   matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
496                   ConversionPatternRewriter &rewriter) const override {
497     auto i8 = rewriter.getIntegerType(8);
498     auto i32 = rewriter.getI32Type();
499     auto loc = op->getLoc();
500 
501     // This is not a final suspension point.
502     auto constFalse = rewriter.create<LLVM::ConstantOp>(
503         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
504 
505     // Suspend a coroutine: @llvm.coro.suspend
506     auto coroState = adaptor.state();
507     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
508         loc, i8, ValueRange({coroState, constFalse}));
509 
510     // Cast return code to i32.
511 
512     // After a suspension point decide if we should branch into resume, cleanup
513     // or suspend block of the coroutine (see @llvm.coro.suspend return code
514     // documentation).
515     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
516     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
517                                               op.cleanupDest()};
518     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
519         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
520         /*defaultDestination=*/op.suspendDest(),
521         /*defaultOperands=*/ValueRange(),
522         /*caseValues=*/caseValues,
523         /*caseDestinations=*/caseDest,
524         /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
525         /*branchWeights=*/ArrayRef<int32_t>());
526 
527     return success();
528   }
529 };
530 } // namespace
531 
532 //===----------------------------------------------------------------------===//
533 // Convert async.runtime.create to the corresponding runtime API call.
534 //
535 // To allocate storage for the async values we use getelementptr trick:
536 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
537 //===----------------------------------------------------------------------===//
538 
539 namespace {
540 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
541 public:
542   using OpConversionPattern::OpConversionPattern;
543 
544   LogicalResult
545   matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
546                   ConversionPatternRewriter &rewriter) const override {
547     TypeConverter *converter = getTypeConverter();
548     Type resultType = op->getResultTypes()[0];
549 
550     // Tokens creation maps to a simple function call.
551     if (resultType.isa<TokenType>()) {
552       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
553                                           converter->convertType(resultType));
554       return success();
555     }
556 
557     // To create a value we need to compute the storage requirement.
558     if (auto value = resultType.dyn_cast<ValueType>()) {
559       // Returns the size requirements for the async value storage.
560       auto sizeOf = [&](ValueType valueType) -> Value {
561         auto loc = op->getLoc();
562         auto i64 = rewriter.getI64Type();
563 
564         auto storedType = converter->convertType(valueType.getValueType());
565         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
566 
567         // %Size = getelementptr %T* null, int 1
568         // %SizeI = ptrtoint %T* %Size to i64
569         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
570         auto one = rewriter.create<LLVM::ConstantOp>(
571             loc, i64, rewriter.getI64IntegerAttr(1));
572         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
573                                                 one.getResult());
574         return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
575       };
576 
577       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
578                                           sizeOf(value));
579 
580       return success();
581     }
582 
583     return rewriter.notifyMatchFailure(op, "unsupported async type");
584   }
585 };
586 } // namespace
587 
588 //===----------------------------------------------------------------------===//
589 // Convert async.runtime.create_group to the corresponding runtime API call.
590 //===----------------------------------------------------------------------===//
591 
592 namespace {
593 class RuntimeCreateGroupOpLowering
594     : public OpConversionPattern<RuntimeCreateGroupOp> {
595 public:
596   using OpConversionPattern::OpConversionPattern;
597 
598   LogicalResult
599   matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
600                   ConversionPatternRewriter &rewriter) const override {
601     TypeConverter *converter = getTypeConverter();
602     Type resultType = op.getResult().getType();
603 
604     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup,
605                                         converter->convertType(resultType),
606                                         adaptor.getOperands());
607     return success();
608   }
609 };
610 } // namespace
611 
612 //===----------------------------------------------------------------------===//
613 // Convert async.runtime.set_available to the corresponding runtime API call.
614 //===----------------------------------------------------------------------===//
615 
616 namespace {
617 class RuntimeSetAvailableOpLowering
618     : public OpConversionPattern<RuntimeSetAvailableOp> {
619 public:
620   using OpConversionPattern::OpConversionPattern;
621 
622   LogicalResult
623   matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
624                   ConversionPatternRewriter &rewriter) const override {
625     StringRef apiFuncName =
626         TypeSwitch<Type, StringRef>(op.operand().getType())
627             .Case<TokenType>([](Type) { return kEmplaceToken; })
628             .Case<ValueType>([](Type) { return kEmplaceValue; });
629 
630     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
631                                         adaptor.getOperands());
632 
633     return success();
634   }
635 };
636 } // namespace
637 
638 //===----------------------------------------------------------------------===//
639 // Convert async.runtime.set_error to the corresponding runtime API call.
640 //===----------------------------------------------------------------------===//
641 
642 namespace {
643 class RuntimeSetErrorOpLowering
644     : public OpConversionPattern<RuntimeSetErrorOp> {
645 public:
646   using OpConversionPattern::OpConversionPattern;
647 
648   LogicalResult
649   matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
650                   ConversionPatternRewriter &rewriter) const override {
651     StringRef apiFuncName =
652         TypeSwitch<Type, StringRef>(op.operand().getType())
653             .Case<TokenType>([](Type) { return kSetTokenError; })
654             .Case<ValueType>([](Type) { return kSetValueError; });
655 
656     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
657                                         adaptor.getOperands());
658 
659     return success();
660   }
661 };
662 } // namespace
663 
664 //===----------------------------------------------------------------------===//
665 // Convert async.runtime.is_error to the corresponding runtime API call.
666 //===----------------------------------------------------------------------===//
667 
668 namespace {
669 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
670 public:
671   using OpConversionPattern::OpConversionPattern;
672 
673   LogicalResult
674   matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
675                   ConversionPatternRewriter &rewriter) const override {
676     StringRef apiFuncName =
677         TypeSwitch<Type, StringRef>(op.operand().getType())
678             .Case<TokenType>([](Type) { return kIsTokenError; })
679             .Case<GroupType>([](Type) { return kIsGroupError; })
680             .Case<ValueType>([](Type) { return kIsValueError; });
681 
682     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
683                                         adaptor.getOperands());
684     return success();
685   }
686 };
687 } // namespace
688 
689 //===----------------------------------------------------------------------===//
690 // Convert async.runtime.await to the corresponding runtime API call.
691 //===----------------------------------------------------------------------===//
692 
693 namespace {
694 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
695 public:
696   using OpConversionPattern::OpConversionPattern;
697 
698   LogicalResult
699   matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
700                   ConversionPatternRewriter &rewriter) const override {
701     StringRef apiFuncName =
702         TypeSwitch<Type, StringRef>(op.operand().getType())
703             .Case<TokenType>([](Type) { return kAwaitToken; })
704             .Case<ValueType>([](Type) { return kAwaitValue; })
705             .Case<GroupType>([](Type) { return kAwaitGroup; });
706 
707     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
708                             adaptor.getOperands());
709     rewriter.eraseOp(op);
710 
711     return success();
712   }
713 };
714 } // namespace
715 
716 //===----------------------------------------------------------------------===//
717 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
718 //===----------------------------------------------------------------------===//
719 
720 namespace {
721 class RuntimeAwaitAndResumeOpLowering
722     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
723 public:
724   using OpConversionPattern::OpConversionPattern;
725 
726   LogicalResult
727   matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
728                   ConversionPatternRewriter &rewriter) const override {
729     StringRef apiFuncName =
730         TypeSwitch<Type, StringRef>(op.operand().getType())
731             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
732             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
733             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
734 
735     Value operand = adaptor.operand();
736     Value handle = adaptor.handle();
737 
738     // A pointer to coroutine resume intrinsic wrapper.
739     addResumeFunction(op->getParentOfType<ModuleOp>());
740     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
741     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
742         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
743 
744     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
745                             ValueRange({operand, handle, resumePtr.getRes()}));
746     rewriter.eraseOp(op);
747 
748     return success();
749   }
750 };
751 } // namespace
752 
753 //===----------------------------------------------------------------------===//
754 // Convert async.runtime.resume to the corresponding runtime API call.
755 //===----------------------------------------------------------------------===//
756 
757 namespace {
758 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
759 public:
760   using OpConversionPattern::OpConversionPattern;
761 
762   LogicalResult
763   matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
764                   ConversionPatternRewriter &rewriter) const override {
765     // A pointer to coroutine resume intrinsic wrapper.
766     addResumeFunction(op->getParentOfType<ModuleOp>());
767     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
768     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
769         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
770 
771     // Call async runtime API to execute a coroutine in the managed thread.
772     auto coroHdl = adaptor.handle();
773     rewriter.replaceOpWithNewOp<CallOp>(
774         op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
775 
776     return success();
777   }
778 };
779 } // namespace
780 
781 //===----------------------------------------------------------------------===//
782 // Convert async.runtime.store to the corresponding runtime API call.
783 //===----------------------------------------------------------------------===//
784 
785 namespace {
786 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
787 public:
788   using OpConversionPattern::OpConversionPattern;
789 
790   LogicalResult
791   matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
792                   ConversionPatternRewriter &rewriter) const override {
793     Location loc = op->getLoc();
794 
795     // Get a pointer to the async value storage from the runtime.
796     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
797     auto storage = adaptor.storage();
798     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
799                                               TypeRange(i8Ptr), storage);
800 
801     // Cast from i8* to the LLVM pointer type.
802     auto valueType = op.value().getType();
803     auto llvmValueType = getTypeConverter()->convertType(valueType);
804     if (!llvmValueType)
805       return rewriter.notifyMatchFailure(
806           op, "failed to convert stored value type to LLVM type");
807 
808     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
809         loc, LLVM::LLVMPointerType::get(llvmValueType),
810         storagePtr.getResult(0));
811 
812     // Store the yielded value into the async value storage.
813     auto value = adaptor.value();
814     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
815 
816     // Erase the original runtime store operation.
817     rewriter.eraseOp(op);
818 
819     return success();
820   }
821 };
822 } // namespace
823 
824 //===----------------------------------------------------------------------===//
825 // Convert async.runtime.load to the corresponding runtime API call.
826 //===----------------------------------------------------------------------===//
827 
828 namespace {
829 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
830 public:
831   using OpConversionPattern::OpConversionPattern;
832 
833   LogicalResult
834   matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
835                   ConversionPatternRewriter &rewriter) const override {
836     Location loc = op->getLoc();
837 
838     // Get a pointer to the async value storage from the runtime.
839     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
840     auto storage = adaptor.storage();
841     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
842                                               TypeRange(i8Ptr), storage);
843 
844     // Cast from i8* to the LLVM pointer type.
845     auto valueType = op.result().getType();
846     auto llvmValueType = getTypeConverter()->convertType(valueType);
847     if (!llvmValueType)
848       return rewriter.notifyMatchFailure(
849           op, "failed to convert loaded value type to LLVM type");
850 
851     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
852         loc, LLVM::LLVMPointerType::get(llvmValueType),
853         storagePtr.getResult(0));
854 
855     // Load from the casted pointer.
856     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
857 
858     return success();
859   }
860 };
861 } // namespace
862 
863 //===----------------------------------------------------------------------===//
864 // Convert async.runtime.add_to_group to the corresponding runtime API call.
865 //===----------------------------------------------------------------------===//
866 
867 namespace {
868 class RuntimeAddToGroupOpLowering
869     : public OpConversionPattern<RuntimeAddToGroupOp> {
870 public:
871   using OpConversionPattern::OpConversionPattern;
872 
873   LogicalResult
874   matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
875                   ConversionPatternRewriter &rewriter) const override {
876     // Currently we can only add tokens to the group.
877     if (!op.operand().getType().isa<TokenType>())
878       return rewriter.notifyMatchFailure(op, "only token type is supported");
879 
880     // Replace with a runtime API function call.
881     rewriter.replaceOpWithNewOp<CallOp>(
882         op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
883 
884     return success();
885   }
886 };
887 } // namespace
888 
889 //===----------------------------------------------------------------------===//
890 // Convert async.runtime.num_worker_threads to the corresponding runtime API
891 // call.
892 //===----------------------------------------------------------------------===//
893 
894 namespace {
895 class RuntimeNumWorkerThreadsOpLowering
896     : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
897 public:
898   using OpConversionPattern::OpConversionPattern;
899 
900   LogicalResult
901   matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
902                   ConversionPatternRewriter &rewriter) const override {
903 
904     // Replace with a runtime API function call.
905     rewriter.replaceOpWithNewOp<CallOp>(op, kGetNumWorkerThreads,
906                                         rewriter.getIndexType());
907 
908     return success();
909   }
910 };
911 } // namespace
912 
913 //===----------------------------------------------------------------------===//
914 // Async reference counting ops lowering (`async.runtime.add_ref` and
915 // `async.runtime.drop_ref` to the corresponding API calls).
916 //===----------------------------------------------------------------------===//
917 
918 namespace {
919 template <typename RefCountingOp>
920 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
921 public:
922   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
923                                  StringRef apiFunctionName)
924       : OpConversionPattern<RefCountingOp>(converter, ctx),
925         apiFunctionName(apiFunctionName) {}
926 
927   LogicalResult
928   matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
929                   ConversionPatternRewriter &rewriter) const override {
930     auto count = rewriter.create<arith::ConstantOp>(
931         op->getLoc(), rewriter.getI64Type(),
932         rewriter.getI64IntegerAttr(op.count()));
933 
934     auto operand = adaptor.operand();
935     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
936                                         ValueRange({operand, count}));
937 
938     return success();
939   }
940 
941 private:
942   StringRef apiFunctionName;
943 };
944 
945 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
946 public:
947   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
948       : RefCountingOpLowering(converter, ctx, kAddRef) {}
949 };
950 
951 class RuntimeDropRefOpLowering
952     : public RefCountingOpLowering<RuntimeDropRefOp> {
953 public:
954   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
955       : RefCountingOpLowering(converter, ctx, kDropRef) {}
956 };
957 } // namespace
958 
959 //===----------------------------------------------------------------------===//
960 // Convert return operations that return async values from async regions.
961 //===----------------------------------------------------------------------===//
962 
963 namespace {
964 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
965 public:
966   using OpConversionPattern::OpConversionPattern;
967 
968   LogicalResult
969   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
970                   ConversionPatternRewriter &rewriter) const override {
971     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
972     return success();
973   }
974 };
975 } // namespace
976 
977 //===----------------------------------------------------------------------===//
978 
979 namespace {
980 struct ConvertAsyncToLLVMPass
981     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
982   void runOnOperation() override;
983 };
984 } // namespace
985 
986 void ConvertAsyncToLLVMPass::runOnOperation() {
987   ModuleOp module = getOperation();
988   MLIRContext *ctx = module->getContext();
989 
990   // Add declarations for most functions required by the coroutines lowering.
991   // We delay adding the resume function until it's needed because it currently
992   // fails to compile unless '-O0' is specified.
993   addAsyncRuntimeApiDeclarations(module);
994 
995   // Lower async.runtime and async.coro operations to Async Runtime API and
996   // LLVM coroutine intrinsics.
997 
998   // Convert async dialect types and operations to LLVM dialect.
999   AsyncRuntimeTypeConverter converter;
1000   RewritePatternSet patterns(ctx);
1001 
1002   // We use conversion to LLVM type to lower async.runtime load and store
1003   // operations.
1004   LLVMTypeConverter llvmConverter(ctx);
1005   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
1006 
1007   // Convert async types in function signatures and function calls.
1008   populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, converter);
1009   populateCallOpTypeConversionPattern(patterns, converter);
1010 
1011   // Convert return operations inside async.execute regions.
1012   patterns.add<ReturnOpOpConversion>(converter, ctx);
1013 
1014   // Lower async.runtime operations to the async runtime API calls.
1015   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1016                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1017                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1018                RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1019                RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1020                                                                   ctx);
1021 
1022   // Lower async.runtime operations that rely on LLVM type converter to convert
1023   // from async value payload type to the LLVM type.
1024   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1025                RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
1026                                                               ctx);
1027 
1028   // Lower async coroutine operations to LLVM coroutine intrinsics.
1029   patterns
1030       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1031            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1032           converter, ctx);
1033 
1034   ConversionTarget target(*ctx);
1035   target
1036       .addLegalOp<arith::ConstantOp, ConstantOp, UnrealizedConversionCastOp>();
1037   target.addLegalDialect<LLVM::LLVMDialect>();
1038 
1039   // All operations from Async dialect must be lowered to the runtime API and
1040   // LLVM intrinsics calls.
1041   target.addIllegalDialect<AsyncDialect>();
1042 
1043   // Add dynamic legality constraints to apply conversions defined above.
1044   target.addDynamicallyLegalOp<FuncOp>(
1045       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1046   target.addDynamicallyLegalOp<ReturnOp>(
1047       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1048   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1049     return converter.isSignatureLegal(op.getCalleeType());
1050   });
1051 
1052   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1053     signalPassFailure();
1054 }
1055 
1056 //===----------------------------------------------------------------------===//
1057 // Patterns for structural type conversions for the Async dialect operations.
1058 //===----------------------------------------------------------------------===//
1059 
1060 namespace {
1061 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1062 public:
1063   using OpConversionPattern::OpConversionPattern;
1064   LogicalResult
1065   matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1066                   ConversionPatternRewriter &rewriter) const override {
1067     ExecuteOp newOp =
1068         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1069     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1070                                 newOp.getRegion().end());
1071 
1072     // Set operands and update block argument and result types.
1073     newOp->setOperands(adaptor.getOperands());
1074     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1075       return failure();
1076     for (auto result : newOp.getResults())
1077       result.setType(typeConverter->convertType(result.getType()));
1078 
1079     rewriter.replaceOp(op, newOp.getResults());
1080     return success();
1081   }
1082 };
1083 
1084 // Dummy pattern to trigger the appropriate type conversion / materialization.
1085 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1086 public:
1087   using OpConversionPattern::OpConversionPattern;
1088   LogicalResult
1089   matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1090                   ConversionPatternRewriter &rewriter) const override {
1091     rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1092     return success();
1093   }
1094 };
1095 
1096 // Dummy pattern to trigger the appropriate type conversion / materialization.
1097 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1098 public:
1099   using OpConversionPattern::OpConversionPattern;
1100   LogicalResult
1101   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1102                   ConversionPatternRewriter &rewriter) const override {
1103     rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1104     return success();
1105   }
1106 };
1107 } // namespace
1108 
1109 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1110   return std::make_unique<ConvertAsyncToLLVMPass>();
1111 }
1112 
1113 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1114     TypeConverter &typeConverter, RewritePatternSet &patterns,
1115     ConversionTarget &target) {
1116   typeConverter.addConversion([&](TokenType type) { return type; });
1117   typeConverter.addConversion([&](ValueType type) {
1118     Type converted = typeConverter.convertType(type.getValueType());
1119     return converted ? ValueType::get(converted) : converted;
1120   });
1121 
1122   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1123       typeConverter, patterns.getContext());
1124 
1125   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1126       [&](Operation *op) { return typeConverter.isLegal(op); });
1127 }
1128