1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Dialect/Async/IR/Async.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 #define DEBUG_TYPE "convert-async-to-llvm"
26 
27 using namespace mlir;
28 using namespace mlir::async;
29 
30 //===----------------------------------------------------------------------===//
31 // Async Runtime C API declaration.
32 //===----------------------------------------------------------------------===//
33 
34 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
35 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
37 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
38 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
39 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
40 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
41 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
42 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
43 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
44 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
45 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
46 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
47 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
48 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
49 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
50 static constexpr const char *kGetValueStorage =
51     "mlirAsyncRuntimeGetValueStorage";
52 static constexpr const char *kAddTokenToGroup =
53     "mlirAsyncRuntimeAddTokenToGroup";
54 static constexpr const char *kAwaitTokenAndExecute =
55     "mlirAsyncRuntimeAwaitTokenAndExecute";
56 static constexpr const char *kAwaitValueAndExecute =
57     "mlirAsyncRuntimeAwaitValueAndExecute";
58 static constexpr const char *kAwaitAllAndExecute =
59     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
60 
61 namespace {
62 /// Async Runtime API function types.
63 ///
64 /// Because we can't create API function signature for type parametrized
65 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
66 /// lowering all async data types become opaque pointers at runtime.
67 struct AsyncAPI {
68   // All async types are lowered to opaque i8* LLVM pointers at runtime.
69   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
70     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
71   }
72 
73   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
74     return LLVM::LLVMTokenType::get(ctx);
75   }
76 
77   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
78     auto ref = opaquePointerType(ctx);
79     auto count = IntegerType::get(ctx, 32);
80     return FunctionType::get(ctx, {ref, count}, {});
81   }
82 
83   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
84     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
85   }
86 
87   static FunctionType createValueFunctionType(MLIRContext *ctx) {
88     auto i32 = IntegerType::get(ctx, 32);
89     auto value = opaquePointerType(ctx);
90     return FunctionType::get(ctx, {i32}, {value});
91   }
92 
93   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
94     auto i64 = IntegerType::get(ctx, 64);
95     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
96   }
97 
98   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
99     auto value = opaquePointerType(ctx);
100     auto storage = opaquePointerType(ctx);
101     return FunctionType::get(ctx, {value}, {storage});
102   }
103 
104   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
105     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
106   }
107 
108   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
109     auto value = opaquePointerType(ctx);
110     return FunctionType::get(ctx, {value}, {});
111   }
112 
113   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
114     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
115   }
116 
117   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
118     auto value = opaquePointerType(ctx);
119     return FunctionType::get(ctx, {value}, {});
120   }
121 
122   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
123     auto i1 = IntegerType::get(ctx, 1);
124     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
125   }
126 
127   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
128     auto value = opaquePointerType(ctx);
129     auto i1 = IntegerType::get(ctx, 1);
130     return FunctionType::get(ctx, {value}, {i1});
131   }
132 
133   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
134     auto i1 = IntegerType::get(ctx, 1);
135     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
136   }
137 
138   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
139     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
140   }
141 
142   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
143     auto value = opaquePointerType(ctx);
144     return FunctionType::get(ctx, {value}, {});
145   }
146 
147   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
148     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
149   }
150 
151   static FunctionType executeFunctionType(MLIRContext *ctx) {
152     auto hdl = opaquePointerType(ctx);
153     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
154     return FunctionType::get(ctx, {hdl, resume}, {});
155   }
156 
157   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
158     auto i64 = IntegerType::get(ctx, 64);
159     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
160                              {i64});
161   }
162 
163   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
164     auto hdl = opaquePointerType(ctx);
165     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
166     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
167   }
168 
169   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
170     auto value = opaquePointerType(ctx);
171     auto hdl = opaquePointerType(ctx);
172     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
173     return FunctionType::get(ctx, {value, hdl, resume}, {});
174   }
175 
176   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
177     auto hdl = opaquePointerType(ctx);
178     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
179     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
180   }
181 
182   // Auxiliary coroutine resume intrinsic wrapper.
183   static Type resumeFunctionType(MLIRContext *ctx) {
184     auto voidTy = LLVM::LLVMVoidType::get(ctx);
185     auto i8Ptr = opaquePointerType(ctx);
186     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
187   }
188 };
189 } // namespace
190 
191 /// Adds Async Runtime C API declarations to the module.
192 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
193   auto builder =
194       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
195 
196   auto addFuncDecl = [&](StringRef name, FunctionType type) {
197     if (module.lookupSymbol(name))
198       return;
199     builder.create<FuncOp>(name, type).setPrivate();
200   };
201 
202   MLIRContext *ctx = module.getContext();
203   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
204   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
205   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
206   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
207   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
208   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
209   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
210   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
211   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
212   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
213   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
214   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
215   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
216   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
217   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
218   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
219   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
220   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
221   addFuncDecl(kAwaitTokenAndExecute,
222               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
223   addFuncDecl(kAwaitValueAndExecute,
224               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
225   addFuncDecl(kAwaitAllAndExecute,
226               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // Add malloc/free declarations to the module.
231 //===----------------------------------------------------------------------===//
232 
233 static constexpr const char *kMalloc = "malloc";
234 static constexpr const char *kFree = "free";
235 
236 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
237                             StringRef name, Type ret, ArrayRef<Type> params) {
238   if (module.lookupSymbol(name))
239     return;
240   Type type = LLVM::LLVMFunctionType::get(ret, params);
241   builder.create<LLVM::LLVMFuncOp>(name, type);
242 }
243 
244 /// Adds malloc/free declarations to the module.
245 static void addCRuntimeDeclarations(ModuleOp module) {
246   using namespace mlir::LLVM;
247 
248   MLIRContext *ctx = module.getContext();
249   auto builder =
250       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
251 
252   auto voidTy = LLVMVoidType::get(ctx);
253   auto i64 = IntegerType::get(ctx, 64);
254   auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8));
255 
256   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
257   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // Coroutine resume function wrapper.
262 //===----------------------------------------------------------------------===//
263 
264 static constexpr const char *kResume = "__resume";
265 
266 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
267 /// intrinsics. We need this function to be able to pass it to the async
268 /// runtime execute API.
269 static void addResumeFunction(ModuleOp module) {
270   if (module.lookupSymbol(kResume))
271     return;
272 
273   MLIRContext *ctx = module.getContext();
274   auto loc = module.getLoc();
275   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
276 
277   auto voidTy = LLVM::LLVMVoidType::get(ctx);
278   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
279 
280   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
281       kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
282   resumeOp.setPrivate();
283 
284   auto *block = resumeOp.addEntryBlock();
285   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
286 
287   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
288   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // Convert Async dialect types to LLVM types.
293 //===----------------------------------------------------------------------===//
294 
295 namespace {
296 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
297 /// their runtime type (opaque pointers) and does not convert any other types.
298 class AsyncRuntimeTypeConverter : public TypeConverter {
299 public:
300   AsyncRuntimeTypeConverter() {
301     addConversion([](Type type) { return type; });
302     addConversion(convertAsyncTypes);
303   }
304 
305   static Optional<Type> convertAsyncTypes(Type type) {
306     if (type.isa<TokenType, GroupType, ValueType>())
307       return AsyncAPI::opaquePointerType(type.getContext());
308 
309     if (type.isa<CoroIdType, CoroStateType>())
310       return AsyncAPI::tokenType(type.getContext());
311     if (type.isa<CoroHandleType>())
312       return AsyncAPI::opaquePointerType(type.getContext());
313 
314     return llvm::None;
315   }
316 };
317 } // namespace
318 
319 //===----------------------------------------------------------------------===//
320 // Convert async.coro.id to @llvm.coro.id intrinsic.
321 //===----------------------------------------------------------------------===//
322 
323 namespace {
324 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
325 public:
326   using OpConversionPattern::OpConversionPattern;
327 
328   LogicalResult
329   matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands,
330                   ConversionPatternRewriter &rewriter) const override {
331     auto token = AsyncAPI::tokenType(op->getContext());
332     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
333     auto loc = op->getLoc();
334 
335     // Constants for initializing coroutine frame.
336     auto constZero = rewriter.create<LLVM::ConstantOp>(
337         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
338     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
339 
340     // Get coroutine id: @llvm.coro.id.
341     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
342         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
343 
344     return success();
345   }
346 };
347 } // namespace
348 
349 //===----------------------------------------------------------------------===//
350 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
351 //===----------------------------------------------------------------------===//
352 
353 namespace {
354 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
355 public:
356   using OpConversionPattern::OpConversionPattern;
357 
358   LogicalResult
359   matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands,
360                   ConversionPatternRewriter &rewriter) const override {
361     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
362     auto loc = op->getLoc();
363 
364     // Get coroutine frame size: @llvm.coro.size.i64.
365     auto coroSize =
366         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
367 
368     // Allocate memory for the coroutine frame.
369     auto coroAlloc = rewriter.create<LLVM::CallOp>(
370         loc, i8Ptr, SymbolRefAttr::get(rewriter.getContext(), kMalloc),
371         ValueRange(coroSize.getResult()));
372 
373     // Begin a coroutine: @llvm.coro.begin.
374     auto coroId = CoroBeginOpAdaptor(operands).id();
375     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
376         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
377 
378     return success();
379   }
380 };
381 } // namespace
382 
383 //===----------------------------------------------------------------------===//
384 // Convert async.coro.free to @llvm.coro.free intrinsic.
385 //===----------------------------------------------------------------------===//
386 
387 namespace {
388 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
389 public:
390   using OpConversionPattern::OpConversionPattern;
391 
392   LogicalResult
393   matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands,
394                   ConversionPatternRewriter &rewriter) const override {
395     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
396     auto loc = op->getLoc();
397 
398     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
399     auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands);
400 
401     // Free the memory.
402     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
403         op, TypeRange(), SymbolRefAttr::get(rewriter.getContext(), kFree),
404         ValueRange(coroMem.getResult()));
405 
406     return success();
407   }
408 };
409 } // namespace
410 
411 //===----------------------------------------------------------------------===//
412 // Convert async.coro.end to @llvm.coro.end intrinsic.
413 //===----------------------------------------------------------------------===//
414 
415 namespace {
416 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
417 public:
418   using OpConversionPattern::OpConversionPattern;
419 
420   LogicalResult
421   matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands,
422                   ConversionPatternRewriter &rewriter) const override {
423     // We are not in the block that is part of the unwind sequence.
424     auto constFalse = rewriter.create<LLVM::ConstantOp>(
425         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
426 
427     // Mark the end of a coroutine: @llvm.coro.end.
428     auto coroHdl = CoroEndOpAdaptor(operands).handle();
429     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
430                                      ValueRange({coroHdl, constFalse}));
431     rewriter.eraseOp(op);
432 
433     return success();
434   }
435 };
436 } // namespace
437 
438 //===----------------------------------------------------------------------===//
439 // Convert async.coro.save to @llvm.coro.save intrinsic.
440 //===----------------------------------------------------------------------===//
441 
442 namespace {
443 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
444 public:
445   using OpConversionPattern::OpConversionPattern;
446 
447   LogicalResult
448   matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands,
449                   ConversionPatternRewriter &rewriter) const override {
450     // Save the coroutine state: @llvm.coro.save
451     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
452         op, AsyncAPI::tokenType(op->getContext()), operands);
453 
454     return success();
455   }
456 };
457 } // namespace
458 
459 //===----------------------------------------------------------------------===//
460 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
461 //===----------------------------------------------------------------------===//
462 
463 namespace {
464 
465 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
466 /// branch to the appropriate block based on the return code.
467 ///
468 /// Before:
469 ///
470 ///   ^suspended:
471 ///     "opBefore"(...)
472 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
473 ///   ^resume:
474 ///     "op"(...)
475 ///   ^cleanup: ...
476 ///   ^suspend: ...
477 ///
478 /// After:
479 ///
480 ///   ^suspended:
481 ///     "opBefore"(...)
482 ///     %suspend = llmv.intr.coro.suspend ...
483 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
484 ///   ^resume:
485 ///     "op"(...)
486 ///   ^cleanup: ...
487 ///   ^suspend: ...
488 ///
489 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
490 public:
491   using OpConversionPattern::OpConversionPattern;
492 
493   LogicalResult
494   matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands,
495                   ConversionPatternRewriter &rewriter) const override {
496     auto i8 = rewriter.getIntegerType(8);
497     auto i32 = rewriter.getI32Type();
498     auto loc = op->getLoc();
499 
500     // This is not a final suspension point.
501     auto constFalse = rewriter.create<LLVM::ConstantOp>(
502         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
503 
504     // Suspend a coroutine: @llvm.coro.suspend
505     auto coroState = CoroSuspendOpAdaptor(operands).state();
506     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
507         loc, i8, ValueRange({coroState, constFalse}));
508 
509     // Cast return code to i32.
510 
511     // After a suspension point decide if we should branch into resume, cleanup
512     // or suspend block of the coroutine (see @llvm.coro.suspend return code
513     // documentation).
514     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
515     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
516                                               op.cleanupDest()};
517     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
518         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
519         /*defaultDestination=*/op.suspendDest(),
520         /*defaultOperands=*/ValueRange(),
521         /*caseValues=*/caseValues,
522         /*caseDestinations=*/caseDest,
523         /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
524         /*branchWeights=*/ArrayRef<int32_t>());
525 
526     return success();
527   }
528 };
529 } // namespace
530 
531 //===----------------------------------------------------------------------===//
532 // Convert async.runtime.create to the corresponding runtime API call.
533 //
534 // To allocate storage for the async values we use getelementptr trick:
535 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
536 //===----------------------------------------------------------------------===//
537 
538 namespace {
539 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
540 public:
541   using OpConversionPattern::OpConversionPattern;
542 
543   LogicalResult
544   matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands,
545                   ConversionPatternRewriter &rewriter) const override {
546     TypeConverter *converter = getTypeConverter();
547     Type resultType = op->getResultTypes()[0];
548 
549     // Tokens creation maps to a simple function call.
550     if (resultType.isa<TokenType>()) {
551       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
552                                           converter->convertType(resultType));
553       return success();
554     }
555 
556     // To create a value we need to compute the storage requirement.
557     if (auto value = resultType.dyn_cast<ValueType>()) {
558       // Returns the size requirements for the async value storage.
559       auto sizeOf = [&](ValueType valueType) -> Value {
560         auto loc = op->getLoc();
561         auto i32 = rewriter.getI32Type();
562 
563         auto storedType = converter->convertType(valueType.getValueType());
564         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
565 
566         // %Size = getelementptr %T* null, int 1
567         // %SizeI = ptrtoint %T* %Size to i32
568         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
569         auto one = rewriter.create<LLVM::ConstantOp>(
570             loc, i32, rewriter.getI32IntegerAttr(1));
571         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
572                                                 one.getResult());
573         return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep);
574       };
575 
576       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
577                                           sizeOf(value));
578 
579       return success();
580     }
581 
582     return rewriter.notifyMatchFailure(op, "unsupported async type");
583   }
584 };
585 } // namespace
586 
587 //===----------------------------------------------------------------------===//
588 // Convert async.runtime.create_group to the corresponding runtime API call.
589 //===----------------------------------------------------------------------===//
590 
591 namespace {
592 class RuntimeCreateGroupOpLowering
593     : public OpConversionPattern<RuntimeCreateGroupOp> {
594 public:
595   using OpConversionPattern::OpConversionPattern;
596 
597   LogicalResult
598   matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands,
599                   ConversionPatternRewriter &rewriter) const override {
600     TypeConverter *converter = getTypeConverter();
601     Type resultType = op.getResult().getType();
602 
603     rewriter.replaceOpWithNewOp<CallOp>(
604         op, kCreateGroup, converter->convertType(resultType), operands);
605     return success();
606   }
607 };
608 } // namespace
609 
610 //===----------------------------------------------------------------------===//
611 // Convert async.runtime.set_available to the corresponding runtime API call.
612 //===----------------------------------------------------------------------===//
613 
614 namespace {
615 class RuntimeSetAvailableOpLowering
616     : public OpConversionPattern<RuntimeSetAvailableOp> {
617 public:
618   using OpConversionPattern::OpConversionPattern;
619 
620   LogicalResult
621   matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
622                   ConversionPatternRewriter &rewriter) const override {
623     StringRef apiFuncName =
624         TypeSwitch<Type, StringRef>(op.operand().getType())
625             .Case<TokenType>([](Type) { return kEmplaceToken; })
626             .Case<ValueType>([](Type) { return kEmplaceValue; });
627 
628     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
629 
630     return success();
631   }
632 };
633 } // namespace
634 
635 //===----------------------------------------------------------------------===//
636 // Convert async.runtime.set_error to the corresponding runtime API call.
637 //===----------------------------------------------------------------------===//
638 
639 namespace {
640 class RuntimeSetErrorOpLowering
641     : public OpConversionPattern<RuntimeSetErrorOp> {
642 public:
643   using OpConversionPattern::OpConversionPattern;
644 
645   LogicalResult
646   matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands,
647                   ConversionPatternRewriter &rewriter) const override {
648     StringRef apiFuncName =
649         TypeSwitch<Type, StringRef>(op.operand().getType())
650             .Case<TokenType>([](Type) { return kSetTokenError; })
651             .Case<ValueType>([](Type) { return kSetValueError; });
652 
653     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
654 
655     return success();
656   }
657 };
658 } // namespace
659 
660 //===----------------------------------------------------------------------===//
661 // Convert async.runtime.is_error to the corresponding runtime API call.
662 //===----------------------------------------------------------------------===//
663 
664 namespace {
665 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
666 public:
667   using OpConversionPattern::OpConversionPattern;
668 
669   LogicalResult
670   matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands,
671                   ConversionPatternRewriter &rewriter) const override {
672     StringRef apiFuncName =
673         TypeSwitch<Type, StringRef>(op.operand().getType())
674             .Case<TokenType>([](Type) { return kIsTokenError; })
675             .Case<GroupType>([](Type) { return kIsGroupError; })
676             .Case<ValueType>([](Type) { return kIsValueError; });
677 
678     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
679                                         operands);
680     return success();
681   }
682 };
683 } // namespace
684 
685 //===----------------------------------------------------------------------===//
686 // Convert async.runtime.await to the corresponding runtime API call.
687 //===----------------------------------------------------------------------===//
688 
689 namespace {
690 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
691 public:
692   using OpConversionPattern::OpConversionPattern;
693 
694   LogicalResult
695   matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
696                   ConversionPatternRewriter &rewriter) const override {
697     StringRef apiFuncName =
698         TypeSwitch<Type, StringRef>(op.operand().getType())
699             .Case<TokenType>([](Type) { return kAwaitToken; })
700             .Case<ValueType>([](Type) { return kAwaitValue; })
701             .Case<GroupType>([](Type) { return kAwaitGroup; });
702 
703     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
704     rewriter.eraseOp(op);
705 
706     return success();
707   }
708 };
709 } // namespace
710 
711 //===----------------------------------------------------------------------===//
712 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
713 //===----------------------------------------------------------------------===//
714 
715 namespace {
716 class RuntimeAwaitAndResumeOpLowering
717     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
718 public:
719   using OpConversionPattern::OpConversionPattern;
720 
721   LogicalResult
722   matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
723                   ConversionPatternRewriter &rewriter) const override {
724     StringRef apiFuncName =
725         TypeSwitch<Type, StringRef>(op.operand().getType())
726             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
727             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
728             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
729 
730     Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
731     Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
732 
733     // A pointer to coroutine resume intrinsic wrapper.
734     addResumeFunction(op->getParentOfType<ModuleOp>());
735     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
736     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
737         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
738 
739     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
740                             ValueRange({operand, handle, resumePtr.res()}));
741     rewriter.eraseOp(op);
742 
743     return success();
744   }
745 };
746 } // namespace
747 
748 //===----------------------------------------------------------------------===//
749 // Convert async.runtime.resume to the corresponding runtime API call.
750 //===----------------------------------------------------------------------===//
751 
752 namespace {
753 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
754 public:
755   using OpConversionPattern::OpConversionPattern;
756 
757   LogicalResult
758   matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands,
759                   ConversionPatternRewriter &rewriter) const override {
760     // A pointer to coroutine resume intrinsic wrapper.
761     addResumeFunction(op->getParentOfType<ModuleOp>());
762     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
763     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
764         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
765 
766     // Call async runtime API to execute a coroutine in the managed thread.
767     auto coroHdl = RuntimeResumeOpAdaptor(operands).handle();
768     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute,
769                                         ValueRange({coroHdl, resumePtr.res()}));
770 
771     return success();
772   }
773 };
774 } // namespace
775 
776 //===----------------------------------------------------------------------===//
777 // Convert async.runtime.store to the corresponding runtime API call.
778 //===----------------------------------------------------------------------===//
779 
780 namespace {
781 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
782 public:
783   using OpConversionPattern::OpConversionPattern;
784 
785   LogicalResult
786   matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands,
787                   ConversionPatternRewriter &rewriter) const override {
788     Location loc = op->getLoc();
789 
790     // Get a pointer to the async value storage from the runtime.
791     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
792     auto storage = RuntimeStoreOpAdaptor(operands).storage();
793     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
794                                               TypeRange(i8Ptr), storage);
795 
796     // Cast from i8* to the LLVM pointer type.
797     auto valueType = op.value().getType();
798     auto llvmValueType = getTypeConverter()->convertType(valueType);
799     if (!llvmValueType)
800       return rewriter.notifyMatchFailure(
801           op, "failed to convert stored value type to LLVM type");
802 
803     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
804         loc, LLVM::LLVMPointerType::get(llvmValueType),
805         storagePtr.getResult(0));
806 
807     // Store the yielded value into the async value storage.
808     auto value = RuntimeStoreOpAdaptor(operands).value();
809     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
810 
811     // Erase the original runtime store operation.
812     rewriter.eraseOp(op);
813 
814     return success();
815   }
816 };
817 } // namespace
818 
819 //===----------------------------------------------------------------------===//
820 // Convert async.runtime.load to the corresponding runtime API call.
821 //===----------------------------------------------------------------------===//
822 
823 namespace {
824 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
825 public:
826   using OpConversionPattern::OpConversionPattern;
827 
828   LogicalResult
829   matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands,
830                   ConversionPatternRewriter &rewriter) const override {
831     Location loc = op->getLoc();
832 
833     // Get a pointer to the async value storage from the runtime.
834     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
835     auto storage = RuntimeLoadOpAdaptor(operands).storage();
836     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
837                                               TypeRange(i8Ptr), storage);
838 
839     // Cast from i8* to the LLVM pointer type.
840     auto valueType = op.result().getType();
841     auto llvmValueType = getTypeConverter()->convertType(valueType);
842     if (!llvmValueType)
843       return rewriter.notifyMatchFailure(
844           op, "failed to convert loaded value type to LLVM type");
845 
846     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
847         loc, LLVM::LLVMPointerType::get(llvmValueType),
848         storagePtr.getResult(0));
849 
850     // Load from the casted pointer.
851     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
852 
853     return success();
854   }
855 };
856 } // namespace
857 
858 //===----------------------------------------------------------------------===//
859 // Convert async.runtime.add_to_group to the corresponding runtime API call.
860 //===----------------------------------------------------------------------===//
861 
862 namespace {
863 class RuntimeAddToGroupOpLowering
864     : public OpConversionPattern<RuntimeAddToGroupOp> {
865 public:
866   using OpConversionPattern::OpConversionPattern;
867 
868   LogicalResult
869   matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands,
870                   ConversionPatternRewriter &rewriter) const override {
871     // Currently we can only add tokens to the group.
872     if (!op.operand().getType().isa<TokenType>())
873       return rewriter.notifyMatchFailure(op, "only token type is supported");
874 
875     // Replace with a runtime API function call.
876     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup,
877                                         rewriter.getI64Type(), operands);
878 
879     return success();
880   }
881 };
882 } // namespace
883 
884 //===----------------------------------------------------------------------===//
885 // Async reference counting ops lowering (`async.runtime.add_ref` and
886 // `async.runtime.drop_ref` to the corresponding API calls).
887 //===----------------------------------------------------------------------===//
888 
889 namespace {
890 template <typename RefCountingOp>
891 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
892 public:
893   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
894                                  StringRef apiFunctionName)
895       : OpConversionPattern<RefCountingOp>(converter, ctx),
896         apiFunctionName(apiFunctionName) {}
897 
898   LogicalResult
899   matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands,
900                   ConversionPatternRewriter &rewriter) const override {
901     auto count =
902         rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(),
903                                     rewriter.getI32IntegerAttr(op.count()));
904 
905     auto operand = typename RefCountingOp::Adaptor(operands).operand();
906     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
907                                         ValueRange({operand, count}));
908 
909     return success();
910   }
911 
912 private:
913   StringRef apiFunctionName;
914 };
915 
916 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
917 public:
918   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
919       : RefCountingOpLowering(converter, ctx, kAddRef) {}
920 };
921 
922 class RuntimeDropRefOpLowering
923     : public RefCountingOpLowering<RuntimeDropRefOp> {
924 public:
925   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
926       : RefCountingOpLowering(converter, ctx, kDropRef) {}
927 };
928 } // namespace
929 
930 //===----------------------------------------------------------------------===//
931 // Convert return operations that return async values from async regions.
932 //===----------------------------------------------------------------------===//
933 
934 namespace {
935 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
936 public:
937   using OpConversionPattern::OpConversionPattern;
938 
939   LogicalResult
940   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
941                   ConversionPatternRewriter &rewriter) const override {
942     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
943     return success();
944   }
945 };
946 } // namespace
947 
948 //===----------------------------------------------------------------------===//
949 
950 namespace {
951 struct ConvertAsyncToLLVMPass
952     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
953   void runOnOperation() override;
954 };
955 } // namespace
956 
957 void ConvertAsyncToLLVMPass::runOnOperation() {
958   ModuleOp module = getOperation();
959   MLIRContext *ctx = module->getContext();
960 
961   // Add declarations for most functions required by the coroutines lowering.
962   // We delay adding the resume function until it's needed because it currently
963   // fails to compile unless '-O0' is specified.
964   addAsyncRuntimeApiDeclarations(module);
965   addCRuntimeDeclarations(module);
966 
967   // Lower async.runtime and async.coro operations to Async Runtime API and
968   // LLVM coroutine intrinsics.
969 
970   // Convert async dialect types and operations to LLVM dialect.
971   AsyncRuntimeTypeConverter converter;
972   RewritePatternSet patterns(ctx);
973 
974   // We use conversion to LLVM type to lower async.runtime load and store
975   // operations.
976   LLVMTypeConverter llvmConverter(ctx);
977   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
978 
979   // Convert async types in function signatures and function calls.
980   populateFuncOpTypeConversionPattern(patterns, converter);
981   populateCallOpTypeConversionPattern(patterns, converter);
982 
983   // Convert return operations inside async.execute regions.
984   patterns.add<ReturnOpOpConversion>(converter, ctx);
985 
986   // Lower async.runtime operations to the async runtime API calls.
987   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
988                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
989                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
990                RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
991                RuntimeDropRefOpLowering>(converter, ctx);
992 
993   // Lower async.runtime operations that rely on LLVM type converter to convert
994   // from async value payload type to the LLVM type.
995   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
996                RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
997                                                               ctx);
998 
999   // Lower async coroutine operations to LLVM coroutine intrinsics.
1000   patterns
1001       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1002            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1003           converter, ctx);
1004 
1005   ConversionTarget target(*ctx);
1006   target.addLegalOp<ConstantOp, UnrealizedConversionCastOp>();
1007   target.addLegalDialect<LLVM::LLVMDialect>();
1008 
1009   // All operations from Async dialect must be lowered to the runtime API and
1010   // LLVM intrinsics calls.
1011   target.addIllegalDialect<AsyncDialect>();
1012 
1013   // Add dynamic legality constraints to apply conversions defined above.
1014   target.addDynamicallyLegalOp<FuncOp>(
1015       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1016   target.addDynamicallyLegalOp<ReturnOp>(
1017       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1018   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1019     return converter.isSignatureLegal(op.getCalleeType());
1020   });
1021 
1022   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1023     signalPassFailure();
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // Patterns for structural type conversions for the Async dialect operations.
1028 //===----------------------------------------------------------------------===//
1029 
1030 namespace {
1031 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1032 public:
1033   using OpConversionPattern::OpConversionPattern;
1034   LogicalResult
1035   matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
1036                   ConversionPatternRewriter &rewriter) const override {
1037     ExecuteOp newOp =
1038         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1039     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1040                                 newOp.getRegion().end());
1041 
1042     // Set operands and update block argument and result types.
1043     newOp->setOperands(operands);
1044     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1045       return failure();
1046     for (auto result : newOp.getResults())
1047       result.setType(typeConverter->convertType(result.getType()));
1048 
1049     rewriter.replaceOp(op, newOp.getResults());
1050     return success();
1051   }
1052 };
1053 
1054 // Dummy pattern to trigger the appropriate type conversion / materialization.
1055 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1056 public:
1057   using OpConversionPattern::OpConversionPattern;
1058   LogicalResult
1059   matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
1060                   ConversionPatternRewriter &rewriter) const override {
1061     rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
1062     return success();
1063   }
1064 };
1065 
1066 // Dummy pattern to trigger the appropriate type conversion / materialization.
1067 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1068 public:
1069   using OpConversionPattern::OpConversionPattern;
1070   LogicalResult
1071   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
1072                   ConversionPatternRewriter &rewriter) const override {
1073     rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
1074     return success();
1075   }
1076 };
1077 } // namespace
1078 
1079 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1080   return std::make_unique<ConvertAsyncToLLVMPass>();
1081 }
1082 
1083 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1084     TypeConverter &typeConverter, RewritePatternSet &patterns,
1085     ConversionTarget &target) {
1086   typeConverter.addConversion([&](TokenType type) { return type; });
1087   typeConverter.addConversion([&](ValueType type) {
1088     Type converted = typeConverter.convertType(type.getValueType());
1089     return converted ? ValueType::get(converted) : converted;
1090   });
1091 
1092   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1093       typeConverter, patterns.getContext());
1094 
1095   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1096       [&](Operation *op) { return typeConverter.isLegal(op); });
1097 }
1098