1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 #define DEBUG_TYPE "convert-async-to-llvm"
28 
29 using namespace mlir;
30 using namespace mlir::async;
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
52 static constexpr const char *kGetValueStorage =
53     "mlirAsyncRuntimeGetValueStorage";
54 static constexpr const char *kAddTokenToGroup =
55     "mlirAsyncRuntimeAddTokenToGroup";
56 static constexpr const char *kAwaitTokenAndExecute =
57     "mlirAsyncRuntimeAwaitTokenAndExecute";
58 static constexpr const char *kAwaitValueAndExecute =
59     "mlirAsyncRuntimeAwaitValueAndExecute";
60 static constexpr const char *kAwaitAllAndExecute =
61     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
62 
63 namespace {
64 /// Async Runtime API function types.
65 ///
66 /// Because we can't create API function signature for type parametrized
67 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
68 /// lowering all async data types become opaque pointers at runtime.
69 struct AsyncAPI {
70   // All async types are lowered to opaque i8* LLVM pointers at runtime.
71   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
72     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
73   }
74 
75   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
76     return LLVM::LLVMTokenType::get(ctx);
77   }
78 
79   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
80     auto ref = opaquePointerType(ctx);
81     auto count = IntegerType::get(ctx, 64);
82     return FunctionType::get(ctx, {ref, count}, {});
83   }
84 
85   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
86     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
87   }
88 
89   static FunctionType createValueFunctionType(MLIRContext *ctx) {
90     auto i64 = IntegerType::get(ctx, 64);
91     auto value = opaquePointerType(ctx);
92     return FunctionType::get(ctx, {i64}, {value});
93   }
94 
95   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
96     auto i64 = IntegerType::get(ctx, 64);
97     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
98   }
99 
100   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
101     auto value = opaquePointerType(ctx);
102     auto storage = opaquePointerType(ctx);
103     return FunctionType::get(ctx, {value}, {storage});
104   }
105 
106   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
107     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
108   }
109 
110   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
111     auto value = opaquePointerType(ctx);
112     return FunctionType::get(ctx, {value}, {});
113   }
114 
115   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
116     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
117   }
118 
119   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
120     auto value = opaquePointerType(ctx);
121     return FunctionType::get(ctx, {value}, {});
122   }
123 
124   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
125     auto i1 = IntegerType::get(ctx, 1);
126     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
127   }
128 
129   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
130     auto value = opaquePointerType(ctx);
131     auto i1 = IntegerType::get(ctx, 1);
132     return FunctionType::get(ctx, {value}, {i1});
133   }
134 
135   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
136     auto i1 = IntegerType::get(ctx, 1);
137     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
138   }
139 
140   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
141     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
142   }
143 
144   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
145     auto value = opaquePointerType(ctx);
146     return FunctionType::get(ctx, {value}, {});
147   }
148 
149   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
150     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
151   }
152 
153   static FunctionType executeFunctionType(MLIRContext *ctx) {
154     auto hdl = opaquePointerType(ctx);
155     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
156     return FunctionType::get(ctx, {hdl, resume}, {});
157   }
158 
159   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
160     auto i64 = IntegerType::get(ctx, 64);
161     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
162                              {i64});
163   }
164 
165   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
166     auto hdl = opaquePointerType(ctx);
167     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
168     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
169   }
170 
171   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
172     auto value = opaquePointerType(ctx);
173     auto hdl = opaquePointerType(ctx);
174     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
175     return FunctionType::get(ctx, {value, hdl, resume}, {});
176   }
177 
178   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
179     auto hdl = opaquePointerType(ctx);
180     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
181     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
182   }
183 
184   // Auxiliary coroutine resume intrinsic wrapper.
185   static Type resumeFunctionType(MLIRContext *ctx) {
186     auto voidTy = LLVM::LLVMVoidType::get(ctx);
187     auto i8Ptr = opaquePointerType(ctx);
188     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
189   }
190 };
191 } // namespace
192 
193 /// Adds Async Runtime C API declarations to the module.
194 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
195   auto builder =
196       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
197 
198   auto addFuncDecl = [&](StringRef name, FunctionType type) {
199     if (module.lookupSymbol(name))
200       return;
201     builder.create<FuncOp>(name, type).setPrivate();
202   };
203 
204   MLIRContext *ctx = module.getContext();
205   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
206   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
207   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
208   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
209   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
210   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
211   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
212   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
213   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
214   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
215   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
216   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
217   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
218   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
219   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
220   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
221   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
222   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
223   addFuncDecl(kAwaitTokenAndExecute,
224               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
225   addFuncDecl(kAwaitValueAndExecute,
226               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
227   addFuncDecl(kAwaitAllAndExecute,
228               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // Coroutine resume function wrapper.
233 //===----------------------------------------------------------------------===//
234 
235 static constexpr const char *kResume = "__resume";
236 
237 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
238 /// intrinsics. We need this function to be able to pass it to the async
239 /// runtime execute API.
240 static void addResumeFunction(ModuleOp module) {
241   if (module.lookupSymbol(kResume))
242     return;
243 
244   MLIRContext *ctx = module.getContext();
245   auto loc = module.getLoc();
246   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
247 
248   auto voidTy = LLVM::LLVMVoidType::get(ctx);
249   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
250 
251   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
252       kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
253   resumeOp.setPrivate();
254 
255   auto *block = resumeOp.addEntryBlock();
256   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
257 
258   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
259   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // Convert Async dialect types to LLVM types.
264 //===----------------------------------------------------------------------===//
265 
266 namespace {
267 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
268 /// their runtime type (opaque pointers) and does not convert any other types.
269 class AsyncRuntimeTypeConverter : public TypeConverter {
270 public:
271   AsyncRuntimeTypeConverter() {
272     addConversion([](Type type) { return type; });
273     addConversion(convertAsyncTypes);
274   }
275 
276   static Optional<Type> convertAsyncTypes(Type type) {
277     if (type.isa<TokenType, GroupType, ValueType>())
278       return AsyncAPI::opaquePointerType(type.getContext());
279 
280     if (type.isa<CoroIdType, CoroStateType>())
281       return AsyncAPI::tokenType(type.getContext());
282     if (type.isa<CoroHandleType>())
283       return AsyncAPI::opaquePointerType(type.getContext());
284 
285     return llvm::None;
286   }
287 };
288 } // namespace
289 
290 //===----------------------------------------------------------------------===//
291 // Convert async.coro.id to @llvm.coro.id intrinsic.
292 //===----------------------------------------------------------------------===//
293 
294 namespace {
295 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
296 public:
297   using OpConversionPattern::OpConversionPattern;
298 
299   LogicalResult
300   matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
301                   ConversionPatternRewriter &rewriter) const override {
302     auto token = AsyncAPI::tokenType(op->getContext());
303     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
304     auto loc = op->getLoc();
305 
306     // Constants for initializing coroutine frame.
307     auto constZero = rewriter.create<LLVM::ConstantOp>(
308         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
309     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
310 
311     // Get coroutine id: @llvm.coro.id.
312     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
313         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
314 
315     return success();
316   }
317 };
318 } // namespace
319 
320 //===----------------------------------------------------------------------===//
321 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
322 //===----------------------------------------------------------------------===//
323 
324 namespace {
325 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
326 public:
327   using OpConversionPattern::OpConversionPattern;
328 
329   LogicalResult
330   matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
331                   ConversionPatternRewriter &rewriter) const override {
332     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
333     auto loc = op->getLoc();
334 
335     // Get coroutine frame size: @llvm.coro.size.i64.
336     Value coroSize =
337         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
338     // The coroutine lowering doesn't properly account for alignment of the
339     // frame, so align everything to 64 bytes which ought to be enough for
340     // everyone. https://llvm.org/PR53148
341     constexpr int64_t coroAlign = 64;
342     auto makeConstant = [&](uint64_t c) {
343       return rewriter.create<LLVM::ConstantOp>(
344           op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(c));
345     };
346     // Round up the size to the alignment. This is a requirement of
347     // aligned_alloc.
348     coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize,
349                                             makeConstant(coroAlign - 1));
350     coroSize = rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize,
351                                             makeConstant(-coroAlign));
352 
353     // Allocate memory for the coroutine frame.
354     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
355         op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
356     auto coroAlloc = rewriter.create<LLVM::CallOp>(
357         loc, i8Ptr, SymbolRefAttr::get(allocFuncOp),
358         ValueRange{makeConstant(coroAlign), coroSize});
359 
360     // Begin a coroutine: @llvm.coro.begin.
361     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
362     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
363         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
364 
365     return success();
366   }
367 };
368 } // namespace
369 
370 //===----------------------------------------------------------------------===//
371 // Convert async.coro.free to @llvm.coro.free intrinsic.
372 //===----------------------------------------------------------------------===//
373 
374 namespace {
375 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
376 public:
377   using OpConversionPattern::OpConversionPattern;
378 
379   LogicalResult
380   matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
381                   ConversionPatternRewriter &rewriter) const override {
382     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
383     auto loc = op->getLoc();
384 
385     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
386     auto coroMem =
387         rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
388 
389     // Free the memory.
390     auto freeFuncOp =
391         LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
392     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
393                                               SymbolRefAttr::get(freeFuncOp),
394                                               ValueRange(coroMem.getResult()));
395 
396     return success();
397   }
398 };
399 } // namespace
400 
401 //===----------------------------------------------------------------------===//
402 // Convert async.coro.end to @llvm.coro.end intrinsic.
403 //===----------------------------------------------------------------------===//
404 
405 namespace {
406 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
407 public:
408   using OpConversionPattern::OpConversionPattern;
409 
410   LogicalResult
411   matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
412                   ConversionPatternRewriter &rewriter) const override {
413     // We are not in the block that is part of the unwind sequence.
414     auto constFalse = rewriter.create<LLVM::ConstantOp>(
415         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
416 
417     // Mark the end of a coroutine: @llvm.coro.end.
418     auto coroHdl = adaptor.handle();
419     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
420                                      ValueRange({coroHdl, constFalse}));
421     rewriter.eraseOp(op);
422 
423     return success();
424   }
425 };
426 } // namespace
427 
428 //===----------------------------------------------------------------------===//
429 // Convert async.coro.save to @llvm.coro.save intrinsic.
430 //===----------------------------------------------------------------------===//
431 
432 namespace {
433 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
434 public:
435   using OpConversionPattern::OpConversionPattern;
436 
437   LogicalResult
438   matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
439                   ConversionPatternRewriter &rewriter) const override {
440     // Save the coroutine state: @llvm.coro.save
441     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
442         op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
443 
444     return success();
445   }
446 };
447 } // namespace
448 
449 //===----------------------------------------------------------------------===//
450 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
451 //===----------------------------------------------------------------------===//
452 
453 namespace {
454 
455 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
456 /// branch to the appropriate block based on the return code.
457 ///
458 /// Before:
459 ///
460 ///   ^suspended:
461 ///     "opBefore"(...)
462 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
463 ///   ^resume:
464 ///     "op"(...)
465 ///   ^cleanup: ...
466 ///   ^suspend: ...
467 ///
468 /// After:
469 ///
470 ///   ^suspended:
471 ///     "opBefore"(...)
472 ///     %suspend = llmv.intr.coro.suspend ...
473 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
474 ///   ^resume:
475 ///     "op"(...)
476 ///   ^cleanup: ...
477 ///   ^suspend: ...
478 ///
479 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
480 public:
481   using OpConversionPattern::OpConversionPattern;
482 
483   LogicalResult
484   matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
485                   ConversionPatternRewriter &rewriter) const override {
486     auto i8 = rewriter.getIntegerType(8);
487     auto i32 = rewriter.getI32Type();
488     auto loc = op->getLoc();
489 
490     // This is not a final suspension point.
491     auto constFalse = rewriter.create<LLVM::ConstantOp>(
492         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
493 
494     // Suspend a coroutine: @llvm.coro.suspend
495     auto coroState = adaptor.state();
496     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
497         loc, i8, ValueRange({coroState, constFalse}));
498 
499     // Cast return code to i32.
500 
501     // After a suspension point decide if we should branch into resume, cleanup
502     // or suspend block of the coroutine (see @llvm.coro.suspend return code
503     // documentation).
504     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
505     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
506                                               op.cleanupDest()};
507     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
508         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
509         /*defaultDestination=*/op.suspendDest(),
510         /*defaultOperands=*/ValueRange(),
511         /*caseValues=*/caseValues,
512         /*caseDestinations=*/caseDest,
513         /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
514         /*branchWeights=*/ArrayRef<int32_t>());
515 
516     return success();
517   }
518 };
519 } // namespace
520 
521 //===----------------------------------------------------------------------===//
522 // Convert async.runtime.create to the corresponding runtime API call.
523 //
524 // To allocate storage for the async values we use getelementptr trick:
525 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
526 //===----------------------------------------------------------------------===//
527 
528 namespace {
529 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
530 public:
531   using OpConversionPattern::OpConversionPattern;
532 
533   LogicalResult
534   matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
535                   ConversionPatternRewriter &rewriter) const override {
536     TypeConverter *converter = getTypeConverter();
537     Type resultType = op->getResultTypes()[0];
538 
539     // Tokens creation maps to a simple function call.
540     if (resultType.isa<TokenType>()) {
541       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
542                                           converter->convertType(resultType));
543       return success();
544     }
545 
546     // To create a value we need to compute the storage requirement.
547     if (auto value = resultType.dyn_cast<ValueType>()) {
548       // Returns the size requirements for the async value storage.
549       auto sizeOf = [&](ValueType valueType) -> Value {
550         auto loc = op->getLoc();
551         auto i64 = rewriter.getI64Type();
552 
553         auto storedType = converter->convertType(valueType.getValueType());
554         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
555 
556         // %Size = getelementptr %T* null, int 1
557         // %SizeI = ptrtoint %T* %Size to i64
558         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
559         auto one = rewriter.create<LLVM::ConstantOp>(
560             loc, i64, rewriter.getI64IntegerAttr(1));
561         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
562                                                 one.getResult());
563         return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
564       };
565 
566       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
567                                           sizeOf(value));
568 
569       return success();
570     }
571 
572     return rewriter.notifyMatchFailure(op, "unsupported async type");
573   }
574 };
575 } // namespace
576 
577 //===----------------------------------------------------------------------===//
578 // Convert async.runtime.create_group to the corresponding runtime API call.
579 //===----------------------------------------------------------------------===//
580 
581 namespace {
582 class RuntimeCreateGroupOpLowering
583     : public OpConversionPattern<RuntimeCreateGroupOp> {
584 public:
585   using OpConversionPattern::OpConversionPattern;
586 
587   LogicalResult
588   matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
589                   ConversionPatternRewriter &rewriter) const override {
590     TypeConverter *converter = getTypeConverter();
591     Type resultType = op.getResult().getType();
592 
593     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup,
594                                         converter->convertType(resultType),
595                                         adaptor.getOperands());
596     return success();
597   }
598 };
599 } // namespace
600 
601 //===----------------------------------------------------------------------===//
602 // Convert async.runtime.set_available to the corresponding runtime API call.
603 //===----------------------------------------------------------------------===//
604 
605 namespace {
606 class RuntimeSetAvailableOpLowering
607     : public OpConversionPattern<RuntimeSetAvailableOp> {
608 public:
609   using OpConversionPattern::OpConversionPattern;
610 
611   LogicalResult
612   matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
613                   ConversionPatternRewriter &rewriter) const override {
614     StringRef apiFuncName =
615         TypeSwitch<Type, StringRef>(op.operand().getType())
616             .Case<TokenType>([](Type) { return kEmplaceToken; })
617             .Case<ValueType>([](Type) { return kEmplaceValue; });
618 
619     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
620                                         adaptor.getOperands());
621 
622     return success();
623   }
624 };
625 } // namespace
626 
627 //===----------------------------------------------------------------------===//
628 // Convert async.runtime.set_error to the corresponding runtime API call.
629 //===----------------------------------------------------------------------===//
630 
631 namespace {
632 class RuntimeSetErrorOpLowering
633     : public OpConversionPattern<RuntimeSetErrorOp> {
634 public:
635   using OpConversionPattern::OpConversionPattern;
636 
637   LogicalResult
638   matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
639                   ConversionPatternRewriter &rewriter) const override {
640     StringRef apiFuncName =
641         TypeSwitch<Type, StringRef>(op.operand().getType())
642             .Case<TokenType>([](Type) { return kSetTokenError; })
643             .Case<ValueType>([](Type) { return kSetValueError; });
644 
645     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
646                                         adaptor.getOperands());
647 
648     return success();
649   }
650 };
651 } // namespace
652 
653 //===----------------------------------------------------------------------===//
654 // Convert async.runtime.is_error to the corresponding runtime API call.
655 //===----------------------------------------------------------------------===//
656 
657 namespace {
658 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
659 public:
660   using OpConversionPattern::OpConversionPattern;
661 
662   LogicalResult
663   matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
664                   ConversionPatternRewriter &rewriter) const override {
665     StringRef apiFuncName =
666         TypeSwitch<Type, StringRef>(op.operand().getType())
667             .Case<TokenType>([](Type) { return kIsTokenError; })
668             .Case<GroupType>([](Type) { return kIsGroupError; })
669             .Case<ValueType>([](Type) { return kIsValueError; });
670 
671     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
672                                         adaptor.getOperands());
673     return success();
674   }
675 };
676 } // namespace
677 
678 //===----------------------------------------------------------------------===//
679 // Convert async.runtime.await to the corresponding runtime API call.
680 //===----------------------------------------------------------------------===//
681 
682 namespace {
683 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
684 public:
685   using OpConversionPattern::OpConversionPattern;
686 
687   LogicalResult
688   matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
689                   ConversionPatternRewriter &rewriter) const override {
690     StringRef apiFuncName =
691         TypeSwitch<Type, StringRef>(op.operand().getType())
692             .Case<TokenType>([](Type) { return kAwaitToken; })
693             .Case<ValueType>([](Type) { return kAwaitValue; })
694             .Case<GroupType>([](Type) { return kAwaitGroup; });
695 
696     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
697                             adaptor.getOperands());
698     rewriter.eraseOp(op);
699 
700     return success();
701   }
702 };
703 } // namespace
704 
705 //===----------------------------------------------------------------------===//
706 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
707 //===----------------------------------------------------------------------===//
708 
709 namespace {
710 class RuntimeAwaitAndResumeOpLowering
711     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
712 public:
713   using OpConversionPattern::OpConversionPattern;
714 
715   LogicalResult
716   matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
717                   ConversionPatternRewriter &rewriter) const override {
718     StringRef apiFuncName =
719         TypeSwitch<Type, StringRef>(op.operand().getType())
720             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
721             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
722             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
723 
724     Value operand = adaptor.operand();
725     Value handle = adaptor.handle();
726 
727     // A pointer to coroutine resume intrinsic wrapper.
728     addResumeFunction(op->getParentOfType<ModuleOp>());
729     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
730     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
731         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
732 
733     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
734                             ValueRange({operand, handle, resumePtr.getRes()}));
735     rewriter.eraseOp(op);
736 
737     return success();
738   }
739 };
740 } // namespace
741 
742 //===----------------------------------------------------------------------===//
743 // Convert async.runtime.resume to the corresponding runtime API call.
744 //===----------------------------------------------------------------------===//
745 
746 namespace {
747 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
748 public:
749   using OpConversionPattern::OpConversionPattern;
750 
751   LogicalResult
752   matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
753                   ConversionPatternRewriter &rewriter) const override {
754     // A pointer to coroutine resume intrinsic wrapper.
755     addResumeFunction(op->getParentOfType<ModuleOp>());
756     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
757     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
758         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
759 
760     // Call async runtime API to execute a coroutine in the managed thread.
761     auto coroHdl = adaptor.handle();
762     rewriter.replaceOpWithNewOp<CallOp>(
763         op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
764 
765     return success();
766   }
767 };
768 } // namespace
769 
770 //===----------------------------------------------------------------------===//
771 // Convert async.runtime.store to the corresponding runtime API call.
772 //===----------------------------------------------------------------------===//
773 
774 namespace {
775 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
776 public:
777   using OpConversionPattern::OpConversionPattern;
778 
779   LogicalResult
780   matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
781                   ConversionPatternRewriter &rewriter) const override {
782     Location loc = op->getLoc();
783 
784     // Get a pointer to the async value storage from the runtime.
785     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
786     auto storage = adaptor.storage();
787     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
788                                               TypeRange(i8Ptr), storage);
789 
790     // Cast from i8* to the LLVM pointer type.
791     auto valueType = op.value().getType();
792     auto llvmValueType = getTypeConverter()->convertType(valueType);
793     if (!llvmValueType)
794       return rewriter.notifyMatchFailure(
795           op, "failed to convert stored value type to LLVM type");
796 
797     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
798         loc, LLVM::LLVMPointerType::get(llvmValueType),
799         storagePtr.getResult(0));
800 
801     // Store the yielded value into the async value storage.
802     auto value = adaptor.value();
803     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
804 
805     // Erase the original runtime store operation.
806     rewriter.eraseOp(op);
807 
808     return success();
809   }
810 };
811 } // namespace
812 
813 //===----------------------------------------------------------------------===//
814 // Convert async.runtime.load to the corresponding runtime API call.
815 //===----------------------------------------------------------------------===//
816 
817 namespace {
818 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
819 public:
820   using OpConversionPattern::OpConversionPattern;
821 
822   LogicalResult
823   matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
824                   ConversionPatternRewriter &rewriter) const override {
825     Location loc = op->getLoc();
826 
827     // Get a pointer to the async value storage from the runtime.
828     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
829     auto storage = adaptor.storage();
830     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
831                                               TypeRange(i8Ptr), storage);
832 
833     // Cast from i8* to the LLVM pointer type.
834     auto valueType = op.result().getType();
835     auto llvmValueType = getTypeConverter()->convertType(valueType);
836     if (!llvmValueType)
837       return rewriter.notifyMatchFailure(
838           op, "failed to convert loaded value type to LLVM type");
839 
840     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
841         loc, LLVM::LLVMPointerType::get(llvmValueType),
842         storagePtr.getResult(0));
843 
844     // Load from the casted pointer.
845     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
846 
847     return success();
848   }
849 };
850 } // namespace
851 
852 //===----------------------------------------------------------------------===//
853 // Convert async.runtime.add_to_group to the corresponding runtime API call.
854 //===----------------------------------------------------------------------===//
855 
856 namespace {
857 class RuntimeAddToGroupOpLowering
858     : public OpConversionPattern<RuntimeAddToGroupOp> {
859 public:
860   using OpConversionPattern::OpConversionPattern;
861 
862   LogicalResult
863   matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
864                   ConversionPatternRewriter &rewriter) const override {
865     // Currently we can only add tokens to the group.
866     if (!op.operand().getType().isa<TokenType>())
867       return rewriter.notifyMatchFailure(op, "only token type is supported");
868 
869     // Replace with a runtime API function call.
870     rewriter.replaceOpWithNewOp<CallOp>(
871         op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
872 
873     return success();
874   }
875 };
876 } // namespace
877 
878 //===----------------------------------------------------------------------===//
879 // Async reference counting ops lowering (`async.runtime.add_ref` and
880 // `async.runtime.drop_ref` to the corresponding API calls).
881 //===----------------------------------------------------------------------===//
882 
883 namespace {
884 template <typename RefCountingOp>
885 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
886 public:
887   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
888                                  StringRef apiFunctionName)
889       : OpConversionPattern<RefCountingOp>(converter, ctx),
890         apiFunctionName(apiFunctionName) {}
891 
892   LogicalResult
893   matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
894                   ConversionPatternRewriter &rewriter) const override {
895     auto count = rewriter.create<arith::ConstantOp>(
896         op->getLoc(), rewriter.getI64Type(),
897         rewriter.getI64IntegerAttr(op.count()));
898 
899     auto operand = adaptor.operand();
900     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
901                                         ValueRange({operand, count}));
902 
903     return success();
904   }
905 
906 private:
907   StringRef apiFunctionName;
908 };
909 
910 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
911 public:
912   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
913       : RefCountingOpLowering(converter, ctx, kAddRef) {}
914 };
915 
916 class RuntimeDropRefOpLowering
917     : public RefCountingOpLowering<RuntimeDropRefOp> {
918 public:
919   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
920       : RefCountingOpLowering(converter, ctx, kDropRef) {}
921 };
922 } // namespace
923 
924 //===----------------------------------------------------------------------===//
925 // Convert return operations that return async values from async regions.
926 //===----------------------------------------------------------------------===//
927 
928 namespace {
929 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
930 public:
931   using OpConversionPattern::OpConversionPattern;
932 
933   LogicalResult
934   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
935                   ConversionPatternRewriter &rewriter) const override {
936     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
937     return success();
938   }
939 };
940 } // namespace
941 
942 //===----------------------------------------------------------------------===//
943 
944 namespace {
945 struct ConvertAsyncToLLVMPass
946     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
947   void runOnOperation() override;
948 };
949 } // namespace
950 
951 void ConvertAsyncToLLVMPass::runOnOperation() {
952   ModuleOp module = getOperation();
953   MLIRContext *ctx = module->getContext();
954 
955   // Add declarations for most functions required by the coroutines lowering.
956   // We delay adding the resume function until it's needed because it currently
957   // fails to compile unless '-O0' is specified.
958   addAsyncRuntimeApiDeclarations(module);
959 
960   // Lower async.runtime and async.coro operations to Async Runtime API and
961   // LLVM coroutine intrinsics.
962 
963   // Convert async dialect types and operations to LLVM dialect.
964   AsyncRuntimeTypeConverter converter;
965   RewritePatternSet patterns(ctx);
966 
967   // We use conversion to LLVM type to lower async.runtime load and store
968   // operations.
969   LLVMTypeConverter llvmConverter(ctx);
970   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
971 
972   // Convert async types in function signatures and function calls.
973   populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, converter);
974   populateCallOpTypeConversionPattern(patterns, converter);
975 
976   // Convert return operations inside async.execute regions.
977   patterns.add<ReturnOpOpConversion>(converter, ctx);
978 
979   // Lower async.runtime operations to the async runtime API calls.
980   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
981                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
982                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
983                RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
984                RuntimeDropRefOpLowering>(converter, ctx);
985 
986   // Lower async.runtime operations that rely on LLVM type converter to convert
987   // from async value payload type to the LLVM type.
988   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
989                RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
990                                                               ctx);
991 
992   // Lower async coroutine operations to LLVM coroutine intrinsics.
993   patterns
994       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
995            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
996           converter, ctx);
997 
998   ConversionTarget target(*ctx);
999   target
1000       .addLegalOp<arith::ConstantOp, ConstantOp, UnrealizedConversionCastOp>();
1001   target.addLegalDialect<LLVM::LLVMDialect>();
1002 
1003   // All operations from Async dialect must be lowered to the runtime API and
1004   // LLVM intrinsics calls.
1005   target.addIllegalDialect<AsyncDialect>();
1006 
1007   // Add dynamic legality constraints to apply conversions defined above.
1008   target.addDynamicallyLegalOp<FuncOp>(
1009       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1010   target.addDynamicallyLegalOp<ReturnOp>(
1011       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1012   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1013     return converter.isSignatureLegal(op.getCalleeType());
1014   });
1015 
1016   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1017     signalPassFailure();
1018 }
1019 
1020 //===----------------------------------------------------------------------===//
1021 // Patterns for structural type conversions for the Async dialect operations.
1022 //===----------------------------------------------------------------------===//
1023 
1024 namespace {
1025 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1026 public:
1027   using OpConversionPattern::OpConversionPattern;
1028   LogicalResult
1029   matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1030                   ConversionPatternRewriter &rewriter) const override {
1031     ExecuteOp newOp =
1032         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1033     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1034                                 newOp.getRegion().end());
1035 
1036     // Set operands and update block argument and result types.
1037     newOp->setOperands(adaptor.getOperands());
1038     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1039       return failure();
1040     for (auto result : newOp.getResults())
1041       result.setType(typeConverter->convertType(result.getType()));
1042 
1043     rewriter.replaceOp(op, newOp.getResults());
1044     return success();
1045   }
1046 };
1047 
1048 // Dummy pattern to trigger the appropriate type conversion / materialization.
1049 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1050 public:
1051   using OpConversionPattern::OpConversionPattern;
1052   LogicalResult
1053   matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1054                   ConversionPatternRewriter &rewriter) const override {
1055     rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1056     return success();
1057   }
1058 };
1059 
1060 // Dummy pattern to trigger the appropriate type conversion / materialization.
1061 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1062 public:
1063   using OpConversionPattern::OpConversionPattern;
1064   LogicalResult
1065   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1066                   ConversionPatternRewriter &rewriter) const override {
1067     rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1068     return success();
1069   }
1070 };
1071 } // namespace
1072 
1073 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1074   return std::make_unique<ConvertAsyncToLLVMPass>();
1075 }
1076 
1077 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1078     TypeConverter &typeConverter, RewritePatternSet &patterns,
1079     ConversionTarget &target) {
1080   typeConverter.addConversion([&](TokenType type) { return type; });
1081   typeConverter.addConversion([&](ValueType type) {
1082     Type converted = typeConverter.convertType(type.getValueType());
1083     return converted ? ValueType::get(converted) : converted;
1084   });
1085 
1086   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1087       typeConverter, patterns.getContext());
1088 
1089   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1090       [&](Operation *op) { return typeConverter.isLegal(op); });
1091 }
1092