1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 #define DEBUG_TYPE "convert-async-to-llvm"
27 
28 using namespace mlir;
29 using namespace mlir::async;
30 
31 //===----------------------------------------------------------------------===//
32 // Async Runtime C API declaration.
33 //===----------------------------------------------------------------------===//
34 
35 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
36 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
37 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
38 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
39 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
40 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
41 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
42 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
43 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
44 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
45 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
46 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
47 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
48 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
49 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
50 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
51 static constexpr const char *kGetValueStorage =
52     "mlirAsyncRuntimeGetValueStorage";
53 static constexpr const char *kAddTokenToGroup =
54     "mlirAsyncRuntimeAddTokenToGroup";
55 static constexpr const char *kAwaitTokenAndExecute =
56     "mlirAsyncRuntimeAwaitTokenAndExecute";
57 static constexpr const char *kAwaitValueAndExecute =
58     "mlirAsyncRuntimeAwaitValueAndExecute";
59 static constexpr const char *kAwaitAllAndExecute =
60     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
61 
62 namespace {
63 /// Async Runtime API function types.
64 ///
65 /// Because we can't create API function signature for type parametrized
66 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
67 /// lowering all async data types become opaque pointers at runtime.
68 struct AsyncAPI {
69   // All async types are lowered to opaque i8* LLVM pointers at runtime.
70   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
71     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
72   }
73 
74   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
75     return LLVM::LLVMTokenType::get(ctx);
76   }
77 
78   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
79     auto ref = opaquePointerType(ctx);
80     auto count = IntegerType::get(ctx, 64);
81     return FunctionType::get(ctx, {ref, count}, {});
82   }
83 
84   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
85     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
86   }
87 
88   static FunctionType createValueFunctionType(MLIRContext *ctx) {
89     auto i64 = IntegerType::get(ctx, 64);
90     auto value = opaquePointerType(ctx);
91     return FunctionType::get(ctx, {i64}, {value});
92   }
93 
94   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
95     auto i64 = IntegerType::get(ctx, 64);
96     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
97   }
98 
99   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
100     auto value = opaquePointerType(ctx);
101     auto storage = opaquePointerType(ctx);
102     return FunctionType::get(ctx, {value}, {storage});
103   }
104 
105   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
106     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
107   }
108 
109   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
110     auto value = opaquePointerType(ctx);
111     return FunctionType::get(ctx, {value}, {});
112   }
113 
114   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
115     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
116   }
117 
118   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
119     auto value = opaquePointerType(ctx);
120     return FunctionType::get(ctx, {value}, {});
121   }
122 
123   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
124     auto i1 = IntegerType::get(ctx, 1);
125     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
126   }
127 
128   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
129     auto value = opaquePointerType(ctx);
130     auto i1 = IntegerType::get(ctx, 1);
131     return FunctionType::get(ctx, {value}, {i1});
132   }
133 
134   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
135     auto i1 = IntegerType::get(ctx, 1);
136     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
137   }
138 
139   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
140     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
141   }
142 
143   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
144     auto value = opaquePointerType(ctx);
145     return FunctionType::get(ctx, {value}, {});
146   }
147 
148   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
149     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
150   }
151 
152   static FunctionType executeFunctionType(MLIRContext *ctx) {
153     auto hdl = opaquePointerType(ctx);
154     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
155     return FunctionType::get(ctx, {hdl, resume}, {});
156   }
157 
158   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
159     auto i64 = IntegerType::get(ctx, 64);
160     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
161                              {i64});
162   }
163 
164   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
165     auto hdl = opaquePointerType(ctx);
166     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
167     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
168   }
169 
170   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
171     auto value = opaquePointerType(ctx);
172     auto hdl = opaquePointerType(ctx);
173     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
174     return FunctionType::get(ctx, {value, hdl, resume}, {});
175   }
176 
177   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
178     auto hdl = opaquePointerType(ctx);
179     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
180     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
181   }
182 
183   // Auxiliary coroutine resume intrinsic wrapper.
184   static Type resumeFunctionType(MLIRContext *ctx) {
185     auto voidTy = LLVM::LLVMVoidType::get(ctx);
186     auto i8Ptr = opaquePointerType(ctx);
187     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
188   }
189 };
190 } // namespace
191 
192 /// Adds Async Runtime C API declarations to the module.
193 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
194   auto builder =
195       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
196 
197   auto addFuncDecl = [&](StringRef name, FunctionType type) {
198     if (module.lookupSymbol(name))
199       return;
200     builder.create<FuncOp>(name, type).setPrivate();
201   };
202 
203   MLIRContext *ctx = module.getContext();
204   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
205   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
206   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
207   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
208   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
209   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
210   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
211   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
212   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
213   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
214   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
215   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
216   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
217   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
218   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
219   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
220   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
221   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
222   addFuncDecl(kAwaitTokenAndExecute,
223               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
224   addFuncDecl(kAwaitValueAndExecute,
225               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
226   addFuncDecl(kAwaitAllAndExecute,
227               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Add malloc/free declarations to the module.
232 //===----------------------------------------------------------------------===//
233 
234 static constexpr const char *kMalloc = "malloc";
235 static constexpr const char *kFree = "free";
236 
237 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
238                             StringRef name, Type ret, ArrayRef<Type> params) {
239   if (module.lookupSymbol(name))
240     return;
241   Type type = LLVM::LLVMFunctionType::get(ret, params);
242   builder.create<LLVM::LLVMFuncOp>(name, type);
243 }
244 
245 /// Adds malloc/free declarations to the module.
246 static void addCRuntimeDeclarations(ModuleOp module) {
247   using namespace mlir::LLVM;
248 
249   MLIRContext *ctx = module.getContext();
250   auto builder =
251       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
252 
253   auto voidTy = LLVMVoidType::get(ctx);
254   auto i64 = IntegerType::get(ctx, 64);
255   auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8));
256 
257   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
258   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // Coroutine resume function wrapper.
263 //===----------------------------------------------------------------------===//
264 
265 static constexpr const char *kResume = "__resume";
266 
267 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
268 /// intrinsics. We need this function to be able to pass it to the async
269 /// runtime execute API.
270 static void addResumeFunction(ModuleOp module) {
271   if (module.lookupSymbol(kResume))
272     return;
273 
274   MLIRContext *ctx = module.getContext();
275   auto loc = module.getLoc();
276   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
277 
278   auto voidTy = LLVM::LLVMVoidType::get(ctx);
279   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
280 
281   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
282       kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
283   resumeOp.setPrivate();
284 
285   auto *block = resumeOp.addEntryBlock();
286   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
287 
288   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
289   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // Convert Async dialect types to LLVM types.
294 //===----------------------------------------------------------------------===//
295 
296 namespace {
297 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
298 /// their runtime type (opaque pointers) and does not convert any other types.
299 class AsyncRuntimeTypeConverter : public TypeConverter {
300 public:
301   AsyncRuntimeTypeConverter() {
302     addConversion([](Type type) { return type; });
303     addConversion(convertAsyncTypes);
304   }
305 
306   static Optional<Type> convertAsyncTypes(Type type) {
307     if (type.isa<TokenType, GroupType, ValueType>())
308       return AsyncAPI::opaquePointerType(type.getContext());
309 
310     if (type.isa<CoroIdType, CoroStateType>())
311       return AsyncAPI::tokenType(type.getContext());
312     if (type.isa<CoroHandleType>())
313       return AsyncAPI::opaquePointerType(type.getContext());
314 
315     return llvm::None;
316   }
317 };
318 } // namespace
319 
320 //===----------------------------------------------------------------------===//
321 // Convert async.coro.id to @llvm.coro.id intrinsic.
322 //===----------------------------------------------------------------------===//
323 
324 namespace {
325 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
326 public:
327   using OpConversionPattern::OpConversionPattern;
328 
329   LogicalResult
330   matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
331                   ConversionPatternRewriter &rewriter) const override {
332     auto token = AsyncAPI::tokenType(op->getContext());
333     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
334     auto loc = op->getLoc();
335 
336     // Constants for initializing coroutine frame.
337     auto constZero = rewriter.create<LLVM::ConstantOp>(
338         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
339     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
340 
341     // Get coroutine id: @llvm.coro.id.
342     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
343         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
344 
345     return success();
346   }
347 };
348 } // namespace
349 
350 //===----------------------------------------------------------------------===//
351 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
352 //===----------------------------------------------------------------------===//
353 
354 namespace {
355 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
356 public:
357   using OpConversionPattern::OpConversionPattern;
358 
359   LogicalResult
360   matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
361                   ConversionPatternRewriter &rewriter) const override {
362     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
363     auto loc = op->getLoc();
364 
365     // Get coroutine frame size: @llvm.coro.size.i64.
366     auto coroSize =
367         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
368 
369     // Allocate memory for the coroutine frame.
370     auto coroAlloc = rewriter.create<LLVM::CallOp>(
371         loc, i8Ptr, SymbolRefAttr::get(rewriter.getContext(), kMalloc),
372         ValueRange(coroSize.getResult()));
373 
374     // Begin a coroutine: @llvm.coro.begin.
375     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
376     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
377         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
378 
379     return success();
380   }
381 };
382 } // namespace
383 
384 //===----------------------------------------------------------------------===//
385 // Convert async.coro.free to @llvm.coro.free intrinsic.
386 //===----------------------------------------------------------------------===//
387 
388 namespace {
389 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
390 public:
391   using OpConversionPattern::OpConversionPattern;
392 
393   LogicalResult
394   matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
395                   ConversionPatternRewriter &rewriter) const override {
396     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
397     auto loc = op->getLoc();
398 
399     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
400     auto coroMem =
401         rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
402 
403     // Free the memory.
404     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
405         op, TypeRange(), SymbolRefAttr::get(rewriter.getContext(), kFree),
406         ValueRange(coroMem.getResult()));
407 
408     return success();
409   }
410 };
411 } // namespace
412 
413 //===----------------------------------------------------------------------===//
414 // Convert async.coro.end to @llvm.coro.end intrinsic.
415 //===----------------------------------------------------------------------===//
416 
417 namespace {
418 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
419 public:
420   using OpConversionPattern::OpConversionPattern;
421 
422   LogicalResult
423   matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
424                   ConversionPatternRewriter &rewriter) const override {
425     // We are not in the block that is part of the unwind sequence.
426     auto constFalse = rewriter.create<LLVM::ConstantOp>(
427         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
428 
429     // Mark the end of a coroutine: @llvm.coro.end.
430     auto coroHdl = adaptor.handle();
431     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
432                                      ValueRange({coroHdl, constFalse}));
433     rewriter.eraseOp(op);
434 
435     return success();
436   }
437 };
438 } // namespace
439 
440 //===----------------------------------------------------------------------===//
441 // Convert async.coro.save to @llvm.coro.save intrinsic.
442 //===----------------------------------------------------------------------===//
443 
444 namespace {
445 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
446 public:
447   using OpConversionPattern::OpConversionPattern;
448 
449   LogicalResult
450   matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
451                   ConversionPatternRewriter &rewriter) const override {
452     // Save the coroutine state: @llvm.coro.save
453     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
454         op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
455 
456     return success();
457   }
458 };
459 } // namespace
460 
461 //===----------------------------------------------------------------------===//
462 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
463 //===----------------------------------------------------------------------===//
464 
465 namespace {
466 
467 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
468 /// branch to the appropriate block based on the return code.
469 ///
470 /// Before:
471 ///
472 ///   ^suspended:
473 ///     "opBefore"(...)
474 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
475 ///   ^resume:
476 ///     "op"(...)
477 ///   ^cleanup: ...
478 ///   ^suspend: ...
479 ///
480 /// After:
481 ///
482 ///   ^suspended:
483 ///     "opBefore"(...)
484 ///     %suspend = llmv.intr.coro.suspend ...
485 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
486 ///   ^resume:
487 ///     "op"(...)
488 ///   ^cleanup: ...
489 ///   ^suspend: ...
490 ///
491 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
492 public:
493   using OpConversionPattern::OpConversionPattern;
494 
495   LogicalResult
496   matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
497                   ConversionPatternRewriter &rewriter) const override {
498     auto i8 = rewriter.getIntegerType(8);
499     auto i32 = rewriter.getI32Type();
500     auto loc = op->getLoc();
501 
502     // This is not a final suspension point.
503     auto constFalse = rewriter.create<LLVM::ConstantOp>(
504         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
505 
506     // Suspend a coroutine: @llvm.coro.suspend
507     auto coroState = adaptor.state();
508     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
509         loc, i8, ValueRange({coroState, constFalse}));
510 
511     // Cast return code to i32.
512 
513     // After a suspension point decide if we should branch into resume, cleanup
514     // or suspend block of the coroutine (see @llvm.coro.suspend return code
515     // documentation).
516     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
517     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
518                                               op.cleanupDest()};
519     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
520         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
521         /*defaultDestination=*/op.suspendDest(),
522         /*defaultOperands=*/ValueRange(),
523         /*caseValues=*/caseValues,
524         /*caseDestinations=*/caseDest,
525         /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
526         /*branchWeights=*/ArrayRef<int32_t>());
527 
528     return success();
529   }
530 };
531 } // namespace
532 
533 //===----------------------------------------------------------------------===//
534 // Convert async.runtime.create to the corresponding runtime API call.
535 //
536 // To allocate storage for the async values we use getelementptr trick:
537 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
538 //===----------------------------------------------------------------------===//
539 
540 namespace {
541 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
542 public:
543   using OpConversionPattern::OpConversionPattern;
544 
545   LogicalResult
546   matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
547                   ConversionPatternRewriter &rewriter) const override {
548     TypeConverter *converter = getTypeConverter();
549     Type resultType = op->getResultTypes()[0];
550 
551     // Tokens creation maps to a simple function call.
552     if (resultType.isa<TokenType>()) {
553       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
554                                           converter->convertType(resultType));
555       return success();
556     }
557 
558     // To create a value we need to compute the storage requirement.
559     if (auto value = resultType.dyn_cast<ValueType>()) {
560       // Returns the size requirements for the async value storage.
561       auto sizeOf = [&](ValueType valueType) -> Value {
562         auto loc = op->getLoc();
563         auto i64 = rewriter.getI64Type();
564 
565         auto storedType = converter->convertType(valueType.getValueType());
566         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
567 
568         // %Size = getelementptr %T* null, int 1
569         // %SizeI = ptrtoint %T* %Size to i64
570         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
571         auto one = rewriter.create<LLVM::ConstantOp>(
572             loc, i64, rewriter.getI64IntegerAttr(1));
573         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
574                                                 one.getResult());
575         return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
576       };
577 
578       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
579                                           sizeOf(value));
580 
581       return success();
582     }
583 
584     return rewriter.notifyMatchFailure(op, "unsupported async type");
585   }
586 };
587 } // namespace
588 
589 //===----------------------------------------------------------------------===//
590 // Convert async.runtime.create_group to the corresponding runtime API call.
591 //===----------------------------------------------------------------------===//
592 
593 namespace {
594 class RuntimeCreateGroupOpLowering
595     : public OpConversionPattern<RuntimeCreateGroupOp> {
596 public:
597   using OpConversionPattern::OpConversionPattern;
598 
599   LogicalResult
600   matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
601                   ConversionPatternRewriter &rewriter) const override {
602     TypeConverter *converter = getTypeConverter();
603     Type resultType = op.getResult().getType();
604 
605     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup,
606                                         converter->convertType(resultType),
607                                         adaptor.getOperands());
608     return success();
609   }
610 };
611 } // namespace
612 
613 //===----------------------------------------------------------------------===//
614 // Convert async.runtime.set_available to the corresponding runtime API call.
615 //===----------------------------------------------------------------------===//
616 
617 namespace {
618 class RuntimeSetAvailableOpLowering
619     : public OpConversionPattern<RuntimeSetAvailableOp> {
620 public:
621   using OpConversionPattern::OpConversionPattern;
622 
623   LogicalResult
624   matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
625                   ConversionPatternRewriter &rewriter) const override {
626     StringRef apiFuncName =
627         TypeSwitch<Type, StringRef>(op.operand().getType())
628             .Case<TokenType>([](Type) { return kEmplaceToken; })
629             .Case<ValueType>([](Type) { return kEmplaceValue; });
630 
631     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
632                                         adaptor.getOperands());
633 
634     return success();
635   }
636 };
637 } // namespace
638 
639 //===----------------------------------------------------------------------===//
640 // Convert async.runtime.set_error to the corresponding runtime API call.
641 //===----------------------------------------------------------------------===//
642 
643 namespace {
644 class RuntimeSetErrorOpLowering
645     : public OpConversionPattern<RuntimeSetErrorOp> {
646 public:
647   using OpConversionPattern::OpConversionPattern;
648 
649   LogicalResult
650   matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
651                   ConversionPatternRewriter &rewriter) const override {
652     StringRef apiFuncName =
653         TypeSwitch<Type, StringRef>(op.operand().getType())
654             .Case<TokenType>([](Type) { return kSetTokenError; })
655             .Case<ValueType>([](Type) { return kSetValueError; });
656 
657     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
658                                         adaptor.getOperands());
659 
660     return success();
661   }
662 };
663 } // namespace
664 
665 //===----------------------------------------------------------------------===//
666 // Convert async.runtime.is_error to the corresponding runtime API call.
667 //===----------------------------------------------------------------------===//
668 
669 namespace {
670 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
671 public:
672   using OpConversionPattern::OpConversionPattern;
673 
674   LogicalResult
675   matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
676                   ConversionPatternRewriter &rewriter) const override {
677     StringRef apiFuncName =
678         TypeSwitch<Type, StringRef>(op.operand().getType())
679             .Case<TokenType>([](Type) { return kIsTokenError; })
680             .Case<GroupType>([](Type) { return kIsGroupError; })
681             .Case<ValueType>([](Type) { return kIsValueError; });
682 
683     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
684                                         adaptor.getOperands());
685     return success();
686   }
687 };
688 } // namespace
689 
690 //===----------------------------------------------------------------------===//
691 // Convert async.runtime.await to the corresponding runtime API call.
692 //===----------------------------------------------------------------------===//
693 
694 namespace {
695 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
696 public:
697   using OpConversionPattern::OpConversionPattern;
698 
699   LogicalResult
700   matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
701                   ConversionPatternRewriter &rewriter) const override {
702     StringRef apiFuncName =
703         TypeSwitch<Type, StringRef>(op.operand().getType())
704             .Case<TokenType>([](Type) { return kAwaitToken; })
705             .Case<ValueType>([](Type) { return kAwaitValue; })
706             .Case<GroupType>([](Type) { return kAwaitGroup; });
707 
708     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
709                             adaptor.getOperands());
710     rewriter.eraseOp(op);
711 
712     return success();
713   }
714 };
715 } // namespace
716 
717 //===----------------------------------------------------------------------===//
718 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
719 //===----------------------------------------------------------------------===//
720 
721 namespace {
722 class RuntimeAwaitAndResumeOpLowering
723     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
724 public:
725   using OpConversionPattern::OpConversionPattern;
726 
727   LogicalResult
728   matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
729                   ConversionPatternRewriter &rewriter) const override {
730     StringRef apiFuncName =
731         TypeSwitch<Type, StringRef>(op.operand().getType())
732             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
733             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
734             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
735 
736     Value operand = adaptor.operand();
737     Value handle = adaptor.handle();
738 
739     // A pointer to coroutine resume intrinsic wrapper.
740     addResumeFunction(op->getParentOfType<ModuleOp>());
741     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
742     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
743         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
744 
745     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
746                             ValueRange({operand, handle, resumePtr.getRes()}));
747     rewriter.eraseOp(op);
748 
749     return success();
750   }
751 };
752 } // namespace
753 
754 //===----------------------------------------------------------------------===//
755 // Convert async.runtime.resume to the corresponding runtime API call.
756 //===----------------------------------------------------------------------===//
757 
758 namespace {
759 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
760 public:
761   using OpConversionPattern::OpConversionPattern;
762 
763   LogicalResult
764   matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
765                   ConversionPatternRewriter &rewriter) const override {
766     // A pointer to coroutine resume intrinsic wrapper.
767     addResumeFunction(op->getParentOfType<ModuleOp>());
768     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
769     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
770         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
771 
772     // Call async runtime API to execute a coroutine in the managed thread.
773     auto coroHdl = adaptor.handle();
774     rewriter.replaceOpWithNewOp<CallOp>(
775         op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
776 
777     return success();
778   }
779 };
780 } // namespace
781 
782 //===----------------------------------------------------------------------===//
783 // Convert async.runtime.store to the corresponding runtime API call.
784 //===----------------------------------------------------------------------===//
785 
786 namespace {
787 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
788 public:
789   using OpConversionPattern::OpConversionPattern;
790 
791   LogicalResult
792   matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
793                   ConversionPatternRewriter &rewriter) const override {
794     Location loc = op->getLoc();
795 
796     // Get a pointer to the async value storage from the runtime.
797     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
798     auto storage = adaptor.storage();
799     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
800                                               TypeRange(i8Ptr), storage);
801 
802     // Cast from i8* to the LLVM pointer type.
803     auto valueType = op.value().getType();
804     auto llvmValueType = getTypeConverter()->convertType(valueType);
805     if (!llvmValueType)
806       return rewriter.notifyMatchFailure(
807           op, "failed to convert stored value type to LLVM type");
808 
809     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
810         loc, LLVM::LLVMPointerType::get(llvmValueType),
811         storagePtr.getResult(0));
812 
813     // Store the yielded value into the async value storage.
814     auto value = adaptor.value();
815     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
816 
817     // Erase the original runtime store operation.
818     rewriter.eraseOp(op);
819 
820     return success();
821   }
822 };
823 } // namespace
824 
825 //===----------------------------------------------------------------------===//
826 // Convert async.runtime.load to the corresponding runtime API call.
827 //===----------------------------------------------------------------------===//
828 
829 namespace {
830 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
831 public:
832   using OpConversionPattern::OpConversionPattern;
833 
834   LogicalResult
835   matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
836                   ConversionPatternRewriter &rewriter) const override {
837     Location loc = op->getLoc();
838 
839     // Get a pointer to the async value storage from the runtime.
840     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
841     auto storage = adaptor.storage();
842     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
843                                               TypeRange(i8Ptr), storage);
844 
845     // Cast from i8* to the LLVM pointer type.
846     auto valueType = op.result().getType();
847     auto llvmValueType = getTypeConverter()->convertType(valueType);
848     if (!llvmValueType)
849       return rewriter.notifyMatchFailure(
850           op, "failed to convert loaded value type to LLVM type");
851 
852     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
853         loc, LLVM::LLVMPointerType::get(llvmValueType),
854         storagePtr.getResult(0));
855 
856     // Load from the casted pointer.
857     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
858 
859     return success();
860   }
861 };
862 } // namespace
863 
864 //===----------------------------------------------------------------------===//
865 // Convert async.runtime.add_to_group to the corresponding runtime API call.
866 //===----------------------------------------------------------------------===//
867 
868 namespace {
869 class RuntimeAddToGroupOpLowering
870     : public OpConversionPattern<RuntimeAddToGroupOp> {
871 public:
872   using OpConversionPattern::OpConversionPattern;
873 
874   LogicalResult
875   matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
876                   ConversionPatternRewriter &rewriter) const override {
877     // Currently we can only add tokens to the group.
878     if (!op.operand().getType().isa<TokenType>())
879       return rewriter.notifyMatchFailure(op, "only token type is supported");
880 
881     // Replace with a runtime API function call.
882     rewriter.replaceOpWithNewOp<CallOp>(
883         op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
884 
885     return success();
886   }
887 };
888 } // namespace
889 
890 //===----------------------------------------------------------------------===//
891 // Async reference counting ops lowering (`async.runtime.add_ref` and
892 // `async.runtime.drop_ref` to the corresponding API calls).
893 //===----------------------------------------------------------------------===//
894 
895 namespace {
896 template <typename RefCountingOp>
897 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
898 public:
899   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
900                                  StringRef apiFunctionName)
901       : OpConversionPattern<RefCountingOp>(converter, ctx),
902         apiFunctionName(apiFunctionName) {}
903 
904   LogicalResult
905   matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
906                   ConversionPatternRewriter &rewriter) const override {
907     auto count = rewriter.create<arith::ConstantOp>(
908         op->getLoc(), rewriter.getI64Type(),
909         rewriter.getI64IntegerAttr(op.count()));
910 
911     auto operand = adaptor.operand();
912     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
913                                         ValueRange({operand, count}));
914 
915     return success();
916   }
917 
918 private:
919   StringRef apiFunctionName;
920 };
921 
922 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
923 public:
924   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
925       : RefCountingOpLowering(converter, ctx, kAddRef) {}
926 };
927 
928 class RuntimeDropRefOpLowering
929     : public RefCountingOpLowering<RuntimeDropRefOp> {
930 public:
931   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
932       : RefCountingOpLowering(converter, ctx, kDropRef) {}
933 };
934 } // namespace
935 
936 //===----------------------------------------------------------------------===//
937 // Convert return operations that return async values from async regions.
938 //===----------------------------------------------------------------------===//
939 
940 namespace {
941 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
942 public:
943   using OpConversionPattern::OpConversionPattern;
944 
945   LogicalResult
946   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
947                   ConversionPatternRewriter &rewriter) const override {
948     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
949     return success();
950   }
951 };
952 } // namespace
953 
954 //===----------------------------------------------------------------------===//
955 
956 namespace {
957 struct ConvertAsyncToLLVMPass
958     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
959   void runOnOperation() override;
960 };
961 } // namespace
962 
963 void ConvertAsyncToLLVMPass::runOnOperation() {
964   ModuleOp module = getOperation();
965   MLIRContext *ctx = module->getContext();
966 
967   // Add declarations for most functions required by the coroutines lowering.
968   // We delay adding the resume function until it's needed because it currently
969   // fails to compile unless '-O0' is specified.
970   addAsyncRuntimeApiDeclarations(module);
971   addCRuntimeDeclarations(module);
972 
973   // Lower async.runtime and async.coro operations to Async Runtime API and
974   // LLVM coroutine intrinsics.
975 
976   // Convert async dialect types and operations to LLVM dialect.
977   AsyncRuntimeTypeConverter converter;
978   RewritePatternSet patterns(ctx);
979 
980   // We use conversion to LLVM type to lower async.runtime load and store
981   // operations.
982   LLVMTypeConverter llvmConverter(ctx);
983   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
984 
985   // Convert async types in function signatures and function calls.
986   populateFuncOpTypeConversionPattern(patterns, converter);
987   populateCallOpTypeConversionPattern(patterns, converter);
988 
989   // Convert return operations inside async.execute regions.
990   patterns.add<ReturnOpOpConversion>(converter, ctx);
991 
992   // Lower async.runtime operations to the async runtime API calls.
993   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
994                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
995                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
996                RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
997                RuntimeDropRefOpLowering>(converter, ctx);
998 
999   // Lower async.runtime operations that rely on LLVM type converter to convert
1000   // from async value payload type to the LLVM type.
1001   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1002                RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
1003                                                               ctx);
1004 
1005   // Lower async coroutine operations to LLVM coroutine intrinsics.
1006   patterns
1007       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1008            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1009           converter, ctx);
1010 
1011   ConversionTarget target(*ctx);
1012   target
1013       .addLegalOp<arith::ConstantOp, ConstantOp, UnrealizedConversionCastOp>();
1014   target.addLegalDialect<LLVM::LLVMDialect>();
1015 
1016   // All operations from Async dialect must be lowered to the runtime API and
1017   // LLVM intrinsics calls.
1018   target.addIllegalDialect<AsyncDialect>();
1019 
1020   // Add dynamic legality constraints to apply conversions defined above.
1021   target.addDynamicallyLegalOp<FuncOp>(
1022       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1023   target.addDynamicallyLegalOp<ReturnOp>(
1024       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1025   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1026     return converter.isSignatureLegal(op.getCalleeType());
1027   });
1028 
1029   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1030     signalPassFailure();
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // Patterns for structural type conversions for the Async dialect operations.
1035 //===----------------------------------------------------------------------===//
1036 
1037 namespace {
1038 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1039 public:
1040   using OpConversionPattern::OpConversionPattern;
1041   LogicalResult
1042   matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1043                   ConversionPatternRewriter &rewriter) const override {
1044     ExecuteOp newOp =
1045         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1046     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1047                                 newOp.getRegion().end());
1048 
1049     // Set operands and update block argument and result types.
1050     newOp->setOperands(adaptor.getOperands());
1051     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1052       return failure();
1053     for (auto result : newOp.getResults())
1054       result.setType(typeConverter->convertType(result.getType()));
1055 
1056     rewriter.replaceOp(op, newOp.getResults());
1057     return success();
1058   }
1059 };
1060 
1061 // Dummy pattern to trigger the appropriate type conversion / materialization.
1062 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1063 public:
1064   using OpConversionPattern::OpConversionPattern;
1065   LogicalResult
1066   matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1067                   ConversionPatternRewriter &rewriter) const override {
1068     rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1069     return success();
1070   }
1071 };
1072 
1073 // Dummy pattern to trigger the appropriate type conversion / materialization.
1074 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1075 public:
1076   using OpConversionPattern::OpConversionPattern;
1077   LogicalResult
1078   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1079                   ConversionPatternRewriter &rewriter) const override {
1080     rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1081     return success();
1082   }
1083 };
1084 } // namespace
1085 
1086 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1087   return std::make_unique<ConvertAsyncToLLVMPass>();
1088 }
1089 
1090 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1091     TypeConverter &typeConverter, RewritePatternSet &patterns,
1092     ConversionTarget &target) {
1093   typeConverter.addConversion([&](TokenType type) { return type; });
1094   typeConverter.addConversion([&](ValueType type) {
1095     Type converted = typeConverter.convertType(type.getValueType());
1096     return converted ? ValueType::get(converted) : converted;
1097   });
1098 
1099   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1100       typeConverter, patterns.getContext());
1101 
1102   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1103       [&](Operation *op) { return typeConverter.isLegal(op); });
1104 }
1105