1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Async/IR/Async.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 #define DEBUG_TYPE "convert-async-to-llvm"
23 
24 using namespace mlir;
25 using namespace mlir::async;
26 
27 //===----------------------------------------------------------------------===//
28 // Async Runtime C API declaration.
29 //===----------------------------------------------------------------------===//
30 
31 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
32 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
33 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
34 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
35 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
36 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
37 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
39 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
42 static constexpr const char *kGetValueStorage =
43     "mlirAsyncRuntimeGetValueStorage";
44 static constexpr const char *kAddTokenToGroup =
45     "mlirAsyncRuntimeAddTokenToGroup";
46 static constexpr const char *kAwaitTokenAndExecute =
47     "mlirAsyncRuntimeAwaitTokenAndExecute";
48 static constexpr const char *kAwaitValueAndExecute =
49     "mlirAsyncRuntimeAwaitValueAndExecute";
50 static constexpr const char *kAwaitAllAndExecute =
51     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
52 
53 namespace {
54 /// Async Runtime API function types.
55 ///
56 /// Because we can't create API function signature for type parametrized
57 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
58 /// lowering all async data types become opaque pointers at runtime.
59 struct AsyncAPI {
60   // All async types are lowered to opaque i8* LLVM pointers at runtime.
61   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
62     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
63   }
64 
65   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
66     return LLVM::LLVMTokenType::get(ctx);
67   }
68 
69   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
70     auto ref = opaquePointerType(ctx);
71     auto count = IntegerType::get(ctx, 32);
72     return FunctionType::get(ctx, {ref, count}, {});
73   }
74 
75   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
76     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
77   }
78 
79   static FunctionType createValueFunctionType(MLIRContext *ctx) {
80     auto i32 = IntegerType::get(ctx, 32);
81     auto value = opaquePointerType(ctx);
82     return FunctionType::get(ctx, {i32}, {value});
83   }
84 
85   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
86     return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
87   }
88 
89   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
90     auto value = opaquePointerType(ctx);
91     auto storage = opaquePointerType(ctx);
92     return FunctionType::get(ctx, {value}, {storage});
93   }
94 
95   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
96     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
97   }
98 
99   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
100     auto value = opaquePointerType(ctx);
101     return FunctionType::get(ctx, {value}, {});
102   }
103 
104   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
105     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
106   }
107 
108   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
109     auto value = opaquePointerType(ctx);
110     return FunctionType::get(ctx, {value}, {});
111   }
112 
113   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
114     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
115   }
116 
117   static FunctionType executeFunctionType(MLIRContext *ctx) {
118     auto hdl = opaquePointerType(ctx);
119     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
120     return FunctionType::get(ctx, {hdl, resume}, {});
121   }
122 
123   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
124     auto i64 = IntegerType::get(ctx, 64);
125     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
126                              {i64});
127   }
128 
129   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
130     auto hdl = opaquePointerType(ctx);
131     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
132     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
133   }
134 
135   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
136     auto value = opaquePointerType(ctx);
137     auto hdl = opaquePointerType(ctx);
138     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
139     return FunctionType::get(ctx, {value, hdl, resume}, {});
140   }
141 
142   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
143     auto hdl = opaquePointerType(ctx);
144     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
145     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
146   }
147 
148   // Auxiliary coroutine resume intrinsic wrapper.
149   static Type resumeFunctionType(MLIRContext *ctx) {
150     auto voidTy = LLVM::LLVMVoidType::get(ctx);
151     auto i8Ptr = opaquePointerType(ctx);
152     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
153   }
154 };
155 } // namespace
156 
157 /// Adds Async Runtime C API declarations to the module.
158 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
159   auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
160                                                          module.getBody());
161 
162   auto addFuncDecl = [&](StringRef name, FunctionType type) {
163     if (module.lookupSymbol(name))
164       return;
165     builder.create<FuncOp>(name, type).setPrivate();
166   };
167 
168   MLIRContext *ctx = module.getContext();
169   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
170   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
171   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
172   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
173   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
174   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
175   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
176   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
177   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
178   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
179   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
180   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
181   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
182   addFuncDecl(kAwaitTokenAndExecute,
183               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
184   addFuncDecl(kAwaitValueAndExecute,
185               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
186   addFuncDecl(kAwaitAllAndExecute,
187               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // Add malloc/free declarations to the module.
192 //===----------------------------------------------------------------------===//
193 
194 static constexpr const char *kMalloc = "malloc";
195 static constexpr const char *kFree = "free";
196 
197 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
198                             StringRef name, Type ret, ArrayRef<Type> params) {
199   if (module.lookupSymbol(name))
200     return;
201   Type type = LLVM::LLVMFunctionType::get(ret, params);
202   builder.create<LLVM::LLVMFuncOp>(name, type);
203 }
204 
205 /// Adds malloc/free declarations to the module.
206 static void addCRuntimeDeclarations(ModuleOp module) {
207   using namespace mlir::LLVM;
208 
209   MLIRContext *ctx = module.getContext();
210   ImplicitLocOpBuilder builder(module.getLoc(),
211                                module.getBody()->getTerminator());
212 
213   auto voidTy = LLVMVoidType::get(ctx);
214   auto i64 = IntegerType::get(ctx, 64);
215   auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8));
216 
217   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
218   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // Coroutine resume function wrapper.
223 //===----------------------------------------------------------------------===//
224 
225 static constexpr const char *kResume = "__resume";
226 
227 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
228 /// intrinsics. We need this function to be able to pass it to the async
229 /// runtime execute API.
230 static void addResumeFunction(ModuleOp module) {
231   if (module.lookupSymbol(kResume))
232     return;
233 
234   MLIRContext *ctx = module.getContext();
235 
236   OpBuilder moduleBuilder(module.getBody()->getTerminator());
237   Location loc = module.getLoc();
238 
239   auto voidTy = LLVM::LLVMVoidType::get(ctx);
240   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
241 
242   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
243       loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
244   resumeOp.setPrivate();
245 
246   auto *block = resumeOp.addEntryBlock();
247   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
248 
249   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
250   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // Convert Async dialect types to LLVM types.
255 //===----------------------------------------------------------------------===//
256 
257 namespace {
258 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
259 /// their runtime type (opaque pointers) and does not convert any other types.
260 class AsyncRuntimeTypeConverter : public TypeConverter {
261 public:
262   AsyncRuntimeTypeConverter() {
263     addConversion([](Type type) { return type; });
264     addConversion(convertAsyncTypes);
265   }
266 
267   static Optional<Type> convertAsyncTypes(Type type) {
268     if (type.isa<TokenType, GroupType, ValueType>())
269       return AsyncAPI::opaquePointerType(type.getContext());
270 
271     if (type.isa<CoroIdType, CoroStateType>())
272       return AsyncAPI::tokenType(type.getContext());
273     if (type.isa<CoroHandleType>())
274       return AsyncAPI::opaquePointerType(type.getContext());
275 
276     return llvm::None;
277   }
278 };
279 } // namespace
280 
281 //===----------------------------------------------------------------------===//
282 // Convert async.coro.id to @llvm.coro.id intrinsic.
283 //===----------------------------------------------------------------------===//
284 
285 namespace {
286 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
287 public:
288   using OpConversionPattern::OpConversionPattern;
289 
290   LogicalResult
291   matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands,
292                   ConversionPatternRewriter &rewriter) const override {
293     auto token = AsyncAPI::tokenType(op->getContext());
294     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
295     auto loc = op->getLoc();
296 
297     // Constants for initializing coroutine frame.
298     auto constZero = rewriter.create<LLVM::ConstantOp>(
299         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
300     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
301 
302     // Get coroutine id: @llvm.coro.id.
303     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
304         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
305 
306     return success();
307   }
308 };
309 } // namespace
310 
311 //===----------------------------------------------------------------------===//
312 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
313 //===----------------------------------------------------------------------===//
314 
315 namespace {
316 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
317 public:
318   using OpConversionPattern::OpConversionPattern;
319 
320   LogicalResult
321   matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands,
322                   ConversionPatternRewriter &rewriter) const override {
323     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
324     auto loc = op->getLoc();
325 
326     // Get coroutine frame size: @llvm.coro.size.i64.
327     auto coroSize =
328         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
329 
330     // Allocate memory for the coroutine frame.
331     auto coroAlloc = rewriter.create<LLVM::CallOp>(
332         loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc),
333         ValueRange(coroSize.getResult()));
334 
335     // Begin a coroutine: @llvm.coro.begin.
336     auto coroId = CoroBeginOpAdaptor(operands).id();
337     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
338         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
339 
340     return success();
341   }
342 };
343 } // namespace
344 
345 //===----------------------------------------------------------------------===//
346 // Convert async.coro.free to @llvm.coro.free intrinsic.
347 //===----------------------------------------------------------------------===//
348 
349 namespace {
350 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
351 public:
352   using OpConversionPattern::OpConversionPattern;
353 
354   LogicalResult
355   matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands,
356                   ConversionPatternRewriter &rewriter) const override {
357     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
358     auto loc = op->getLoc();
359 
360     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
361     auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands);
362 
363     // Free the memory.
364     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
365                                               rewriter.getSymbolRefAttr(kFree),
366                                               ValueRange(coroMem.getResult()));
367 
368     return success();
369   }
370 };
371 } // namespace
372 
373 //===----------------------------------------------------------------------===//
374 // Convert async.coro.end to @llvm.coro.end intrinsic.
375 //===----------------------------------------------------------------------===//
376 
377 namespace {
378 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
379 public:
380   using OpConversionPattern::OpConversionPattern;
381 
382   LogicalResult
383   matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands,
384                   ConversionPatternRewriter &rewriter) const override {
385     // We are not in the block that is part of the unwind sequence.
386     auto constFalse = rewriter.create<LLVM::ConstantOp>(
387         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
388 
389     // Mark the end of a coroutine: @llvm.coro.end.
390     auto coroHdl = CoroEndOpAdaptor(operands).handle();
391     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
392                                      ValueRange({coroHdl, constFalse}));
393     rewriter.eraseOp(op);
394 
395     return success();
396   }
397 };
398 } // namespace
399 
400 //===----------------------------------------------------------------------===//
401 // Convert async.coro.save to @llvm.coro.save intrinsic.
402 //===----------------------------------------------------------------------===//
403 
404 namespace {
405 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
406 public:
407   using OpConversionPattern::OpConversionPattern;
408 
409   LogicalResult
410   matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands,
411                   ConversionPatternRewriter &rewriter) const override {
412     // Save the coroutine state: @llvm.coro.save
413     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
414         op, AsyncAPI::tokenType(op->getContext()), operands);
415 
416     return success();
417   }
418 };
419 } // namespace
420 
421 //===----------------------------------------------------------------------===//
422 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
423 //===----------------------------------------------------------------------===//
424 
425 namespace {
426 
427 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
428 /// branch to the appropriate block based on the return code.
429 ///
430 /// Before:
431 ///
432 ///   ^suspended:
433 ///     "opBefore"(...)
434 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
435 ///   ^resume:
436 ///     "op"(...)
437 ///   ^cleanup: ...
438 ///   ^suspend: ...
439 ///
440 /// After:
441 ///
442 ///   ^suspended:
443 ///     "opBefore"(...)
444 ///     %suspend = llmv.intr.coro.suspend ...
445 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
446 ///   ^resume:
447 ///     "op"(...)
448 ///   ^cleanup: ...
449 ///   ^suspend: ...
450 ///
451 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
452 public:
453   using OpConversionPattern::OpConversionPattern;
454 
455   LogicalResult
456   matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands,
457                   ConversionPatternRewriter &rewriter) const override {
458     auto i8 = rewriter.getIntegerType(8);
459     auto i32 = rewriter.getI32Type();
460     auto loc = op->getLoc();
461 
462     // This is not a final suspension point.
463     auto constFalse = rewriter.create<LLVM::ConstantOp>(
464         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
465 
466     // Suspend a coroutine: @llvm.coro.suspend
467     auto coroState = CoroSuspendOpAdaptor(operands).state();
468     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
469         loc, i8, ValueRange({coroState, constFalse}));
470 
471     // Cast return code to i32.
472 
473     // After a suspension point decide if we should branch into resume, cleanup
474     // or suspend block of the coroutine (see @llvm.coro.suspend return code
475     // documentation).
476     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
477     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
478                                               op.cleanupDest()};
479     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
480         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
481         /*defaultDestination=*/op.suspendDest(),
482         /*defaultOperands=*/ValueRange(),
483         /*caseValues=*/caseValues,
484         /*caseDestinations=*/caseDest,
485         /*caseOperands=*/ArrayRef<ValueRange>(),
486         /*branchWeights=*/ArrayRef<int32_t>());
487 
488     return success();
489   }
490 };
491 } // namespace
492 
493 //===----------------------------------------------------------------------===//
494 // Convert async.runtime.create to the corresponding runtime API call.
495 //
496 // To allocate storage for the async values we use getelementptr trick:
497 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
498 //===----------------------------------------------------------------------===//
499 
500 namespace {
501 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
502 public:
503   using OpConversionPattern::OpConversionPattern;
504 
505   LogicalResult
506   matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands,
507                   ConversionPatternRewriter &rewriter) const override {
508     TypeConverter *converter = getTypeConverter();
509     Type resultType = op->getResultTypes()[0];
510 
511     // Tokens and Groups lowered to function calls without arguments.
512     if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
513       rewriter.replaceOpWithNewOp<CallOp>(
514           op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
515           converter->convertType(resultType));
516       return success();
517     }
518 
519     // To create a value we need to compute the storage requirement.
520     if (auto value = resultType.dyn_cast<ValueType>()) {
521       // Returns the size requirements for the async value storage.
522       auto sizeOf = [&](ValueType valueType) -> Value {
523         auto loc = op->getLoc();
524         auto i32 = rewriter.getI32Type();
525 
526         auto storedType = converter->convertType(valueType.getValueType());
527         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
528 
529         // %Size = getelementptr %T* null, int 1
530         // %SizeI = ptrtoint %T* %Size to i32
531         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
532         auto one = rewriter.create<LLVM::ConstantOp>(
533             loc, i32, rewriter.getI32IntegerAttr(1));
534         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
535                                                 one.getResult());
536         return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep);
537       };
538 
539       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
540                                           sizeOf(value));
541 
542       return success();
543     }
544 
545     return rewriter.notifyMatchFailure(op, "unsupported async type");
546   }
547 };
548 } // namespace
549 
550 //===----------------------------------------------------------------------===//
551 // Convert async.runtime.set_available to the corresponding runtime API call.
552 //===----------------------------------------------------------------------===//
553 
554 namespace {
555 class RuntimeSetAvailableOpLowering
556     : public OpConversionPattern<RuntimeSetAvailableOp> {
557 public:
558   using OpConversionPattern::OpConversionPattern;
559 
560   LogicalResult
561   matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
562                   ConversionPatternRewriter &rewriter) const override {
563     Type operandType = op.operand().getType();
564 
565     if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) {
566       rewriter.create<CallOp>(op->getLoc(),
567                               operandType.isa<TokenType>() ? kEmplaceToken
568                                                            : kEmplaceValue,
569                               TypeRange(), operands);
570       rewriter.eraseOp(op);
571       return success();
572     }
573 
574     return rewriter.notifyMatchFailure(op, "unsupported async type");
575   }
576 };
577 } // namespace
578 
579 //===----------------------------------------------------------------------===//
580 // Convert async.runtime.await to the corresponding runtime API call.
581 //===----------------------------------------------------------------------===//
582 
583 namespace {
584 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
585 public:
586   using OpConversionPattern::OpConversionPattern;
587 
588   LogicalResult
589   matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
590                   ConversionPatternRewriter &rewriter) const override {
591     Type operandType = op.operand().getType();
592 
593     StringRef apiFuncName;
594     if (operandType.isa<TokenType>())
595       apiFuncName = kAwaitToken;
596     else if (operandType.isa<ValueType>())
597       apiFuncName = kAwaitValue;
598     else if (operandType.isa<GroupType>())
599       apiFuncName = kAwaitGroup;
600     else
601       return rewriter.notifyMatchFailure(op, "unsupported async type");
602 
603     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
604     rewriter.eraseOp(op);
605 
606     return success();
607   }
608 };
609 } // namespace
610 
611 //===----------------------------------------------------------------------===//
612 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
613 //===----------------------------------------------------------------------===//
614 
615 namespace {
616 class RuntimeAwaitAndResumeOpLowering
617     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
618 public:
619   using OpConversionPattern::OpConversionPattern;
620 
621   LogicalResult
622   matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
623                   ConversionPatternRewriter &rewriter) const override {
624     Type operandType = op.operand().getType();
625 
626     StringRef apiFuncName;
627     if (operandType.isa<TokenType>())
628       apiFuncName = kAwaitTokenAndExecute;
629     else if (operandType.isa<ValueType>())
630       apiFuncName = kAwaitValueAndExecute;
631     else if (operandType.isa<GroupType>())
632       apiFuncName = kAwaitAllAndExecute;
633     else
634       return rewriter.notifyMatchFailure(op, "unsupported async type");
635 
636     Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
637     Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
638 
639     // A pointer to coroutine resume intrinsic wrapper.
640     addResumeFunction(op->getParentOfType<ModuleOp>());
641     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
642     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
643         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
644 
645     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
646                             ValueRange({operand, handle, resumePtr.res()}));
647     rewriter.eraseOp(op);
648 
649     return success();
650   }
651 };
652 } // namespace
653 
654 //===----------------------------------------------------------------------===//
655 // Convert async.runtime.resume to the corresponding runtime API call.
656 //===----------------------------------------------------------------------===//
657 
658 namespace {
659 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
660 public:
661   using OpConversionPattern::OpConversionPattern;
662 
663   LogicalResult
664   matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands,
665                   ConversionPatternRewriter &rewriter) const override {
666     // A pointer to coroutine resume intrinsic wrapper.
667     addResumeFunction(op->getParentOfType<ModuleOp>());
668     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
669     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
670         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
671 
672     // Call async runtime API to execute a coroutine in the managed thread.
673     auto coroHdl = RuntimeResumeOpAdaptor(operands).handle();
674     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute,
675                                         ValueRange({coroHdl, resumePtr.res()}));
676 
677     return success();
678   }
679 };
680 } // namespace
681 
682 //===----------------------------------------------------------------------===//
683 // Convert async.runtime.store to the corresponding runtime API call.
684 //===----------------------------------------------------------------------===//
685 
686 namespace {
687 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
688 public:
689   using OpConversionPattern::OpConversionPattern;
690 
691   LogicalResult
692   matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands,
693                   ConversionPatternRewriter &rewriter) const override {
694     Location loc = op->getLoc();
695 
696     // Get a pointer to the async value storage from the runtime.
697     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
698     auto storage = RuntimeStoreOpAdaptor(operands).storage();
699     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
700                                               TypeRange(i8Ptr), storage);
701 
702     // Cast from i8* to the LLVM pointer type.
703     auto valueType = op.value().getType();
704     auto llvmValueType = getTypeConverter()->convertType(valueType);
705     if (!llvmValueType)
706       return rewriter.notifyMatchFailure(
707           op, "failed to convert stored value type to LLVM type");
708 
709     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
710         loc, LLVM::LLVMPointerType::get(llvmValueType),
711         storagePtr.getResult(0));
712 
713     // Store the yielded value into the async value storage.
714     auto value = RuntimeStoreOpAdaptor(operands).value();
715     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
716 
717     // Erase the original runtime store operation.
718     rewriter.eraseOp(op);
719 
720     return success();
721   }
722 };
723 } // namespace
724 
725 //===----------------------------------------------------------------------===//
726 // Convert async.runtime.load to the corresponding runtime API call.
727 //===----------------------------------------------------------------------===//
728 
729 namespace {
730 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
731 public:
732   using OpConversionPattern::OpConversionPattern;
733 
734   LogicalResult
735   matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands,
736                   ConversionPatternRewriter &rewriter) const override {
737     Location loc = op->getLoc();
738 
739     // Get a pointer to the async value storage from the runtime.
740     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
741     auto storage = RuntimeLoadOpAdaptor(operands).storage();
742     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
743                                               TypeRange(i8Ptr), storage);
744 
745     // Cast from i8* to the LLVM pointer type.
746     auto valueType = op.result().getType();
747     auto llvmValueType = getTypeConverter()->convertType(valueType);
748     if (!llvmValueType)
749       return rewriter.notifyMatchFailure(
750           op, "failed to convert loaded value type to LLVM type");
751 
752     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
753         loc, LLVM::LLVMPointerType::get(llvmValueType),
754         storagePtr.getResult(0));
755 
756     // Load from the casted pointer.
757     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
758 
759     return success();
760   }
761 };
762 } // namespace
763 
764 //===----------------------------------------------------------------------===//
765 // Convert async.runtime.add_to_group to the corresponding runtime API call.
766 //===----------------------------------------------------------------------===//
767 
768 namespace {
769 class RuntimeAddToGroupOpLowering
770     : public OpConversionPattern<RuntimeAddToGroupOp> {
771 public:
772   using OpConversionPattern::OpConversionPattern;
773 
774   LogicalResult
775   matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands,
776                   ConversionPatternRewriter &rewriter) const override {
777     // Currently we can only add tokens to the group.
778     if (!op.operand().getType().isa<TokenType>())
779       return rewriter.notifyMatchFailure(op, "only token type is supported");
780 
781     // Replace with a runtime API function call.
782     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup,
783                                         rewriter.getI64Type(), operands);
784 
785     return success();
786   }
787 };
788 } // namespace
789 
790 //===----------------------------------------------------------------------===//
791 // Async reference counting ops lowering (`async.runtime.add_ref` and
792 // `async.runtime.drop_ref` to the corresponding API calls).
793 //===----------------------------------------------------------------------===//
794 
795 namespace {
796 template <typename RefCountingOp>
797 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
798 public:
799   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
800                                  StringRef apiFunctionName)
801       : OpConversionPattern<RefCountingOp>(converter, ctx),
802         apiFunctionName(apiFunctionName) {}
803 
804   LogicalResult
805   matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands,
806                   ConversionPatternRewriter &rewriter) const override {
807     auto count =
808         rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(),
809                                     rewriter.getI32IntegerAttr(op.count()));
810 
811     auto operand = typename RefCountingOp::Adaptor(operands).operand();
812     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
813                                         ValueRange({operand, count}));
814 
815     return success();
816   }
817 
818 private:
819   StringRef apiFunctionName;
820 };
821 
822 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
823 public:
824   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
825       : RefCountingOpLowering(converter, ctx, kAddRef) {}
826 };
827 
828 class RuntimeDropRefOpLowering
829     : public RefCountingOpLowering<RuntimeDropRefOp> {
830 public:
831   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
832       : RefCountingOpLowering(converter, ctx, kDropRef) {}
833 };
834 } // namespace
835 
836 //===----------------------------------------------------------------------===//
837 // Convert return operations that return async values from async regions.
838 //===----------------------------------------------------------------------===//
839 
840 namespace {
841 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
842 public:
843   using OpConversionPattern::OpConversionPattern;
844 
845   LogicalResult
846   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
847                   ConversionPatternRewriter &rewriter) const override {
848     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
849     return success();
850   }
851 };
852 } // namespace
853 
854 //===----------------------------------------------------------------------===//
855 
856 namespace {
857 struct ConvertAsyncToLLVMPass
858     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
859   void runOnOperation() override;
860 };
861 } // namespace
862 
863 void ConvertAsyncToLLVMPass::runOnOperation() {
864   ModuleOp module = getOperation();
865   MLIRContext *ctx = module->getContext();
866 
867   // Add declarations for most functions required by the coroutines lowering.
868   // We delay adding the resume function until it's needed because it currently
869   // fails to compile unless '-O0' is specified.
870   addAsyncRuntimeApiDeclarations(module);
871   addCRuntimeDeclarations(module);
872 
873   // Lower async.runtime and async.coro operations to Async Runtime API and
874   // LLVM coroutine intrinsics.
875 
876   // Convert async dialect types and operations to LLVM dialect.
877   AsyncRuntimeTypeConverter converter;
878   OwningRewritePatternList patterns;
879 
880   // We use conversion to LLVM type to lower async.runtime load and store
881   // operations.
882   LLVMTypeConverter llvmConverter(ctx);
883   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
884 
885   // Convert async types in function signatures and function calls.
886   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
887   populateCallOpTypeConversionPattern(patterns, ctx, converter);
888 
889   // Convert return operations inside async.execute regions.
890   patterns.insert<ReturnOpOpConversion>(converter, ctx);
891 
892   // Lower async.runtime operations to the async runtime API calls.
893   patterns.insert<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
894                   RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
895                   RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
896                   RuntimeDropRefOpLowering>(converter, ctx);
897 
898   // Lower async.runtime operations that rely on LLVM type converter to convert
899   // from async value payload type to the LLVM type.
900   patterns.insert<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
901                   RuntimeLoadOpLowering>(llvmConverter, ctx);
902 
903   // Lower async coroutine operations to LLVM coroutine intrinsics.
904   patterns.insert<CoroIdOpConversion, CoroBeginOpConversion,
905                   CoroFreeOpConversion, CoroEndOpConversion,
906                   CoroSaveOpConversion, CoroSuspendOpConversion>(converter,
907                                                                  ctx);
908 
909   ConversionTarget target(*ctx);
910   target.addLegalOp<ConstantOp>();
911   target.addLegalDialect<LLVM::LLVMDialect>();
912 
913   // All operations from Async dialect must be lowered to the runtime API and
914   // LLVM intrinsics calls.
915   target.addIllegalDialect<AsyncDialect>();
916 
917   // Add dynamic legality constraints to apply conversions defined above.
918   target.addDynamicallyLegalOp<FuncOp>(
919       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
920   target.addDynamicallyLegalOp<ReturnOp>(
921       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
922   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
923     return converter.isSignatureLegal(op.getCalleeType());
924   });
925 
926   if (failed(applyPartialConversion(module, target, std::move(patterns))))
927     signalPassFailure();
928 }
929 
930 //===----------------------------------------------------------------------===//
931 // Patterns for structural type conversions for the Async dialect operations.
932 //===----------------------------------------------------------------------===//
933 
934 namespace {
935 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
936 public:
937   using OpConversionPattern::OpConversionPattern;
938   LogicalResult
939   matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
940                   ConversionPatternRewriter &rewriter) const override {
941     ExecuteOp newOp =
942         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
943     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
944                                 newOp.getRegion().end());
945 
946     // Set operands and update block argument and result types.
947     newOp->setOperands(operands);
948     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
949       return failure();
950     for (auto result : newOp.getResults())
951       result.setType(typeConverter->convertType(result.getType()));
952 
953     rewriter.replaceOp(op, newOp.getResults());
954     return success();
955   }
956 };
957 
958 // Dummy pattern to trigger the appropriate type conversion / materialization.
959 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
960 public:
961   using OpConversionPattern::OpConversionPattern;
962   LogicalResult
963   matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
964                   ConversionPatternRewriter &rewriter) const override {
965     rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
966     return success();
967   }
968 };
969 
970 // Dummy pattern to trigger the appropriate type conversion / materialization.
971 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
972 public:
973   using OpConversionPattern::OpConversionPattern;
974   LogicalResult
975   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
976                   ConversionPatternRewriter &rewriter) const override {
977     rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
978     return success();
979   }
980 };
981 } // namespace
982 
983 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
984   return std::make_unique<ConvertAsyncToLLVMPass>();
985 }
986 
987 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
988     MLIRContext *context, TypeConverter &typeConverter,
989     OwningRewritePatternList &patterns, ConversionTarget &target) {
990   typeConverter.addConversion([&](TokenType type) { return type; });
991   typeConverter.addConversion([&](ValueType type) {
992     return ValueType::get(typeConverter.convertType(type.getValueType()));
993   });
994 
995   patterns
996       .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
997           typeConverter, context);
998 
999   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1000       [&](Operation *op) { return typeConverter.isLegal(op); });
1001 }
1002