1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
19 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 #define DEBUG_TYPE "convert-async-to-llvm"
28 
29 using namespace mlir;
30 using namespace mlir::async;
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
52 static constexpr const char *kGetValueStorage =
53     "mlirAsyncRuntimeGetValueStorage";
54 static constexpr const char *kAddTokenToGroup =
55     "mlirAsyncRuntimeAddTokenToGroup";
56 static constexpr const char *kAwaitTokenAndExecute =
57     "mlirAsyncRuntimeAwaitTokenAndExecute";
58 static constexpr const char *kAwaitValueAndExecute =
59     "mlirAsyncRuntimeAwaitValueAndExecute";
60 static constexpr const char *kAwaitAllAndExecute =
61     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
62 static constexpr const char *kGetNumWorkerThreads =
63     "mlirAsyncRuntimGetNumWorkerThreads";
64 
65 namespace {
66 /// Async Runtime API function types.
67 ///
68 /// Because we can't create API function signature for type parametrized
69 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
70 /// lowering all async data types become opaque pointers at runtime.
71 struct AsyncAPI {
72   // All async types are lowered to opaque i8* LLVM pointers at runtime.
73   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
74     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
75   }
76 
77   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
78     return LLVM::LLVMTokenType::get(ctx);
79   }
80 
81   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
82     auto ref = opaquePointerType(ctx);
83     auto count = IntegerType::get(ctx, 64);
84     return FunctionType::get(ctx, {ref, count}, {});
85   }
86 
87   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
88     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
89   }
90 
91   static FunctionType createValueFunctionType(MLIRContext *ctx) {
92     auto i64 = IntegerType::get(ctx, 64);
93     auto value = opaquePointerType(ctx);
94     return FunctionType::get(ctx, {i64}, {value});
95   }
96 
97   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
98     auto i64 = IntegerType::get(ctx, 64);
99     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
100   }
101 
102   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
103     auto value = opaquePointerType(ctx);
104     auto storage = opaquePointerType(ctx);
105     return FunctionType::get(ctx, {value}, {storage});
106   }
107 
108   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
109     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
110   }
111 
112   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
113     auto value = opaquePointerType(ctx);
114     return FunctionType::get(ctx, {value}, {});
115   }
116 
117   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
118     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
119   }
120 
121   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
122     auto value = opaquePointerType(ctx);
123     return FunctionType::get(ctx, {value}, {});
124   }
125 
126   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
127     auto i1 = IntegerType::get(ctx, 1);
128     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
129   }
130 
131   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
132     auto value = opaquePointerType(ctx);
133     auto i1 = IntegerType::get(ctx, 1);
134     return FunctionType::get(ctx, {value}, {i1});
135   }
136 
137   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
138     auto i1 = IntegerType::get(ctx, 1);
139     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
140   }
141 
142   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
143     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
144   }
145 
146   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
147     auto value = opaquePointerType(ctx);
148     return FunctionType::get(ctx, {value}, {});
149   }
150 
151   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
152     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
153   }
154 
155   static FunctionType executeFunctionType(MLIRContext *ctx) {
156     auto hdl = opaquePointerType(ctx);
157     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
158     return FunctionType::get(ctx, {hdl, resume}, {});
159   }
160 
161   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
162     auto i64 = IntegerType::get(ctx, 64);
163     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
164                              {i64});
165   }
166 
167   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
168     auto hdl = opaquePointerType(ctx);
169     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
170     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
171   }
172 
173   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
174     auto value = opaquePointerType(ctx);
175     auto hdl = opaquePointerType(ctx);
176     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
177     return FunctionType::get(ctx, {value, hdl, resume}, {});
178   }
179 
180   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
181     auto hdl = opaquePointerType(ctx);
182     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
183     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
184   }
185 
186   static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
187     return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
188   }
189 
190   // Auxiliary coroutine resume intrinsic wrapper.
191   static Type resumeFunctionType(MLIRContext *ctx) {
192     auto voidTy = LLVM::LLVMVoidType::get(ctx);
193     auto i8Ptr = opaquePointerType(ctx);
194     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
195   }
196 };
197 } // namespace
198 
199 /// Adds Async Runtime C API declarations to the module.
200 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
201   auto builder =
202       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
203 
204   auto addFuncDecl = [&](StringRef name, FunctionType type) {
205     if (module.lookupSymbol(name))
206       return;
207     builder.create<func::FuncOp>(name, type).setPrivate();
208   };
209 
210   MLIRContext *ctx = module.getContext();
211   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
213   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
214   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
215   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
216   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
217   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
218   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
219   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
220   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
221   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
222   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
223   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
224   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
225   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
226   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
227   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
228   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
229   addFuncDecl(kAwaitTokenAndExecute,
230               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
231   addFuncDecl(kAwaitValueAndExecute,
232               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
233   addFuncDecl(kAwaitAllAndExecute,
234               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
235   addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // Coroutine resume function wrapper.
240 //===----------------------------------------------------------------------===//
241 
242 static constexpr const char *kResume = "__resume";
243 
244 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
245 /// intrinsics. We need this function to be able to pass it to the async
246 /// runtime execute API.
247 static void addResumeFunction(ModuleOp module) {
248   if (module.lookupSymbol(kResume))
249     return;
250 
251   MLIRContext *ctx = module.getContext();
252   auto loc = module.getLoc();
253   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
254 
255   auto voidTy = LLVM::LLVMVoidType::get(ctx);
256   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
257 
258   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
259       kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
260   resumeOp.setPrivate();
261 
262   auto *block = resumeOp.addEntryBlock();
263   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
264 
265   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
266   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // Convert Async dialect types to LLVM types.
271 //===----------------------------------------------------------------------===//
272 
273 namespace {
274 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
275 /// their runtime type (opaque pointers) and does not convert any other types.
276 class AsyncRuntimeTypeConverter : public TypeConverter {
277 public:
278   AsyncRuntimeTypeConverter() {
279     addConversion([](Type type) { return type; });
280     addConversion(convertAsyncTypes);
281   }
282 
283   static Optional<Type> convertAsyncTypes(Type type) {
284     if (type.isa<TokenType, GroupType, ValueType>())
285       return AsyncAPI::opaquePointerType(type.getContext());
286 
287     if (type.isa<CoroIdType, CoroStateType>())
288       return AsyncAPI::tokenType(type.getContext());
289     if (type.isa<CoroHandleType>())
290       return AsyncAPI::opaquePointerType(type.getContext());
291 
292     return llvm::None;
293   }
294 };
295 } // namespace
296 
297 //===----------------------------------------------------------------------===//
298 // Convert async.coro.id to @llvm.coro.id intrinsic.
299 //===----------------------------------------------------------------------===//
300 
301 namespace {
302 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
303 public:
304   using OpConversionPattern::OpConversionPattern;
305 
306   LogicalResult
307   matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
308                   ConversionPatternRewriter &rewriter) const override {
309     auto token = AsyncAPI::tokenType(op->getContext());
310     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
311     auto loc = op->getLoc();
312 
313     // Constants for initializing coroutine frame.
314     auto constZero = rewriter.create<LLVM::ConstantOp>(
315         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
316     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
317 
318     // Get coroutine id: @llvm.coro.id.
319     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
320         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
321 
322     return success();
323   }
324 };
325 } // namespace
326 
327 //===----------------------------------------------------------------------===//
328 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
329 //===----------------------------------------------------------------------===//
330 
331 namespace {
332 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
333 public:
334   using OpConversionPattern::OpConversionPattern;
335 
336   LogicalResult
337   matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
338                   ConversionPatternRewriter &rewriter) const override {
339     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
340     auto loc = op->getLoc();
341 
342     // Get coroutine frame size: @llvm.coro.size.i64.
343     Value coroSize =
344         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
345     // Get coroutine frame alignment: @llvm.coro.align.i64.
346     Value coroAlign =
347         rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
348 
349     // Round up the size to be multiple of the alignment. Since aligned_alloc
350     // requires the size parameter be an integral multiple of the alignment
351     // parameter.
352     auto makeConstant = [&](uint64_t c) {
353       return rewriter.create<LLVM::ConstantOp>(
354           op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(c));
355     };
356     coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
357     coroSize =
358         rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
359     Value negCoroAlign =
360         rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
361     coroSize =
362         rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
363 
364     // Allocate memory for the coroutine frame.
365     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
366         op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
367     auto coroAlloc = rewriter.create<LLVM::CallOp>(
368         loc, i8Ptr, SymbolRefAttr::get(allocFuncOp),
369         ValueRange{coroAlign, coroSize});
370 
371     // Begin a coroutine: @llvm.coro.begin.
372     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
373     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
374         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
375 
376     return success();
377   }
378 };
379 } // namespace
380 
381 //===----------------------------------------------------------------------===//
382 // Convert async.coro.free to @llvm.coro.free intrinsic.
383 //===----------------------------------------------------------------------===//
384 
385 namespace {
386 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
387 public:
388   using OpConversionPattern::OpConversionPattern;
389 
390   LogicalResult
391   matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
392                   ConversionPatternRewriter &rewriter) const override {
393     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
394     auto loc = op->getLoc();
395 
396     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
397     auto coroMem =
398         rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
399 
400     // Free the memory.
401     auto freeFuncOp =
402         LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
403     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
404                                               SymbolRefAttr::get(freeFuncOp),
405                                               ValueRange(coroMem.getResult()));
406 
407     return success();
408   }
409 };
410 } // namespace
411 
412 //===----------------------------------------------------------------------===//
413 // Convert async.coro.end to @llvm.coro.end intrinsic.
414 //===----------------------------------------------------------------------===//
415 
416 namespace {
417 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
418 public:
419   using OpConversionPattern::OpConversionPattern;
420 
421   LogicalResult
422   matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
423                   ConversionPatternRewriter &rewriter) const override {
424     // We are not in the block that is part of the unwind sequence.
425     auto constFalse = rewriter.create<LLVM::ConstantOp>(
426         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
427 
428     // Mark the end of a coroutine: @llvm.coro.end.
429     auto coroHdl = adaptor.handle();
430     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
431                                      ValueRange({coroHdl, constFalse}));
432     rewriter.eraseOp(op);
433 
434     return success();
435   }
436 };
437 } // namespace
438 
439 //===----------------------------------------------------------------------===//
440 // Convert async.coro.save to @llvm.coro.save intrinsic.
441 //===----------------------------------------------------------------------===//
442 
443 namespace {
444 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
445 public:
446   using OpConversionPattern::OpConversionPattern;
447 
448   LogicalResult
449   matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
450                   ConversionPatternRewriter &rewriter) const override {
451     // Save the coroutine state: @llvm.coro.save
452     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
453         op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
454 
455     return success();
456   }
457 };
458 } // namespace
459 
460 //===----------------------------------------------------------------------===//
461 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
462 //===----------------------------------------------------------------------===//
463 
464 namespace {
465 
466 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
467 /// branch to the appropriate block based on the return code.
468 ///
469 /// Before:
470 ///
471 ///   ^suspended:
472 ///     "opBefore"(...)
473 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
474 ///   ^resume:
475 ///     "op"(...)
476 ///   ^cleanup: ...
477 ///   ^suspend: ...
478 ///
479 /// After:
480 ///
481 ///   ^suspended:
482 ///     "opBefore"(...)
483 ///     %suspend = llmv.intr.coro.suspend ...
484 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
485 ///   ^resume:
486 ///     "op"(...)
487 ///   ^cleanup: ...
488 ///   ^suspend: ...
489 ///
490 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
491 public:
492   using OpConversionPattern::OpConversionPattern;
493 
494   LogicalResult
495   matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
496                   ConversionPatternRewriter &rewriter) const override {
497     auto i8 = rewriter.getIntegerType(8);
498     auto i32 = rewriter.getI32Type();
499     auto loc = op->getLoc();
500 
501     // This is not a final suspension point.
502     auto constFalse = rewriter.create<LLVM::ConstantOp>(
503         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
504 
505     // Suspend a coroutine: @llvm.coro.suspend
506     auto coroState = adaptor.state();
507     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
508         loc, i8, ValueRange({coroState, constFalse}));
509 
510     // Cast return code to i32.
511 
512     // After a suspension point decide if we should branch into resume, cleanup
513     // or suspend block of the coroutine (see @llvm.coro.suspend return code
514     // documentation).
515     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
516     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
517                                               op.cleanupDest()};
518     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
519         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
520         /*defaultDestination=*/op.suspendDest(),
521         /*defaultOperands=*/ValueRange(),
522         /*caseValues=*/caseValues,
523         /*caseDestinations=*/caseDest,
524         /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
525         /*branchWeights=*/ArrayRef<int32_t>());
526 
527     return success();
528   }
529 };
530 } // namespace
531 
532 //===----------------------------------------------------------------------===//
533 // Convert async.runtime.create to the corresponding runtime API call.
534 //
535 // To allocate storage for the async values we use getelementptr trick:
536 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
537 //===----------------------------------------------------------------------===//
538 
539 namespace {
540 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
541 public:
542   using OpConversionPattern::OpConversionPattern;
543 
544   LogicalResult
545   matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
546                   ConversionPatternRewriter &rewriter) const override {
547     TypeConverter *converter = getTypeConverter();
548     Type resultType = op->getResultTypes()[0];
549 
550     // Tokens creation maps to a simple function call.
551     if (resultType.isa<TokenType>()) {
552       rewriter.replaceOpWithNewOp<func::CallOp>(
553           op, kCreateToken, converter->convertType(resultType));
554       return success();
555     }
556 
557     // To create a value we need to compute the storage requirement.
558     if (auto value = resultType.dyn_cast<ValueType>()) {
559       // Returns the size requirements for the async value storage.
560       auto sizeOf = [&](ValueType valueType) -> Value {
561         auto loc = op->getLoc();
562         auto i64 = rewriter.getI64Type();
563 
564         auto storedType = converter->convertType(valueType.getValueType());
565         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
566 
567         // %Size = getelementptr %T* null, int 1
568         // %SizeI = ptrtoint %T* %Size to i64
569         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
570         auto one = rewriter.create<LLVM::ConstantOp>(
571             loc, i64, rewriter.getI64IntegerAttr(1));
572         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
573                                                 one.getResult());
574         return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
575       };
576 
577       rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
578                                                 sizeOf(value));
579 
580       return success();
581     }
582 
583     return rewriter.notifyMatchFailure(op, "unsupported async type");
584   }
585 };
586 } // namespace
587 
588 //===----------------------------------------------------------------------===//
589 // Convert async.runtime.create_group to the corresponding runtime API call.
590 //===----------------------------------------------------------------------===//
591 
592 namespace {
593 class RuntimeCreateGroupOpLowering
594     : public OpConversionPattern<RuntimeCreateGroupOp> {
595 public:
596   using OpConversionPattern::OpConversionPattern;
597 
598   LogicalResult
599   matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
600                   ConversionPatternRewriter &rewriter) const override {
601     TypeConverter *converter = getTypeConverter();
602     Type resultType = op.getResult().getType();
603 
604     rewriter.replaceOpWithNewOp<func::CallOp>(
605         op, kCreateGroup, converter->convertType(resultType),
606         adaptor.getOperands());
607     return success();
608   }
609 };
610 } // namespace
611 
612 //===----------------------------------------------------------------------===//
613 // Convert async.runtime.set_available to the corresponding runtime API call.
614 //===----------------------------------------------------------------------===//
615 
616 namespace {
617 class RuntimeSetAvailableOpLowering
618     : public OpConversionPattern<RuntimeSetAvailableOp> {
619 public:
620   using OpConversionPattern::OpConversionPattern;
621 
622   LogicalResult
623   matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
624                   ConversionPatternRewriter &rewriter) const override {
625     StringRef apiFuncName =
626         TypeSwitch<Type, StringRef>(op.operand().getType())
627             .Case<TokenType>([](Type) { return kEmplaceToken; })
628             .Case<ValueType>([](Type) { return kEmplaceValue; });
629 
630     rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
631                                               adaptor.getOperands());
632 
633     return success();
634   }
635 };
636 } // namespace
637 
638 //===----------------------------------------------------------------------===//
639 // Convert async.runtime.set_error to the corresponding runtime API call.
640 //===----------------------------------------------------------------------===//
641 
642 namespace {
643 class RuntimeSetErrorOpLowering
644     : public OpConversionPattern<RuntimeSetErrorOp> {
645 public:
646   using OpConversionPattern::OpConversionPattern;
647 
648   LogicalResult
649   matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
650                   ConversionPatternRewriter &rewriter) const override {
651     StringRef apiFuncName =
652         TypeSwitch<Type, StringRef>(op.operand().getType())
653             .Case<TokenType>([](Type) { return kSetTokenError; })
654             .Case<ValueType>([](Type) { return kSetValueError; });
655 
656     rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
657                                               adaptor.getOperands());
658 
659     return success();
660   }
661 };
662 } // namespace
663 
664 //===----------------------------------------------------------------------===//
665 // Convert async.runtime.is_error to the corresponding runtime API call.
666 //===----------------------------------------------------------------------===//
667 
668 namespace {
669 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
670 public:
671   using OpConversionPattern::OpConversionPattern;
672 
673   LogicalResult
674   matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
675                   ConversionPatternRewriter &rewriter) const override {
676     StringRef apiFuncName =
677         TypeSwitch<Type, StringRef>(op.operand().getType())
678             .Case<TokenType>([](Type) { return kIsTokenError; })
679             .Case<GroupType>([](Type) { return kIsGroupError; })
680             .Case<ValueType>([](Type) { return kIsValueError; });
681 
682     rewriter.replaceOpWithNewOp<func::CallOp>(
683         op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
684     return success();
685   }
686 };
687 } // namespace
688 
689 //===----------------------------------------------------------------------===//
690 // Convert async.runtime.await to the corresponding runtime API call.
691 //===----------------------------------------------------------------------===//
692 
693 namespace {
694 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
695 public:
696   using OpConversionPattern::OpConversionPattern;
697 
698   LogicalResult
699   matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
700                   ConversionPatternRewriter &rewriter) const override {
701     StringRef apiFuncName =
702         TypeSwitch<Type, StringRef>(op.operand().getType())
703             .Case<TokenType>([](Type) { return kAwaitToken; })
704             .Case<ValueType>([](Type) { return kAwaitValue; })
705             .Case<GroupType>([](Type) { return kAwaitGroup; });
706 
707     rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
708                                   adaptor.getOperands());
709     rewriter.eraseOp(op);
710 
711     return success();
712   }
713 };
714 } // namespace
715 
716 //===----------------------------------------------------------------------===//
717 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
718 //===----------------------------------------------------------------------===//
719 
720 namespace {
721 class RuntimeAwaitAndResumeOpLowering
722     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
723 public:
724   using OpConversionPattern::OpConversionPattern;
725 
726   LogicalResult
727   matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
728                   ConversionPatternRewriter &rewriter) const override {
729     StringRef apiFuncName =
730         TypeSwitch<Type, StringRef>(op.operand().getType())
731             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
732             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
733             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
734 
735     Value operand = adaptor.operand();
736     Value handle = adaptor.handle();
737 
738     // A pointer to coroutine resume intrinsic wrapper.
739     addResumeFunction(op->getParentOfType<ModuleOp>());
740     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
741     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
742         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
743 
744     rewriter.create<func::CallOp>(
745         op->getLoc(), apiFuncName, TypeRange(),
746         ValueRange({operand, handle, resumePtr.getRes()}));
747     rewriter.eraseOp(op);
748 
749     return success();
750   }
751 };
752 } // namespace
753 
754 //===----------------------------------------------------------------------===//
755 // Convert async.runtime.resume to the corresponding runtime API call.
756 //===----------------------------------------------------------------------===//
757 
758 namespace {
759 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
760 public:
761   using OpConversionPattern::OpConversionPattern;
762 
763   LogicalResult
764   matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
765                   ConversionPatternRewriter &rewriter) const override {
766     // A pointer to coroutine resume intrinsic wrapper.
767     addResumeFunction(op->getParentOfType<ModuleOp>());
768     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
769     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
770         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
771 
772     // Call async runtime API to execute a coroutine in the managed thread.
773     auto coroHdl = adaptor.handle();
774     rewriter.replaceOpWithNewOp<func::CallOp>(
775         op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
776 
777     return success();
778   }
779 };
780 } // namespace
781 
782 //===----------------------------------------------------------------------===//
783 // Convert async.runtime.store to the corresponding runtime API call.
784 //===----------------------------------------------------------------------===//
785 
786 namespace {
787 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
788 public:
789   using OpConversionPattern::OpConversionPattern;
790 
791   LogicalResult
792   matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
793                   ConversionPatternRewriter &rewriter) const override {
794     Location loc = op->getLoc();
795 
796     // Get a pointer to the async value storage from the runtime.
797     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
798     auto storage = adaptor.storage();
799     auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
800                                                     TypeRange(i8Ptr), storage);
801 
802     // Cast from i8* to the LLVM pointer type.
803     auto valueType = op.value().getType();
804     auto llvmValueType = getTypeConverter()->convertType(valueType);
805     if (!llvmValueType)
806       return rewriter.notifyMatchFailure(
807           op, "failed to convert stored value type to LLVM type");
808 
809     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
810         loc, LLVM::LLVMPointerType::get(llvmValueType),
811         storagePtr.getResult(0));
812 
813     // Store the yielded value into the async value storage.
814     auto value = adaptor.value();
815     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
816 
817     // Erase the original runtime store operation.
818     rewriter.eraseOp(op);
819 
820     return success();
821   }
822 };
823 } // namespace
824 
825 //===----------------------------------------------------------------------===//
826 // Convert async.runtime.load to the corresponding runtime API call.
827 //===----------------------------------------------------------------------===//
828 
829 namespace {
830 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
831 public:
832   using OpConversionPattern::OpConversionPattern;
833 
834   LogicalResult
835   matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
836                   ConversionPatternRewriter &rewriter) const override {
837     Location loc = op->getLoc();
838 
839     // Get a pointer to the async value storage from the runtime.
840     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
841     auto storage = adaptor.storage();
842     auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
843                                                     TypeRange(i8Ptr), storage);
844 
845     // Cast from i8* to the LLVM pointer type.
846     auto valueType = op.result().getType();
847     auto llvmValueType = getTypeConverter()->convertType(valueType);
848     if (!llvmValueType)
849       return rewriter.notifyMatchFailure(
850           op, "failed to convert loaded value type to LLVM type");
851 
852     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
853         loc, LLVM::LLVMPointerType::get(llvmValueType),
854         storagePtr.getResult(0));
855 
856     // Load from the casted pointer.
857     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
858 
859     return success();
860   }
861 };
862 } // namespace
863 
864 //===----------------------------------------------------------------------===//
865 // Convert async.runtime.add_to_group to the corresponding runtime API call.
866 //===----------------------------------------------------------------------===//
867 
868 namespace {
869 class RuntimeAddToGroupOpLowering
870     : public OpConversionPattern<RuntimeAddToGroupOp> {
871 public:
872   using OpConversionPattern::OpConversionPattern;
873 
874   LogicalResult
875   matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
876                   ConversionPatternRewriter &rewriter) const override {
877     // Currently we can only add tokens to the group.
878     if (!op.operand().getType().isa<TokenType>())
879       return rewriter.notifyMatchFailure(op, "only token type is supported");
880 
881     // Replace with a runtime API function call.
882     rewriter.replaceOpWithNewOp<func::CallOp>(
883         op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
884 
885     return success();
886   }
887 };
888 } // namespace
889 
890 //===----------------------------------------------------------------------===//
891 // Convert async.runtime.num_worker_threads to the corresponding runtime API
892 // call.
893 //===----------------------------------------------------------------------===//
894 
895 namespace {
896 class RuntimeNumWorkerThreadsOpLowering
897     : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
898 public:
899   using OpConversionPattern::OpConversionPattern;
900 
901   LogicalResult
902   matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
903                   ConversionPatternRewriter &rewriter) const override {
904 
905     // Replace with a runtime API function call.
906     rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
907                                               rewriter.getIndexType());
908 
909     return success();
910   }
911 };
912 } // namespace
913 
914 //===----------------------------------------------------------------------===//
915 // Async reference counting ops lowering (`async.runtime.add_ref` and
916 // `async.runtime.drop_ref` to the corresponding API calls).
917 //===----------------------------------------------------------------------===//
918 
919 namespace {
920 template <typename RefCountingOp>
921 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
922 public:
923   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
924                                  StringRef apiFunctionName)
925       : OpConversionPattern<RefCountingOp>(converter, ctx),
926         apiFunctionName(apiFunctionName) {}
927 
928   LogicalResult
929   matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
930                   ConversionPatternRewriter &rewriter) const override {
931     auto count = rewriter.create<arith::ConstantOp>(
932         op->getLoc(), rewriter.getI64Type(),
933         rewriter.getI64IntegerAttr(op.count()));
934 
935     auto operand = adaptor.operand();
936     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
937                                               ValueRange({operand, count}));
938 
939     return success();
940   }
941 
942 private:
943   StringRef apiFunctionName;
944 };
945 
946 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
947 public:
948   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
949       : RefCountingOpLowering(converter, ctx, kAddRef) {}
950 };
951 
952 class RuntimeDropRefOpLowering
953     : public RefCountingOpLowering<RuntimeDropRefOp> {
954 public:
955   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
956       : RefCountingOpLowering(converter, ctx, kDropRef) {}
957 };
958 } // namespace
959 
960 //===----------------------------------------------------------------------===//
961 // Convert return operations that return async values from async regions.
962 //===----------------------------------------------------------------------===//
963 
964 namespace {
965 class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
966 public:
967   using OpConversionPattern::OpConversionPattern;
968 
969   LogicalResult
970   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
971                   ConversionPatternRewriter &rewriter) const override {
972     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
973     return success();
974   }
975 };
976 } // namespace
977 
978 //===----------------------------------------------------------------------===//
979 
980 namespace {
981 struct ConvertAsyncToLLVMPass
982     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
983   void runOnOperation() override;
984 };
985 } // namespace
986 
987 void ConvertAsyncToLLVMPass::runOnOperation() {
988   ModuleOp module = getOperation();
989   MLIRContext *ctx = module->getContext();
990 
991   // Add declarations for most functions required by the coroutines lowering.
992   // We delay adding the resume function until it's needed because it currently
993   // fails to compile unless '-O0' is specified.
994   addAsyncRuntimeApiDeclarations(module);
995 
996   // Lower async.runtime and async.coro operations to Async Runtime API and
997   // LLVM coroutine intrinsics.
998 
999   // Convert async dialect types and operations to LLVM dialect.
1000   AsyncRuntimeTypeConverter converter;
1001   RewritePatternSet patterns(ctx);
1002 
1003   // We use conversion to LLVM type to lower async.runtime load and store
1004   // operations.
1005   LLVMTypeConverter llvmConverter(ctx);
1006   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
1007 
1008   // Convert async types in function signatures and function calls.
1009   populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1010                                                                  converter);
1011   populateCallOpTypeConversionPattern(patterns, converter);
1012 
1013   // Convert return operations inside async.execute regions.
1014   patterns.add<ReturnOpOpConversion>(converter, ctx);
1015 
1016   // Lower async.runtime operations to the async runtime API calls.
1017   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1018                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1019                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1020                RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1021                RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1022                                                                   ctx);
1023 
1024   // Lower async.runtime operations that rely on LLVM type converter to convert
1025   // from async value payload type to the LLVM type.
1026   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1027                RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
1028                                                               ctx);
1029 
1030   // Lower async coroutine operations to LLVM coroutine intrinsics.
1031   patterns
1032       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1033            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1034           converter, ctx);
1035 
1036   ConversionTarget target(*ctx);
1037   target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1038                     UnrealizedConversionCastOp>();
1039   target.addLegalDialect<LLVM::LLVMDialect>();
1040 
1041   // All operations from Async dialect must be lowered to the runtime API and
1042   // LLVM intrinsics calls.
1043   target.addIllegalDialect<AsyncDialect>();
1044 
1045   // Add dynamic legality constraints to apply conversions defined above.
1046   target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1047     return converter.isSignatureLegal(op.getFunctionType());
1048   });
1049   target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1050     return converter.isLegal(op.getOperandTypes());
1051   });
1052   target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1053     return converter.isSignatureLegal(op.getCalleeType());
1054   });
1055 
1056   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1057     signalPassFailure();
1058 }
1059 
1060 //===----------------------------------------------------------------------===//
1061 // Patterns for structural type conversions for the Async dialect operations.
1062 //===----------------------------------------------------------------------===//
1063 
1064 namespace {
1065 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1066 public:
1067   using OpConversionPattern::OpConversionPattern;
1068   LogicalResult
1069   matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1070                   ConversionPatternRewriter &rewriter) const override {
1071     ExecuteOp newOp =
1072         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1073     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1074                                 newOp.getRegion().end());
1075 
1076     // Set operands and update block argument and result types.
1077     newOp->setOperands(adaptor.getOperands());
1078     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1079       return failure();
1080     for (auto result : newOp.getResults())
1081       result.setType(typeConverter->convertType(result.getType()));
1082 
1083     rewriter.replaceOp(op, newOp.getResults());
1084     return success();
1085   }
1086 };
1087 
1088 // Dummy pattern to trigger the appropriate type conversion / materialization.
1089 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1090 public:
1091   using OpConversionPattern::OpConversionPattern;
1092   LogicalResult
1093   matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1094                   ConversionPatternRewriter &rewriter) const override {
1095     rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1096     return success();
1097   }
1098 };
1099 
1100 // Dummy pattern to trigger the appropriate type conversion / materialization.
1101 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1102 public:
1103   using OpConversionPattern::OpConversionPattern;
1104   LogicalResult
1105   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1106                   ConversionPatternRewriter &rewriter) const override {
1107     rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1108     return success();
1109   }
1110 };
1111 } // namespace
1112 
1113 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1114   return std::make_unique<ConvertAsyncToLLVMPass>();
1115 }
1116 
1117 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1118     TypeConverter &typeConverter, RewritePatternSet &patterns,
1119     ConversionTarget &target) {
1120   typeConverter.addConversion([&](TokenType type) { return type; });
1121   typeConverter.addConversion([&](ValueType type) {
1122     Type converted = typeConverter.convertType(type.getValueType());
1123     return converted ? ValueType::get(converted) : converted;
1124   });
1125 
1126   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1127       typeConverter, patterns.getContext());
1128 
1129   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1130       [&](Operation *op) { return typeConverter.isLegal(op); });
1131 }
1132