1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Async/IR/Async.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 #define DEBUG_TYPE "convert-async-to-llvm"
24 
25 using namespace mlir;
26 using namespace mlir::async;
27 
28 //===----------------------------------------------------------------------===//
29 // Async Runtime C API declaration.
30 //===----------------------------------------------------------------------===//
31 
32 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
33 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
34 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
35 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
36 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
37 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
38 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
39 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
40 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
41 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
42 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
43 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
44 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
45 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
46 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
47 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
48 static constexpr const char *kGetValueStorage =
49     "mlirAsyncRuntimeGetValueStorage";
50 static constexpr const char *kAddTokenToGroup =
51     "mlirAsyncRuntimeAddTokenToGroup";
52 static constexpr const char *kAwaitTokenAndExecute =
53     "mlirAsyncRuntimeAwaitTokenAndExecute";
54 static constexpr const char *kAwaitValueAndExecute =
55     "mlirAsyncRuntimeAwaitValueAndExecute";
56 static constexpr const char *kAwaitAllAndExecute =
57     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
58 
59 namespace {
60 /// Async Runtime API function types.
61 ///
62 /// Because we can't create API function signature for type parametrized
63 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
64 /// lowering all async data types become opaque pointers at runtime.
65 struct AsyncAPI {
66   // All async types are lowered to opaque i8* LLVM pointers at runtime.
67   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
68     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
69   }
70 
71   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
72     return LLVM::LLVMTokenType::get(ctx);
73   }
74 
75   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
76     auto ref = opaquePointerType(ctx);
77     auto count = IntegerType::get(ctx, 32);
78     return FunctionType::get(ctx, {ref, count}, {});
79   }
80 
81   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
82     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
83   }
84 
85   static FunctionType createValueFunctionType(MLIRContext *ctx) {
86     auto i32 = IntegerType::get(ctx, 32);
87     auto value = opaquePointerType(ctx);
88     return FunctionType::get(ctx, {i32}, {value});
89   }
90 
91   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
92     auto i64 = IntegerType::get(ctx, 64);
93     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
94   }
95 
96   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
97     auto value = opaquePointerType(ctx);
98     auto storage = opaquePointerType(ctx);
99     return FunctionType::get(ctx, {value}, {storage});
100   }
101 
102   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
103     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
104   }
105 
106   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
107     auto value = opaquePointerType(ctx);
108     return FunctionType::get(ctx, {value}, {});
109   }
110 
111   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
112     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
113   }
114 
115   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
116     auto value = opaquePointerType(ctx);
117     return FunctionType::get(ctx, {value}, {});
118   }
119 
120   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
121     auto i1 = IntegerType::get(ctx, 1);
122     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
123   }
124 
125   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
126     auto value = opaquePointerType(ctx);
127     auto i1 = IntegerType::get(ctx, 1);
128     return FunctionType::get(ctx, {value}, {i1});
129   }
130 
131   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
132     auto i1 = IntegerType::get(ctx, 1);
133     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
134   }
135 
136   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
137     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
138   }
139 
140   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
141     auto value = opaquePointerType(ctx);
142     return FunctionType::get(ctx, {value}, {});
143   }
144 
145   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
146     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
147   }
148 
149   static FunctionType executeFunctionType(MLIRContext *ctx) {
150     auto hdl = opaquePointerType(ctx);
151     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
152     return FunctionType::get(ctx, {hdl, resume}, {});
153   }
154 
155   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
156     auto i64 = IntegerType::get(ctx, 64);
157     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
158                              {i64});
159   }
160 
161   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
162     auto hdl = opaquePointerType(ctx);
163     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
164     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
165   }
166 
167   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
168     auto value = opaquePointerType(ctx);
169     auto hdl = opaquePointerType(ctx);
170     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
171     return FunctionType::get(ctx, {value, hdl, resume}, {});
172   }
173 
174   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
175     auto hdl = opaquePointerType(ctx);
176     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
177     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
178   }
179 
180   // Auxiliary coroutine resume intrinsic wrapper.
181   static Type resumeFunctionType(MLIRContext *ctx) {
182     auto voidTy = LLVM::LLVMVoidType::get(ctx);
183     auto i8Ptr = opaquePointerType(ctx);
184     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
185   }
186 };
187 } // namespace
188 
189 /// Adds Async Runtime C API declarations to the module.
190 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
191   auto builder =
192       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
193 
194   auto addFuncDecl = [&](StringRef name, FunctionType type) {
195     if (module.lookupSymbol(name))
196       return;
197     builder.create<FuncOp>(name, type).setPrivate();
198   };
199 
200   MLIRContext *ctx = module.getContext();
201   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
202   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
203   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
204   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
205   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
206   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
207   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
208   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
209   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
210   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
211   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
212   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
213   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
214   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
215   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
216   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
217   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
218   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
219   addFuncDecl(kAwaitTokenAndExecute,
220               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
221   addFuncDecl(kAwaitValueAndExecute,
222               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
223   addFuncDecl(kAwaitAllAndExecute,
224               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // Add malloc/free declarations to the module.
229 //===----------------------------------------------------------------------===//
230 
231 static constexpr const char *kMalloc = "malloc";
232 static constexpr const char *kFree = "free";
233 
234 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
235                             StringRef name, Type ret, ArrayRef<Type> params) {
236   if (module.lookupSymbol(name))
237     return;
238   Type type = LLVM::LLVMFunctionType::get(ret, params);
239   builder.create<LLVM::LLVMFuncOp>(name, type);
240 }
241 
242 /// Adds malloc/free declarations to the module.
243 static void addCRuntimeDeclarations(ModuleOp module) {
244   using namespace mlir::LLVM;
245 
246   MLIRContext *ctx = module.getContext();
247   auto builder =
248       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
249 
250   auto voidTy = LLVMVoidType::get(ctx);
251   auto i64 = IntegerType::get(ctx, 64);
252   auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8));
253 
254   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
255   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // Coroutine resume function wrapper.
260 //===----------------------------------------------------------------------===//
261 
262 static constexpr const char *kResume = "__resume";
263 
264 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
265 /// intrinsics. We need this function to be able to pass it to the async
266 /// runtime execute API.
267 static void addResumeFunction(ModuleOp module) {
268   if (module.lookupSymbol(kResume))
269     return;
270 
271   MLIRContext *ctx = module.getContext();
272   auto loc = module.getLoc();
273   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
274 
275   auto voidTy = LLVM::LLVMVoidType::get(ctx);
276   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
277 
278   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
279       kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
280   resumeOp.setPrivate();
281 
282   auto *block = resumeOp.addEntryBlock();
283   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
284 
285   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
286   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // Convert Async dialect types to LLVM types.
291 //===----------------------------------------------------------------------===//
292 
293 namespace {
294 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
295 /// their runtime type (opaque pointers) and does not convert any other types.
296 class AsyncRuntimeTypeConverter : public TypeConverter {
297 public:
298   AsyncRuntimeTypeConverter() {
299     addConversion([](Type type) { return type; });
300     addConversion(convertAsyncTypes);
301   }
302 
303   static Optional<Type> convertAsyncTypes(Type type) {
304     if (type.isa<TokenType, GroupType, ValueType>())
305       return AsyncAPI::opaquePointerType(type.getContext());
306 
307     if (type.isa<CoroIdType, CoroStateType>())
308       return AsyncAPI::tokenType(type.getContext());
309     if (type.isa<CoroHandleType>())
310       return AsyncAPI::opaquePointerType(type.getContext());
311 
312     return llvm::None;
313   }
314 };
315 } // namespace
316 
317 //===----------------------------------------------------------------------===//
318 // Convert async.coro.id to @llvm.coro.id intrinsic.
319 //===----------------------------------------------------------------------===//
320 
321 namespace {
322 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
323 public:
324   using OpConversionPattern::OpConversionPattern;
325 
326   LogicalResult
327   matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands,
328                   ConversionPatternRewriter &rewriter) const override {
329     auto token = AsyncAPI::tokenType(op->getContext());
330     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
331     auto loc = op->getLoc();
332 
333     // Constants for initializing coroutine frame.
334     auto constZero = rewriter.create<LLVM::ConstantOp>(
335         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
336     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
337 
338     // Get coroutine id: @llvm.coro.id.
339     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
340         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
341 
342     return success();
343   }
344 };
345 } // namespace
346 
347 //===----------------------------------------------------------------------===//
348 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
349 //===----------------------------------------------------------------------===//
350 
351 namespace {
352 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
353 public:
354   using OpConversionPattern::OpConversionPattern;
355 
356   LogicalResult
357   matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands,
358                   ConversionPatternRewriter &rewriter) const override {
359     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
360     auto loc = op->getLoc();
361 
362     // Get coroutine frame size: @llvm.coro.size.i64.
363     auto coroSize =
364         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
365 
366     // Allocate memory for the coroutine frame.
367     auto coroAlloc = rewriter.create<LLVM::CallOp>(
368         loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc),
369         ValueRange(coroSize.getResult()));
370 
371     // Begin a coroutine: @llvm.coro.begin.
372     auto coroId = CoroBeginOpAdaptor(operands).id();
373     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
374         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
375 
376     return success();
377   }
378 };
379 } // namespace
380 
381 //===----------------------------------------------------------------------===//
382 // Convert async.coro.free to @llvm.coro.free intrinsic.
383 //===----------------------------------------------------------------------===//
384 
385 namespace {
386 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
387 public:
388   using OpConversionPattern::OpConversionPattern;
389 
390   LogicalResult
391   matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands,
392                   ConversionPatternRewriter &rewriter) const override {
393     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
394     auto loc = op->getLoc();
395 
396     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
397     auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands);
398 
399     // Free the memory.
400     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
401                                               rewriter.getSymbolRefAttr(kFree),
402                                               ValueRange(coroMem.getResult()));
403 
404     return success();
405   }
406 };
407 } // namespace
408 
409 //===----------------------------------------------------------------------===//
410 // Convert async.coro.end to @llvm.coro.end intrinsic.
411 //===----------------------------------------------------------------------===//
412 
413 namespace {
414 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
415 public:
416   using OpConversionPattern::OpConversionPattern;
417 
418   LogicalResult
419   matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands,
420                   ConversionPatternRewriter &rewriter) const override {
421     // We are not in the block that is part of the unwind sequence.
422     auto constFalse = rewriter.create<LLVM::ConstantOp>(
423         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
424 
425     // Mark the end of a coroutine: @llvm.coro.end.
426     auto coroHdl = CoroEndOpAdaptor(operands).handle();
427     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
428                                      ValueRange({coroHdl, constFalse}));
429     rewriter.eraseOp(op);
430 
431     return success();
432   }
433 };
434 } // namespace
435 
436 //===----------------------------------------------------------------------===//
437 // Convert async.coro.save to @llvm.coro.save intrinsic.
438 //===----------------------------------------------------------------------===//
439 
440 namespace {
441 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
442 public:
443   using OpConversionPattern::OpConversionPattern;
444 
445   LogicalResult
446   matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands,
447                   ConversionPatternRewriter &rewriter) const override {
448     // Save the coroutine state: @llvm.coro.save
449     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
450         op, AsyncAPI::tokenType(op->getContext()), operands);
451 
452     return success();
453   }
454 };
455 } // namespace
456 
457 //===----------------------------------------------------------------------===//
458 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
459 //===----------------------------------------------------------------------===//
460 
461 namespace {
462 
463 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
464 /// branch to the appropriate block based on the return code.
465 ///
466 /// Before:
467 ///
468 ///   ^suspended:
469 ///     "opBefore"(...)
470 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
471 ///   ^resume:
472 ///     "op"(...)
473 ///   ^cleanup: ...
474 ///   ^suspend: ...
475 ///
476 /// After:
477 ///
478 ///   ^suspended:
479 ///     "opBefore"(...)
480 ///     %suspend = llmv.intr.coro.suspend ...
481 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
482 ///   ^resume:
483 ///     "op"(...)
484 ///   ^cleanup: ...
485 ///   ^suspend: ...
486 ///
487 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
488 public:
489   using OpConversionPattern::OpConversionPattern;
490 
491   LogicalResult
492   matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands,
493                   ConversionPatternRewriter &rewriter) const override {
494     auto i8 = rewriter.getIntegerType(8);
495     auto i32 = rewriter.getI32Type();
496     auto loc = op->getLoc();
497 
498     // This is not a final suspension point.
499     auto constFalse = rewriter.create<LLVM::ConstantOp>(
500         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
501 
502     // Suspend a coroutine: @llvm.coro.suspend
503     auto coroState = CoroSuspendOpAdaptor(operands).state();
504     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
505         loc, i8, ValueRange({coroState, constFalse}));
506 
507     // Cast return code to i32.
508 
509     // After a suspension point decide if we should branch into resume, cleanup
510     // or suspend block of the coroutine (see @llvm.coro.suspend return code
511     // documentation).
512     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
513     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
514                                               op.cleanupDest()};
515     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
516         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
517         /*defaultDestination=*/op.suspendDest(),
518         /*defaultOperands=*/ValueRange(),
519         /*caseValues=*/caseValues,
520         /*caseDestinations=*/caseDest,
521         /*caseOperands=*/ArrayRef<ValueRange>(),
522         /*branchWeights=*/ArrayRef<int32_t>());
523 
524     return success();
525   }
526 };
527 } // namespace
528 
529 //===----------------------------------------------------------------------===//
530 // Convert async.runtime.create to the corresponding runtime API call.
531 //
532 // To allocate storage for the async values we use getelementptr trick:
533 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
534 //===----------------------------------------------------------------------===//
535 
536 namespace {
537 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
538 public:
539   using OpConversionPattern::OpConversionPattern;
540 
541   LogicalResult
542   matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands,
543                   ConversionPatternRewriter &rewriter) const override {
544     TypeConverter *converter = getTypeConverter();
545     Type resultType = op->getResultTypes()[0];
546 
547     // Tokens creation maps to a simple function call.
548     if (resultType.isa<TokenType>()) {
549       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
550                                           converter->convertType(resultType));
551       return success();
552     }
553 
554     // To create a value we need to compute the storage requirement.
555     if (auto value = resultType.dyn_cast<ValueType>()) {
556       // Returns the size requirements for the async value storage.
557       auto sizeOf = [&](ValueType valueType) -> Value {
558         auto loc = op->getLoc();
559         auto i32 = rewriter.getI32Type();
560 
561         auto storedType = converter->convertType(valueType.getValueType());
562         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
563 
564         // %Size = getelementptr %T* null, int 1
565         // %SizeI = ptrtoint %T* %Size to i32
566         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
567         auto one = rewriter.create<LLVM::ConstantOp>(
568             loc, i32, rewriter.getI32IntegerAttr(1));
569         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
570                                                 one.getResult());
571         return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep);
572       };
573 
574       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
575                                           sizeOf(value));
576 
577       return success();
578     }
579 
580     return rewriter.notifyMatchFailure(op, "unsupported async type");
581   }
582 };
583 } // namespace
584 
585 //===----------------------------------------------------------------------===//
586 // Convert async.runtime.create_group to the corresponding runtime API call.
587 //===----------------------------------------------------------------------===//
588 
589 namespace {
590 class RuntimeCreateGroupOpLowering
591     : public OpConversionPattern<RuntimeCreateGroupOp> {
592 public:
593   using OpConversionPattern::OpConversionPattern;
594 
595   LogicalResult
596   matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands,
597                   ConversionPatternRewriter &rewriter) const override {
598     TypeConverter *converter = getTypeConverter();
599     Type resultType = op.getResult().getType();
600 
601     rewriter.replaceOpWithNewOp<CallOp>(
602         op, kCreateGroup, converter->convertType(resultType), operands);
603     return success();
604   }
605 };
606 } // namespace
607 
608 //===----------------------------------------------------------------------===//
609 // Convert async.runtime.set_available to the corresponding runtime API call.
610 //===----------------------------------------------------------------------===//
611 
612 namespace {
613 class RuntimeSetAvailableOpLowering
614     : public OpConversionPattern<RuntimeSetAvailableOp> {
615 public:
616   using OpConversionPattern::OpConversionPattern;
617 
618   LogicalResult
619   matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
620                   ConversionPatternRewriter &rewriter) const override {
621     StringRef apiFuncName =
622         TypeSwitch<Type, StringRef>(op.operand().getType())
623             .Case<TokenType>([](Type) { return kEmplaceToken; })
624             .Case<ValueType>([](Type) { return kEmplaceValue; });
625 
626     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
627 
628     return success();
629   }
630 };
631 } // namespace
632 
633 //===----------------------------------------------------------------------===//
634 // Convert async.runtime.set_error to the corresponding runtime API call.
635 //===----------------------------------------------------------------------===//
636 
637 namespace {
638 class RuntimeSetErrorOpLowering
639     : public OpConversionPattern<RuntimeSetErrorOp> {
640 public:
641   using OpConversionPattern::OpConversionPattern;
642 
643   LogicalResult
644   matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands,
645                   ConversionPatternRewriter &rewriter) const override {
646     StringRef apiFuncName =
647         TypeSwitch<Type, StringRef>(op.operand().getType())
648             .Case<TokenType>([](Type) { return kSetTokenError; })
649             .Case<ValueType>([](Type) { return kSetValueError; });
650 
651     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
652 
653     return success();
654   }
655 };
656 } // namespace
657 
658 //===----------------------------------------------------------------------===//
659 // Convert async.runtime.is_error to the corresponding runtime API call.
660 //===----------------------------------------------------------------------===//
661 
662 namespace {
663 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
664 public:
665   using OpConversionPattern::OpConversionPattern;
666 
667   LogicalResult
668   matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands,
669                   ConversionPatternRewriter &rewriter) const override {
670     StringRef apiFuncName =
671         TypeSwitch<Type, StringRef>(op.operand().getType())
672             .Case<TokenType>([](Type) { return kIsTokenError; })
673             .Case<GroupType>([](Type) { return kIsGroupError; })
674             .Case<ValueType>([](Type) { return kIsValueError; });
675 
676     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
677                                         operands);
678     return success();
679   }
680 };
681 } // namespace
682 
683 //===----------------------------------------------------------------------===//
684 // Convert async.runtime.await to the corresponding runtime API call.
685 //===----------------------------------------------------------------------===//
686 
687 namespace {
688 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
689 public:
690   using OpConversionPattern::OpConversionPattern;
691 
692   LogicalResult
693   matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
694                   ConversionPatternRewriter &rewriter) const override {
695     StringRef apiFuncName =
696         TypeSwitch<Type, StringRef>(op.operand().getType())
697             .Case<TokenType>([](Type) { return kAwaitToken; })
698             .Case<ValueType>([](Type) { return kAwaitValue; })
699             .Case<GroupType>([](Type) { return kAwaitGroup; });
700 
701     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
702     rewriter.eraseOp(op);
703 
704     return success();
705   }
706 };
707 } // namespace
708 
709 //===----------------------------------------------------------------------===//
710 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
711 //===----------------------------------------------------------------------===//
712 
713 namespace {
714 class RuntimeAwaitAndResumeOpLowering
715     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
716 public:
717   using OpConversionPattern::OpConversionPattern;
718 
719   LogicalResult
720   matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
721                   ConversionPatternRewriter &rewriter) const override {
722     StringRef apiFuncName =
723         TypeSwitch<Type, StringRef>(op.operand().getType())
724             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
725             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
726             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
727 
728     Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
729     Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
730 
731     // A pointer to coroutine resume intrinsic wrapper.
732     addResumeFunction(op->getParentOfType<ModuleOp>());
733     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
734     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
735         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
736 
737     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
738                             ValueRange({operand, handle, resumePtr.res()}));
739     rewriter.eraseOp(op);
740 
741     return success();
742   }
743 };
744 } // namespace
745 
746 //===----------------------------------------------------------------------===//
747 // Convert async.runtime.resume to the corresponding runtime API call.
748 //===----------------------------------------------------------------------===//
749 
750 namespace {
751 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
752 public:
753   using OpConversionPattern::OpConversionPattern;
754 
755   LogicalResult
756   matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands,
757                   ConversionPatternRewriter &rewriter) const override {
758     // A pointer to coroutine resume intrinsic wrapper.
759     addResumeFunction(op->getParentOfType<ModuleOp>());
760     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
761     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
762         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
763 
764     // Call async runtime API to execute a coroutine in the managed thread.
765     auto coroHdl = RuntimeResumeOpAdaptor(operands).handle();
766     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute,
767                                         ValueRange({coroHdl, resumePtr.res()}));
768 
769     return success();
770   }
771 };
772 } // namespace
773 
774 //===----------------------------------------------------------------------===//
775 // Convert async.runtime.store to the corresponding runtime API call.
776 //===----------------------------------------------------------------------===//
777 
778 namespace {
779 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
780 public:
781   using OpConversionPattern::OpConversionPattern;
782 
783   LogicalResult
784   matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands,
785                   ConversionPatternRewriter &rewriter) const override {
786     Location loc = op->getLoc();
787 
788     // Get a pointer to the async value storage from the runtime.
789     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
790     auto storage = RuntimeStoreOpAdaptor(operands).storage();
791     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
792                                               TypeRange(i8Ptr), storage);
793 
794     // Cast from i8* to the LLVM pointer type.
795     auto valueType = op.value().getType();
796     auto llvmValueType = getTypeConverter()->convertType(valueType);
797     if (!llvmValueType)
798       return rewriter.notifyMatchFailure(
799           op, "failed to convert stored value type to LLVM type");
800 
801     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
802         loc, LLVM::LLVMPointerType::get(llvmValueType),
803         storagePtr.getResult(0));
804 
805     // Store the yielded value into the async value storage.
806     auto value = RuntimeStoreOpAdaptor(operands).value();
807     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
808 
809     // Erase the original runtime store operation.
810     rewriter.eraseOp(op);
811 
812     return success();
813   }
814 };
815 } // namespace
816 
817 //===----------------------------------------------------------------------===//
818 // Convert async.runtime.load to the corresponding runtime API call.
819 //===----------------------------------------------------------------------===//
820 
821 namespace {
822 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
823 public:
824   using OpConversionPattern::OpConversionPattern;
825 
826   LogicalResult
827   matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands,
828                   ConversionPatternRewriter &rewriter) const override {
829     Location loc = op->getLoc();
830 
831     // Get a pointer to the async value storage from the runtime.
832     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
833     auto storage = RuntimeLoadOpAdaptor(operands).storage();
834     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
835                                               TypeRange(i8Ptr), storage);
836 
837     // Cast from i8* to the LLVM pointer type.
838     auto valueType = op.result().getType();
839     auto llvmValueType = getTypeConverter()->convertType(valueType);
840     if (!llvmValueType)
841       return rewriter.notifyMatchFailure(
842           op, "failed to convert loaded value type to LLVM type");
843 
844     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
845         loc, LLVM::LLVMPointerType::get(llvmValueType),
846         storagePtr.getResult(0));
847 
848     // Load from the casted pointer.
849     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
850 
851     return success();
852   }
853 };
854 } // namespace
855 
856 //===----------------------------------------------------------------------===//
857 // Convert async.runtime.add_to_group to the corresponding runtime API call.
858 //===----------------------------------------------------------------------===//
859 
860 namespace {
861 class RuntimeAddToGroupOpLowering
862     : public OpConversionPattern<RuntimeAddToGroupOp> {
863 public:
864   using OpConversionPattern::OpConversionPattern;
865 
866   LogicalResult
867   matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands,
868                   ConversionPatternRewriter &rewriter) const override {
869     // Currently we can only add tokens to the group.
870     if (!op.operand().getType().isa<TokenType>())
871       return rewriter.notifyMatchFailure(op, "only token type is supported");
872 
873     // Replace with a runtime API function call.
874     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup,
875                                         rewriter.getI64Type(), operands);
876 
877     return success();
878   }
879 };
880 } // namespace
881 
882 //===----------------------------------------------------------------------===//
883 // Async reference counting ops lowering (`async.runtime.add_ref` and
884 // `async.runtime.drop_ref` to the corresponding API calls).
885 //===----------------------------------------------------------------------===//
886 
887 namespace {
888 template <typename RefCountingOp>
889 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
890 public:
891   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
892                                  StringRef apiFunctionName)
893       : OpConversionPattern<RefCountingOp>(converter, ctx),
894         apiFunctionName(apiFunctionName) {}
895 
896   LogicalResult
897   matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands,
898                   ConversionPatternRewriter &rewriter) const override {
899     auto count =
900         rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(),
901                                     rewriter.getI32IntegerAttr(op.count()));
902 
903     auto operand = typename RefCountingOp::Adaptor(operands).operand();
904     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
905                                         ValueRange({operand, count}));
906 
907     return success();
908   }
909 
910 private:
911   StringRef apiFunctionName;
912 };
913 
914 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
915 public:
916   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
917       : RefCountingOpLowering(converter, ctx, kAddRef) {}
918 };
919 
920 class RuntimeDropRefOpLowering
921     : public RefCountingOpLowering<RuntimeDropRefOp> {
922 public:
923   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
924       : RefCountingOpLowering(converter, ctx, kDropRef) {}
925 };
926 } // namespace
927 
928 //===----------------------------------------------------------------------===//
929 // Convert return operations that return async values from async regions.
930 //===----------------------------------------------------------------------===//
931 
932 namespace {
933 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
934 public:
935   using OpConversionPattern::OpConversionPattern;
936 
937   LogicalResult
938   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
939                   ConversionPatternRewriter &rewriter) const override {
940     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
941     return success();
942   }
943 };
944 } // namespace
945 
946 //===----------------------------------------------------------------------===//
947 
948 namespace {
949 struct ConvertAsyncToLLVMPass
950     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
951   void runOnOperation() override;
952 };
953 } // namespace
954 
955 void ConvertAsyncToLLVMPass::runOnOperation() {
956   ModuleOp module = getOperation();
957   MLIRContext *ctx = module->getContext();
958 
959   // Add declarations for most functions required by the coroutines lowering.
960   // We delay adding the resume function until it's needed because it currently
961   // fails to compile unless '-O0' is specified.
962   addAsyncRuntimeApiDeclarations(module);
963   addCRuntimeDeclarations(module);
964 
965   // Lower async.runtime and async.coro operations to Async Runtime API and
966   // LLVM coroutine intrinsics.
967 
968   // Convert async dialect types and operations to LLVM dialect.
969   AsyncRuntimeTypeConverter converter;
970   RewritePatternSet patterns(ctx);
971 
972   // We use conversion to LLVM type to lower async.runtime load and store
973   // operations.
974   LLVMTypeConverter llvmConverter(ctx);
975   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
976 
977   // Convert async types in function signatures and function calls.
978   populateFuncOpTypeConversionPattern(patterns, converter);
979   populateCallOpTypeConversionPattern(patterns, converter);
980 
981   // Convert return operations inside async.execute regions.
982   patterns.add<ReturnOpOpConversion>(converter, ctx);
983 
984   // Lower async.runtime operations to the async runtime API calls.
985   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
986                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
987                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
988                RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
989                RuntimeDropRefOpLowering>(converter, ctx);
990 
991   // Lower async.runtime operations that rely on LLVM type converter to convert
992   // from async value payload type to the LLVM type.
993   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
994                RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
995                                                               ctx);
996 
997   // Lower async coroutine operations to LLVM coroutine intrinsics.
998   patterns
999       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1000            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1001           converter, ctx);
1002 
1003   ConversionTarget target(*ctx);
1004   target.addLegalOp<ConstantOp>();
1005   target.addLegalDialect<LLVM::LLVMDialect>();
1006 
1007   // All operations from Async dialect must be lowered to the runtime API and
1008   // LLVM intrinsics calls.
1009   target.addIllegalDialect<AsyncDialect>();
1010 
1011   // Add dynamic legality constraints to apply conversions defined above.
1012   target.addDynamicallyLegalOp<FuncOp>(
1013       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1014   target.addDynamicallyLegalOp<ReturnOp>(
1015       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1016   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1017     return converter.isSignatureLegal(op.getCalleeType());
1018   });
1019 
1020   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1021     signalPassFailure();
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // Patterns for structural type conversions for the Async dialect operations.
1026 //===----------------------------------------------------------------------===//
1027 
1028 namespace {
1029 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1030 public:
1031   using OpConversionPattern::OpConversionPattern;
1032   LogicalResult
1033   matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
1034                   ConversionPatternRewriter &rewriter) const override {
1035     ExecuteOp newOp =
1036         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1037     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1038                                 newOp.getRegion().end());
1039 
1040     // Set operands and update block argument and result types.
1041     newOp->setOperands(operands);
1042     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1043       return failure();
1044     for (auto result : newOp.getResults())
1045       result.setType(typeConverter->convertType(result.getType()));
1046 
1047     rewriter.replaceOp(op, newOp.getResults());
1048     return success();
1049   }
1050 };
1051 
1052 // Dummy pattern to trigger the appropriate type conversion / materialization.
1053 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1054 public:
1055   using OpConversionPattern::OpConversionPattern;
1056   LogicalResult
1057   matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
1058                   ConversionPatternRewriter &rewriter) const override {
1059     rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
1060     return success();
1061   }
1062 };
1063 
1064 // Dummy pattern to trigger the appropriate type conversion / materialization.
1065 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1066 public:
1067   using OpConversionPattern::OpConversionPattern;
1068   LogicalResult
1069   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
1070                   ConversionPatternRewriter &rewriter) const override {
1071     rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
1072     return success();
1073   }
1074 };
1075 } // namespace
1076 
1077 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1078   return std::make_unique<ConvertAsyncToLLVMPass>();
1079 }
1080 
1081 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1082     TypeConverter &typeConverter, RewritePatternSet &patterns,
1083     ConversionTarget &target) {
1084   typeConverter.addConversion([&](TokenType type) { return type; });
1085   typeConverter.addConversion([&](ValueType type) {
1086     Type converted = typeConverter.convertType(type.getValueType());
1087     return converted ? ValueType::get(converted) : converted;
1088   });
1089 
1090   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1091       typeConverter, patterns.getContext());
1092 
1093   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1094       [&](Operation *op) { return typeConverter.isLegal(op); });
1095 }
1096