1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 #define DEBUG_TYPE "convert-async-to-llvm"
28 
29 using namespace mlir;
30 using namespace mlir::async;
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
52 static constexpr const char *kGetValueStorage =
53     "mlirAsyncRuntimeGetValueStorage";
54 static constexpr const char *kAddTokenToGroup =
55     "mlirAsyncRuntimeAddTokenToGroup";
56 static constexpr const char *kAwaitTokenAndExecute =
57     "mlirAsyncRuntimeAwaitTokenAndExecute";
58 static constexpr const char *kAwaitValueAndExecute =
59     "mlirAsyncRuntimeAwaitValueAndExecute";
60 static constexpr const char *kAwaitAllAndExecute =
61     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
62 
63 namespace {
64 /// Async Runtime API function types.
65 ///
66 /// Because we can't create API function signature for type parametrized
67 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
68 /// lowering all async data types become opaque pointers at runtime.
69 struct AsyncAPI {
70   // All async types are lowered to opaque i8* LLVM pointers at runtime.
71   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
72     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
73   }
74 
75   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
76     return LLVM::LLVMTokenType::get(ctx);
77   }
78 
79   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
80     auto ref = opaquePointerType(ctx);
81     auto count = IntegerType::get(ctx, 64);
82     return FunctionType::get(ctx, {ref, count}, {});
83   }
84 
85   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
86     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
87   }
88 
89   static FunctionType createValueFunctionType(MLIRContext *ctx) {
90     auto i64 = IntegerType::get(ctx, 64);
91     auto value = opaquePointerType(ctx);
92     return FunctionType::get(ctx, {i64}, {value});
93   }
94 
95   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
96     auto i64 = IntegerType::get(ctx, 64);
97     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
98   }
99 
100   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
101     auto value = opaquePointerType(ctx);
102     auto storage = opaquePointerType(ctx);
103     return FunctionType::get(ctx, {value}, {storage});
104   }
105 
106   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
107     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
108   }
109 
110   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
111     auto value = opaquePointerType(ctx);
112     return FunctionType::get(ctx, {value}, {});
113   }
114 
115   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
116     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
117   }
118 
119   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
120     auto value = opaquePointerType(ctx);
121     return FunctionType::get(ctx, {value}, {});
122   }
123 
124   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
125     auto i1 = IntegerType::get(ctx, 1);
126     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
127   }
128 
129   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
130     auto value = opaquePointerType(ctx);
131     auto i1 = IntegerType::get(ctx, 1);
132     return FunctionType::get(ctx, {value}, {i1});
133   }
134 
135   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
136     auto i1 = IntegerType::get(ctx, 1);
137     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
138   }
139 
140   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
141     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
142   }
143 
144   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
145     auto value = opaquePointerType(ctx);
146     return FunctionType::get(ctx, {value}, {});
147   }
148 
149   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
150     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
151   }
152 
153   static FunctionType executeFunctionType(MLIRContext *ctx) {
154     auto hdl = opaquePointerType(ctx);
155     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
156     return FunctionType::get(ctx, {hdl, resume}, {});
157   }
158 
159   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
160     auto i64 = IntegerType::get(ctx, 64);
161     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
162                              {i64});
163   }
164 
165   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
166     auto hdl = opaquePointerType(ctx);
167     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
168     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
169   }
170 
171   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
172     auto value = opaquePointerType(ctx);
173     auto hdl = opaquePointerType(ctx);
174     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
175     return FunctionType::get(ctx, {value, hdl, resume}, {});
176   }
177 
178   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
179     auto hdl = opaquePointerType(ctx);
180     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
181     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
182   }
183 
184   // Auxiliary coroutine resume intrinsic wrapper.
185   static Type resumeFunctionType(MLIRContext *ctx) {
186     auto voidTy = LLVM::LLVMVoidType::get(ctx);
187     auto i8Ptr = opaquePointerType(ctx);
188     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
189   }
190 };
191 } // namespace
192 
193 /// Adds Async Runtime C API declarations to the module.
194 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
195   auto builder =
196       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
197 
198   auto addFuncDecl = [&](StringRef name, FunctionType type) {
199     if (module.lookupSymbol(name))
200       return;
201     builder.create<FuncOp>(name, type).setPrivate();
202   };
203 
204   MLIRContext *ctx = module.getContext();
205   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
206   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
207   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
208   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
209   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
210   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
211   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
212   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
213   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
214   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
215   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
216   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
217   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
218   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
219   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
220   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
221   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
222   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
223   addFuncDecl(kAwaitTokenAndExecute,
224               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
225   addFuncDecl(kAwaitValueAndExecute,
226               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
227   addFuncDecl(kAwaitAllAndExecute,
228               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // Coroutine resume function wrapper.
233 //===----------------------------------------------------------------------===//
234 
235 static constexpr const char *kResume = "__resume";
236 
237 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
238 /// intrinsics. We need this function to be able to pass it to the async
239 /// runtime execute API.
240 static void addResumeFunction(ModuleOp module) {
241   if (module.lookupSymbol(kResume))
242     return;
243 
244   MLIRContext *ctx = module.getContext();
245   auto loc = module.getLoc();
246   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
247 
248   auto voidTy = LLVM::LLVMVoidType::get(ctx);
249   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
250 
251   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
252       kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
253   resumeOp.setPrivate();
254 
255   auto *block = resumeOp.addEntryBlock();
256   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
257 
258   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
259   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // Convert Async dialect types to LLVM types.
264 //===----------------------------------------------------------------------===//
265 
266 namespace {
267 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
268 /// their runtime type (opaque pointers) and does not convert any other types.
269 class AsyncRuntimeTypeConverter : public TypeConverter {
270 public:
271   AsyncRuntimeTypeConverter() {
272     addConversion([](Type type) { return type; });
273     addConversion(convertAsyncTypes);
274   }
275 
276   static Optional<Type> convertAsyncTypes(Type type) {
277     if (type.isa<TokenType, GroupType, ValueType>())
278       return AsyncAPI::opaquePointerType(type.getContext());
279 
280     if (type.isa<CoroIdType, CoroStateType>())
281       return AsyncAPI::tokenType(type.getContext());
282     if (type.isa<CoroHandleType>())
283       return AsyncAPI::opaquePointerType(type.getContext());
284 
285     return llvm::None;
286   }
287 };
288 } // namespace
289 
290 //===----------------------------------------------------------------------===//
291 // Convert async.coro.id to @llvm.coro.id intrinsic.
292 //===----------------------------------------------------------------------===//
293 
294 namespace {
295 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
296 public:
297   using OpConversionPattern::OpConversionPattern;
298 
299   LogicalResult
300   matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
301                   ConversionPatternRewriter &rewriter) const override {
302     auto token = AsyncAPI::tokenType(op->getContext());
303     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
304     auto loc = op->getLoc();
305 
306     // Constants for initializing coroutine frame.
307     auto constZero = rewriter.create<LLVM::ConstantOp>(
308         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
309     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
310 
311     // Get coroutine id: @llvm.coro.id.
312     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
313         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
314 
315     return success();
316   }
317 };
318 } // namespace
319 
320 //===----------------------------------------------------------------------===//
321 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
322 //===----------------------------------------------------------------------===//
323 
324 namespace {
325 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
326 public:
327   using OpConversionPattern::OpConversionPattern;
328 
329   LogicalResult
330   matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
331                   ConversionPatternRewriter &rewriter) const override {
332     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
333     auto loc = op->getLoc();
334 
335     // Get coroutine frame size: @llvm.coro.size.i64.
336     Value coroSize =
337         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
338     // Get coroutine frame alignment: @llvm.coro.align.i64.
339     Value coroAlign =
340         rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
341 
342     // Round up the size to be multiple of the alignment. Since aligned_alloc
343     // requires the size parameter be an integral multiple of the alignment
344     // parameter.
345     auto makeConstant = [&](uint64_t c) {
346       return rewriter.create<LLVM::ConstantOp>(
347           op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(c));
348     };
349     coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
350     coroSize =
351         rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
352     Value NegCoroAlign =
353         rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
354     coroSize =
355         rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, NegCoroAlign);
356 
357     // Allocate memory for the coroutine frame.
358     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
359         op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
360     auto coroAlloc = rewriter.create<LLVM::CallOp>(
361         loc, i8Ptr, SymbolRefAttr::get(allocFuncOp),
362         ValueRange{coroAlign, coroSize});
363 
364     // Begin a coroutine: @llvm.coro.begin.
365     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
366     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
367         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
368 
369     return success();
370   }
371 };
372 } // namespace
373 
374 //===----------------------------------------------------------------------===//
375 // Convert async.coro.free to @llvm.coro.free intrinsic.
376 //===----------------------------------------------------------------------===//
377 
378 namespace {
379 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
380 public:
381   using OpConversionPattern::OpConversionPattern;
382 
383   LogicalResult
384   matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
385                   ConversionPatternRewriter &rewriter) const override {
386     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
387     auto loc = op->getLoc();
388 
389     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
390     auto coroMem =
391         rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
392 
393     // Free the memory.
394     auto freeFuncOp =
395         LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
396     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
397                                               SymbolRefAttr::get(freeFuncOp),
398                                               ValueRange(coroMem.getResult()));
399 
400     return success();
401   }
402 };
403 } // namespace
404 
405 //===----------------------------------------------------------------------===//
406 // Convert async.coro.end to @llvm.coro.end intrinsic.
407 //===----------------------------------------------------------------------===//
408 
409 namespace {
410 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
411 public:
412   using OpConversionPattern::OpConversionPattern;
413 
414   LogicalResult
415   matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
416                   ConversionPatternRewriter &rewriter) const override {
417     // We are not in the block that is part of the unwind sequence.
418     auto constFalse = rewriter.create<LLVM::ConstantOp>(
419         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
420 
421     // Mark the end of a coroutine: @llvm.coro.end.
422     auto coroHdl = adaptor.handle();
423     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
424                                      ValueRange({coroHdl, constFalse}));
425     rewriter.eraseOp(op);
426 
427     return success();
428   }
429 };
430 } // namespace
431 
432 //===----------------------------------------------------------------------===//
433 // Convert async.coro.save to @llvm.coro.save intrinsic.
434 //===----------------------------------------------------------------------===//
435 
436 namespace {
437 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
438 public:
439   using OpConversionPattern::OpConversionPattern;
440 
441   LogicalResult
442   matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
443                   ConversionPatternRewriter &rewriter) const override {
444     // Save the coroutine state: @llvm.coro.save
445     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
446         op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
447 
448     return success();
449   }
450 };
451 } // namespace
452 
453 //===----------------------------------------------------------------------===//
454 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
455 //===----------------------------------------------------------------------===//
456 
457 namespace {
458 
459 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
460 /// branch to the appropriate block based on the return code.
461 ///
462 /// Before:
463 ///
464 ///   ^suspended:
465 ///     "opBefore"(...)
466 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
467 ///   ^resume:
468 ///     "op"(...)
469 ///   ^cleanup: ...
470 ///   ^suspend: ...
471 ///
472 /// After:
473 ///
474 ///   ^suspended:
475 ///     "opBefore"(...)
476 ///     %suspend = llmv.intr.coro.suspend ...
477 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
478 ///   ^resume:
479 ///     "op"(...)
480 ///   ^cleanup: ...
481 ///   ^suspend: ...
482 ///
483 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
484 public:
485   using OpConversionPattern::OpConversionPattern;
486 
487   LogicalResult
488   matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
489                   ConversionPatternRewriter &rewriter) const override {
490     auto i8 = rewriter.getIntegerType(8);
491     auto i32 = rewriter.getI32Type();
492     auto loc = op->getLoc();
493 
494     // This is not a final suspension point.
495     auto constFalse = rewriter.create<LLVM::ConstantOp>(
496         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
497 
498     // Suspend a coroutine: @llvm.coro.suspend
499     auto coroState = adaptor.state();
500     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
501         loc, i8, ValueRange({coroState, constFalse}));
502 
503     // Cast return code to i32.
504 
505     // After a suspension point decide if we should branch into resume, cleanup
506     // or suspend block of the coroutine (see @llvm.coro.suspend return code
507     // documentation).
508     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
509     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
510                                               op.cleanupDest()};
511     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
512         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
513         /*defaultDestination=*/op.suspendDest(),
514         /*defaultOperands=*/ValueRange(),
515         /*caseValues=*/caseValues,
516         /*caseDestinations=*/caseDest,
517         /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
518         /*branchWeights=*/ArrayRef<int32_t>());
519 
520     return success();
521   }
522 };
523 } // namespace
524 
525 //===----------------------------------------------------------------------===//
526 // Convert async.runtime.create to the corresponding runtime API call.
527 //
528 // To allocate storage for the async values we use getelementptr trick:
529 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
530 //===----------------------------------------------------------------------===//
531 
532 namespace {
533 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
534 public:
535   using OpConversionPattern::OpConversionPattern;
536 
537   LogicalResult
538   matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
539                   ConversionPatternRewriter &rewriter) const override {
540     TypeConverter *converter = getTypeConverter();
541     Type resultType = op->getResultTypes()[0];
542 
543     // Tokens creation maps to a simple function call.
544     if (resultType.isa<TokenType>()) {
545       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
546                                           converter->convertType(resultType));
547       return success();
548     }
549 
550     // To create a value we need to compute the storage requirement.
551     if (auto value = resultType.dyn_cast<ValueType>()) {
552       // Returns the size requirements for the async value storage.
553       auto sizeOf = [&](ValueType valueType) -> Value {
554         auto loc = op->getLoc();
555         auto i64 = rewriter.getI64Type();
556 
557         auto storedType = converter->convertType(valueType.getValueType());
558         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
559 
560         // %Size = getelementptr %T* null, int 1
561         // %SizeI = ptrtoint %T* %Size to i64
562         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
563         auto one = rewriter.create<LLVM::ConstantOp>(
564             loc, i64, rewriter.getI64IntegerAttr(1));
565         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
566                                                 one.getResult());
567         return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
568       };
569 
570       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
571                                           sizeOf(value));
572 
573       return success();
574     }
575 
576     return rewriter.notifyMatchFailure(op, "unsupported async type");
577   }
578 };
579 } // namespace
580 
581 //===----------------------------------------------------------------------===//
582 // Convert async.runtime.create_group to the corresponding runtime API call.
583 //===----------------------------------------------------------------------===//
584 
585 namespace {
586 class RuntimeCreateGroupOpLowering
587     : public OpConversionPattern<RuntimeCreateGroupOp> {
588 public:
589   using OpConversionPattern::OpConversionPattern;
590 
591   LogicalResult
592   matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
593                   ConversionPatternRewriter &rewriter) const override {
594     TypeConverter *converter = getTypeConverter();
595     Type resultType = op.getResult().getType();
596 
597     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup,
598                                         converter->convertType(resultType),
599                                         adaptor.getOperands());
600     return success();
601   }
602 };
603 } // namespace
604 
605 //===----------------------------------------------------------------------===//
606 // Convert async.runtime.set_available to the corresponding runtime API call.
607 //===----------------------------------------------------------------------===//
608 
609 namespace {
610 class RuntimeSetAvailableOpLowering
611     : public OpConversionPattern<RuntimeSetAvailableOp> {
612 public:
613   using OpConversionPattern::OpConversionPattern;
614 
615   LogicalResult
616   matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
617                   ConversionPatternRewriter &rewriter) const override {
618     StringRef apiFuncName =
619         TypeSwitch<Type, StringRef>(op.operand().getType())
620             .Case<TokenType>([](Type) { return kEmplaceToken; })
621             .Case<ValueType>([](Type) { return kEmplaceValue; });
622 
623     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
624                                         adaptor.getOperands());
625 
626     return success();
627   }
628 };
629 } // namespace
630 
631 //===----------------------------------------------------------------------===//
632 // Convert async.runtime.set_error to the corresponding runtime API call.
633 //===----------------------------------------------------------------------===//
634 
635 namespace {
636 class RuntimeSetErrorOpLowering
637     : public OpConversionPattern<RuntimeSetErrorOp> {
638 public:
639   using OpConversionPattern::OpConversionPattern;
640 
641   LogicalResult
642   matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
643                   ConversionPatternRewriter &rewriter) const override {
644     StringRef apiFuncName =
645         TypeSwitch<Type, StringRef>(op.operand().getType())
646             .Case<TokenType>([](Type) { return kSetTokenError; })
647             .Case<ValueType>([](Type) { return kSetValueError; });
648 
649     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
650                                         adaptor.getOperands());
651 
652     return success();
653   }
654 };
655 } // namespace
656 
657 //===----------------------------------------------------------------------===//
658 // Convert async.runtime.is_error to the corresponding runtime API call.
659 //===----------------------------------------------------------------------===//
660 
661 namespace {
662 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
663 public:
664   using OpConversionPattern::OpConversionPattern;
665 
666   LogicalResult
667   matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
668                   ConversionPatternRewriter &rewriter) const override {
669     StringRef apiFuncName =
670         TypeSwitch<Type, StringRef>(op.operand().getType())
671             .Case<TokenType>([](Type) { return kIsTokenError; })
672             .Case<GroupType>([](Type) { return kIsGroupError; })
673             .Case<ValueType>([](Type) { return kIsValueError; });
674 
675     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
676                                         adaptor.getOperands());
677     return success();
678   }
679 };
680 } // namespace
681 
682 //===----------------------------------------------------------------------===//
683 // Convert async.runtime.await to the corresponding runtime API call.
684 //===----------------------------------------------------------------------===//
685 
686 namespace {
687 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
688 public:
689   using OpConversionPattern::OpConversionPattern;
690 
691   LogicalResult
692   matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
693                   ConversionPatternRewriter &rewriter) const override {
694     StringRef apiFuncName =
695         TypeSwitch<Type, StringRef>(op.operand().getType())
696             .Case<TokenType>([](Type) { return kAwaitToken; })
697             .Case<ValueType>([](Type) { return kAwaitValue; })
698             .Case<GroupType>([](Type) { return kAwaitGroup; });
699 
700     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
701                             adaptor.getOperands());
702     rewriter.eraseOp(op);
703 
704     return success();
705   }
706 };
707 } // namespace
708 
709 //===----------------------------------------------------------------------===//
710 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
711 //===----------------------------------------------------------------------===//
712 
713 namespace {
714 class RuntimeAwaitAndResumeOpLowering
715     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
716 public:
717   using OpConversionPattern::OpConversionPattern;
718 
719   LogicalResult
720   matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
721                   ConversionPatternRewriter &rewriter) const override {
722     StringRef apiFuncName =
723         TypeSwitch<Type, StringRef>(op.operand().getType())
724             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
725             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
726             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
727 
728     Value operand = adaptor.operand();
729     Value handle = adaptor.handle();
730 
731     // A pointer to coroutine resume intrinsic wrapper.
732     addResumeFunction(op->getParentOfType<ModuleOp>());
733     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
734     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
735         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
736 
737     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
738                             ValueRange({operand, handle, resumePtr.getRes()}));
739     rewriter.eraseOp(op);
740 
741     return success();
742   }
743 };
744 } // namespace
745 
746 //===----------------------------------------------------------------------===//
747 // Convert async.runtime.resume to the corresponding runtime API call.
748 //===----------------------------------------------------------------------===//
749 
750 namespace {
751 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
752 public:
753   using OpConversionPattern::OpConversionPattern;
754 
755   LogicalResult
756   matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
757                   ConversionPatternRewriter &rewriter) const override {
758     // A pointer to coroutine resume intrinsic wrapper.
759     addResumeFunction(op->getParentOfType<ModuleOp>());
760     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
761     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
762         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
763 
764     // Call async runtime API to execute a coroutine in the managed thread.
765     auto coroHdl = adaptor.handle();
766     rewriter.replaceOpWithNewOp<CallOp>(
767         op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
768 
769     return success();
770   }
771 };
772 } // namespace
773 
774 //===----------------------------------------------------------------------===//
775 // Convert async.runtime.store to the corresponding runtime API call.
776 //===----------------------------------------------------------------------===//
777 
778 namespace {
779 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
780 public:
781   using OpConversionPattern::OpConversionPattern;
782 
783   LogicalResult
784   matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
785                   ConversionPatternRewriter &rewriter) const override {
786     Location loc = op->getLoc();
787 
788     // Get a pointer to the async value storage from the runtime.
789     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
790     auto storage = adaptor.storage();
791     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
792                                               TypeRange(i8Ptr), storage);
793 
794     // Cast from i8* to the LLVM pointer type.
795     auto valueType = op.value().getType();
796     auto llvmValueType = getTypeConverter()->convertType(valueType);
797     if (!llvmValueType)
798       return rewriter.notifyMatchFailure(
799           op, "failed to convert stored value type to LLVM type");
800 
801     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
802         loc, LLVM::LLVMPointerType::get(llvmValueType),
803         storagePtr.getResult(0));
804 
805     // Store the yielded value into the async value storage.
806     auto value = adaptor.value();
807     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
808 
809     // Erase the original runtime store operation.
810     rewriter.eraseOp(op);
811 
812     return success();
813   }
814 };
815 } // namespace
816 
817 //===----------------------------------------------------------------------===//
818 // Convert async.runtime.load to the corresponding runtime API call.
819 //===----------------------------------------------------------------------===//
820 
821 namespace {
822 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
823 public:
824   using OpConversionPattern::OpConversionPattern;
825 
826   LogicalResult
827   matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
828                   ConversionPatternRewriter &rewriter) const override {
829     Location loc = op->getLoc();
830 
831     // Get a pointer to the async value storage from the runtime.
832     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
833     auto storage = adaptor.storage();
834     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
835                                               TypeRange(i8Ptr), storage);
836 
837     // Cast from i8* to the LLVM pointer type.
838     auto valueType = op.result().getType();
839     auto llvmValueType = getTypeConverter()->convertType(valueType);
840     if (!llvmValueType)
841       return rewriter.notifyMatchFailure(
842           op, "failed to convert loaded value type to LLVM type");
843 
844     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
845         loc, LLVM::LLVMPointerType::get(llvmValueType),
846         storagePtr.getResult(0));
847 
848     // Load from the casted pointer.
849     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
850 
851     return success();
852   }
853 };
854 } // namespace
855 
856 //===----------------------------------------------------------------------===//
857 // Convert async.runtime.add_to_group to the corresponding runtime API call.
858 //===----------------------------------------------------------------------===//
859 
860 namespace {
861 class RuntimeAddToGroupOpLowering
862     : public OpConversionPattern<RuntimeAddToGroupOp> {
863 public:
864   using OpConversionPattern::OpConversionPattern;
865 
866   LogicalResult
867   matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
868                   ConversionPatternRewriter &rewriter) const override {
869     // Currently we can only add tokens to the group.
870     if (!op.operand().getType().isa<TokenType>())
871       return rewriter.notifyMatchFailure(op, "only token type is supported");
872 
873     // Replace with a runtime API function call.
874     rewriter.replaceOpWithNewOp<CallOp>(
875         op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
876 
877     return success();
878   }
879 };
880 } // namespace
881 
882 //===----------------------------------------------------------------------===//
883 // Async reference counting ops lowering (`async.runtime.add_ref` and
884 // `async.runtime.drop_ref` to the corresponding API calls).
885 //===----------------------------------------------------------------------===//
886 
887 namespace {
888 template <typename RefCountingOp>
889 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
890 public:
891   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
892                                  StringRef apiFunctionName)
893       : OpConversionPattern<RefCountingOp>(converter, ctx),
894         apiFunctionName(apiFunctionName) {}
895 
896   LogicalResult
897   matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
898                   ConversionPatternRewriter &rewriter) const override {
899     auto count = rewriter.create<arith::ConstantOp>(
900         op->getLoc(), rewriter.getI64Type(),
901         rewriter.getI64IntegerAttr(op.count()));
902 
903     auto operand = adaptor.operand();
904     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
905                                         ValueRange({operand, count}));
906 
907     return success();
908   }
909 
910 private:
911   StringRef apiFunctionName;
912 };
913 
914 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
915 public:
916   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
917       : RefCountingOpLowering(converter, ctx, kAddRef) {}
918 };
919 
920 class RuntimeDropRefOpLowering
921     : public RefCountingOpLowering<RuntimeDropRefOp> {
922 public:
923   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
924       : RefCountingOpLowering(converter, ctx, kDropRef) {}
925 };
926 } // namespace
927 
928 //===----------------------------------------------------------------------===//
929 // Convert return operations that return async values from async regions.
930 //===----------------------------------------------------------------------===//
931 
932 namespace {
933 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
934 public:
935   using OpConversionPattern::OpConversionPattern;
936 
937   LogicalResult
938   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
939                   ConversionPatternRewriter &rewriter) const override {
940     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
941     return success();
942   }
943 };
944 } // namespace
945 
946 //===----------------------------------------------------------------------===//
947 
948 namespace {
949 struct ConvertAsyncToLLVMPass
950     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
951   void runOnOperation() override;
952 };
953 } // namespace
954 
955 void ConvertAsyncToLLVMPass::runOnOperation() {
956   ModuleOp module = getOperation();
957   MLIRContext *ctx = module->getContext();
958 
959   // Add declarations for most functions required by the coroutines lowering.
960   // We delay adding the resume function until it's needed because it currently
961   // fails to compile unless '-O0' is specified.
962   addAsyncRuntimeApiDeclarations(module);
963 
964   // Lower async.runtime and async.coro operations to Async Runtime API and
965   // LLVM coroutine intrinsics.
966 
967   // Convert async dialect types and operations to LLVM dialect.
968   AsyncRuntimeTypeConverter converter;
969   RewritePatternSet patterns(ctx);
970 
971   // We use conversion to LLVM type to lower async.runtime load and store
972   // operations.
973   LLVMTypeConverter llvmConverter(ctx);
974   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
975 
976   // Convert async types in function signatures and function calls.
977   populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, converter);
978   populateCallOpTypeConversionPattern(patterns, converter);
979 
980   // Convert return operations inside async.execute regions.
981   patterns.add<ReturnOpOpConversion>(converter, ctx);
982 
983   // Lower async.runtime operations to the async runtime API calls.
984   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
985                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
986                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
987                RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
988                RuntimeDropRefOpLowering>(converter, ctx);
989 
990   // Lower async.runtime operations that rely on LLVM type converter to convert
991   // from async value payload type to the LLVM type.
992   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
993                RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
994                                                               ctx);
995 
996   // Lower async coroutine operations to LLVM coroutine intrinsics.
997   patterns
998       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
999            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1000           converter, ctx);
1001 
1002   ConversionTarget target(*ctx);
1003   target
1004       .addLegalOp<arith::ConstantOp, ConstantOp, UnrealizedConversionCastOp>();
1005   target.addLegalDialect<LLVM::LLVMDialect>();
1006 
1007   // All operations from Async dialect must be lowered to the runtime API and
1008   // LLVM intrinsics calls.
1009   target.addIllegalDialect<AsyncDialect>();
1010 
1011   // Add dynamic legality constraints to apply conversions defined above.
1012   target.addDynamicallyLegalOp<FuncOp>(
1013       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1014   target.addDynamicallyLegalOp<ReturnOp>(
1015       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1016   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1017     return converter.isSignatureLegal(op.getCalleeType());
1018   });
1019 
1020   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1021     signalPassFailure();
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // Patterns for structural type conversions for the Async dialect operations.
1026 //===----------------------------------------------------------------------===//
1027 
1028 namespace {
1029 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1030 public:
1031   using OpConversionPattern::OpConversionPattern;
1032   LogicalResult
1033   matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1034                   ConversionPatternRewriter &rewriter) const override {
1035     ExecuteOp newOp =
1036         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1037     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1038                                 newOp.getRegion().end());
1039 
1040     // Set operands and update block argument and result types.
1041     newOp->setOperands(adaptor.getOperands());
1042     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1043       return failure();
1044     for (auto result : newOp.getResults())
1045       result.setType(typeConverter->convertType(result.getType()));
1046 
1047     rewriter.replaceOp(op, newOp.getResults());
1048     return success();
1049   }
1050 };
1051 
1052 // Dummy pattern to trigger the appropriate type conversion / materialization.
1053 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1054 public:
1055   using OpConversionPattern::OpConversionPattern;
1056   LogicalResult
1057   matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1058                   ConversionPatternRewriter &rewriter) const override {
1059     rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1060     return success();
1061   }
1062 };
1063 
1064 // Dummy pattern to trigger the appropriate type conversion / materialization.
1065 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1066 public:
1067   using OpConversionPattern::OpConversionPattern;
1068   LogicalResult
1069   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1070                   ConversionPatternRewriter &rewriter) const override {
1071     rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1072     return success();
1073   }
1074 };
1075 } // namespace
1076 
1077 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1078   return std::make_unique<ConvertAsyncToLLVMPass>();
1079 }
1080 
1081 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1082     TypeConverter &typeConverter, RewritePatternSet &patterns,
1083     ConversionTarget &target) {
1084   typeConverter.addConversion([&](TokenType type) { return type; });
1085   typeConverter.addConversion([&](ValueType type) {
1086     Type converted = typeConverter.convertType(type.getValueType());
1087     return converted ? ValueType::get(converted) : converted;
1088   });
1089 
1090   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1091       typeConverter, patterns.getContext());
1092 
1093   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1094       [&](Operation *op) { return typeConverter.isLegal(op); });
1095 }
1096