1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Async/IR/Async.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 #define DEBUG_TYPE "convert-async-to-llvm"
23 
24 using namespace mlir;
25 using namespace mlir::async;
26 
27 //===----------------------------------------------------------------------===//
28 // Async Runtime C API declaration.
29 //===----------------------------------------------------------------------===//
30 
31 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
32 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
33 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
34 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
35 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
36 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
37 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
39 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
42 static constexpr const char *kGetValueStorage =
43     "mlirAsyncRuntimeGetValueStorage";
44 static constexpr const char *kAddTokenToGroup =
45     "mlirAsyncRuntimeAddTokenToGroup";
46 static constexpr const char *kAwaitTokenAndExecute =
47     "mlirAsyncRuntimeAwaitTokenAndExecute";
48 static constexpr const char *kAwaitValueAndExecute =
49     "mlirAsyncRuntimeAwaitValueAndExecute";
50 static constexpr const char *kAwaitAllAndExecute =
51     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
52 
53 namespace {
54 /// Async Runtime API function types.
55 ///
56 /// Because we can't create API function signature for type parametrized
57 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
58 /// lowering all async data types become opaque pointers at runtime.
59 struct AsyncAPI {
60   // All async types are lowered to opaque i8* LLVM pointers at runtime.
61   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
62     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
63   }
64 
65   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
66     return LLVM::LLVMTokenType::get(ctx);
67   }
68 
69   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
70     auto ref = opaquePointerType(ctx);
71     auto count = IntegerType::get(ctx, 32);
72     return FunctionType::get(ctx, {ref, count}, {});
73   }
74 
75   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
76     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
77   }
78 
79   static FunctionType createValueFunctionType(MLIRContext *ctx) {
80     auto i32 = IntegerType::get(ctx, 32);
81     auto value = opaquePointerType(ctx);
82     return FunctionType::get(ctx, {i32}, {value});
83   }
84 
85   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
86     return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
87   }
88 
89   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
90     auto value = opaquePointerType(ctx);
91     auto storage = opaquePointerType(ctx);
92     return FunctionType::get(ctx, {value}, {storage});
93   }
94 
95   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
96     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
97   }
98 
99   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
100     auto value = opaquePointerType(ctx);
101     return FunctionType::get(ctx, {value}, {});
102   }
103 
104   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
105     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
106   }
107 
108   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
109     auto value = opaquePointerType(ctx);
110     return FunctionType::get(ctx, {value}, {});
111   }
112 
113   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
114     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
115   }
116 
117   static FunctionType executeFunctionType(MLIRContext *ctx) {
118     auto hdl = opaquePointerType(ctx);
119     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
120     return FunctionType::get(ctx, {hdl, resume}, {});
121   }
122 
123   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
124     auto i64 = IntegerType::get(ctx, 64);
125     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
126                              {i64});
127   }
128 
129   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
130     auto hdl = opaquePointerType(ctx);
131     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
132     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
133   }
134 
135   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
136     auto value = opaquePointerType(ctx);
137     auto hdl = opaquePointerType(ctx);
138     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
139     return FunctionType::get(ctx, {value, hdl, resume}, {});
140   }
141 
142   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
143     auto hdl = opaquePointerType(ctx);
144     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
145     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
146   }
147 
148   // Auxiliary coroutine resume intrinsic wrapper.
149   static Type resumeFunctionType(MLIRContext *ctx) {
150     auto voidTy = LLVM::LLVMVoidType::get(ctx);
151     auto i8Ptr = opaquePointerType(ctx);
152     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
153   }
154 };
155 } // namespace
156 
157 /// Adds Async Runtime C API declarations to the module.
158 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
159   auto builder =
160       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
161 
162   auto addFuncDecl = [&](StringRef name, FunctionType type) {
163     if (module.lookupSymbol(name))
164       return;
165     builder.create<FuncOp>(name, type).setPrivate();
166   };
167 
168   MLIRContext *ctx = module.getContext();
169   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
170   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
171   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
172   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
173   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
174   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
175   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
176   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
177   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
178   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
179   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
180   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
181   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
182   addFuncDecl(kAwaitTokenAndExecute,
183               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
184   addFuncDecl(kAwaitValueAndExecute,
185               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
186   addFuncDecl(kAwaitAllAndExecute,
187               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // Add malloc/free declarations to the module.
192 //===----------------------------------------------------------------------===//
193 
194 static constexpr const char *kMalloc = "malloc";
195 static constexpr const char *kFree = "free";
196 
197 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
198                             StringRef name, Type ret, ArrayRef<Type> params) {
199   if (module.lookupSymbol(name))
200     return;
201   Type type = LLVM::LLVMFunctionType::get(ret, params);
202   builder.create<LLVM::LLVMFuncOp>(name, type);
203 }
204 
205 /// Adds malloc/free declarations to the module.
206 static void addCRuntimeDeclarations(ModuleOp module) {
207   using namespace mlir::LLVM;
208 
209   MLIRContext *ctx = module.getContext();
210   auto builder =
211       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
212 
213   auto voidTy = LLVMVoidType::get(ctx);
214   auto i64 = IntegerType::get(ctx, 64);
215   auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8));
216 
217   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
218   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // Coroutine resume function wrapper.
223 //===----------------------------------------------------------------------===//
224 
225 static constexpr const char *kResume = "__resume";
226 
227 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
228 /// intrinsics. We need this function to be able to pass it to the async
229 /// runtime execute API.
230 static void addResumeFunction(ModuleOp module) {
231   if (module.lookupSymbol(kResume))
232     return;
233 
234   MLIRContext *ctx = module.getContext();
235   auto loc = module.getLoc();
236   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
237 
238   auto voidTy = LLVM::LLVMVoidType::get(ctx);
239   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
240 
241   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
242       kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
243   resumeOp.setPrivate();
244 
245   auto *block = resumeOp.addEntryBlock();
246   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
247 
248   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
249   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // Convert Async dialect types to LLVM types.
254 //===----------------------------------------------------------------------===//
255 
256 namespace {
257 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
258 /// their runtime type (opaque pointers) and does not convert any other types.
259 class AsyncRuntimeTypeConverter : public TypeConverter {
260 public:
261   AsyncRuntimeTypeConverter() {
262     addConversion([](Type type) { return type; });
263     addConversion(convertAsyncTypes);
264   }
265 
266   static Optional<Type> convertAsyncTypes(Type type) {
267     if (type.isa<TokenType, GroupType, ValueType>())
268       return AsyncAPI::opaquePointerType(type.getContext());
269 
270     if (type.isa<CoroIdType, CoroStateType>())
271       return AsyncAPI::tokenType(type.getContext());
272     if (type.isa<CoroHandleType>())
273       return AsyncAPI::opaquePointerType(type.getContext());
274 
275     return llvm::None;
276   }
277 };
278 } // namespace
279 
280 //===----------------------------------------------------------------------===//
281 // Convert async.coro.id to @llvm.coro.id intrinsic.
282 //===----------------------------------------------------------------------===//
283 
284 namespace {
285 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
286 public:
287   using OpConversionPattern::OpConversionPattern;
288 
289   LogicalResult
290   matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands,
291                   ConversionPatternRewriter &rewriter) const override {
292     auto token = AsyncAPI::tokenType(op->getContext());
293     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
294     auto loc = op->getLoc();
295 
296     // Constants for initializing coroutine frame.
297     auto constZero = rewriter.create<LLVM::ConstantOp>(
298         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
299     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
300 
301     // Get coroutine id: @llvm.coro.id.
302     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
303         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
304 
305     return success();
306   }
307 };
308 } // namespace
309 
310 //===----------------------------------------------------------------------===//
311 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
312 //===----------------------------------------------------------------------===//
313 
314 namespace {
315 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
316 public:
317   using OpConversionPattern::OpConversionPattern;
318 
319   LogicalResult
320   matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands,
321                   ConversionPatternRewriter &rewriter) const override {
322     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
323     auto loc = op->getLoc();
324 
325     // Get coroutine frame size: @llvm.coro.size.i64.
326     auto coroSize =
327         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
328 
329     // Allocate memory for the coroutine frame.
330     auto coroAlloc = rewriter.create<LLVM::CallOp>(
331         loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc),
332         ValueRange(coroSize.getResult()));
333 
334     // Begin a coroutine: @llvm.coro.begin.
335     auto coroId = CoroBeginOpAdaptor(operands).id();
336     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
337         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
338 
339     return success();
340   }
341 };
342 } // namespace
343 
344 //===----------------------------------------------------------------------===//
345 // Convert async.coro.free to @llvm.coro.free intrinsic.
346 //===----------------------------------------------------------------------===//
347 
348 namespace {
349 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
350 public:
351   using OpConversionPattern::OpConversionPattern;
352 
353   LogicalResult
354   matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands,
355                   ConversionPatternRewriter &rewriter) const override {
356     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
357     auto loc = op->getLoc();
358 
359     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
360     auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands);
361 
362     // Free the memory.
363     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
364                                               rewriter.getSymbolRefAttr(kFree),
365                                               ValueRange(coroMem.getResult()));
366 
367     return success();
368   }
369 };
370 } // namespace
371 
372 //===----------------------------------------------------------------------===//
373 // Convert async.coro.end to @llvm.coro.end intrinsic.
374 //===----------------------------------------------------------------------===//
375 
376 namespace {
377 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
378 public:
379   using OpConversionPattern::OpConversionPattern;
380 
381   LogicalResult
382   matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands,
383                   ConversionPatternRewriter &rewriter) const override {
384     // We are not in the block that is part of the unwind sequence.
385     auto constFalse = rewriter.create<LLVM::ConstantOp>(
386         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
387 
388     // Mark the end of a coroutine: @llvm.coro.end.
389     auto coroHdl = CoroEndOpAdaptor(operands).handle();
390     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
391                                      ValueRange({coroHdl, constFalse}));
392     rewriter.eraseOp(op);
393 
394     return success();
395   }
396 };
397 } // namespace
398 
399 //===----------------------------------------------------------------------===//
400 // Convert async.coro.save to @llvm.coro.save intrinsic.
401 //===----------------------------------------------------------------------===//
402 
403 namespace {
404 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
405 public:
406   using OpConversionPattern::OpConversionPattern;
407 
408   LogicalResult
409   matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands,
410                   ConversionPatternRewriter &rewriter) const override {
411     // Save the coroutine state: @llvm.coro.save
412     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
413         op, AsyncAPI::tokenType(op->getContext()), operands);
414 
415     return success();
416   }
417 };
418 } // namespace
419 
420 //===----------------------------------------------------------------------===//
421 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
422 //===----------------------------------------------------------------------===//
423 
424 namespace {
425 
426 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
427 /// branch to the appropriate block based on the return code.
428 ///
429 /// Before:
430 ///
431 ///   ^suspended:
432 ///     "opBefore"(...)
433 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
434 ///   ^resume:
435 ///     "op"(...)
436 ///   ^cleanup: ...
437 ///   ^suspend: ...
438 ///
439 /// After:
440 ///
441 ///   ^suspended:
442 ///     "opBefore"(...)
443 ///     %suspend = llmv.intr.coro.suspend ...
444 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
445 ///   ^resume:
446 ///     "op"(...)
447 ///   ^cleanup: ...
448 ///   ^suspend: ...
449 ///
450 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
451 public:
452   using OpConversionPattern::OpConversionPattern;
453 
454   LogicalResult
455   matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands,
456                   ConversionPatternRewriter &rewriter) const override {
457     auto i8 = rewriter.getIntegerType(8);
458     auto i32 = rewriter.getI32Type();
459     auto loc = op->getLoc();
460 
461     // This is not a final suspension point.
462     auto constFalse = rewriter.create<LLVM::ConstantOp>(
463         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
464 
465     // Suspend a coroutine: @llvm.coro.suspend
466     auto coroState = CoroSuspendOpAdaptor(operands).state();
467     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
468         loc, i8, ValueRange({coroState, constFalse}));
469 
470     // Cast return code to i32.
471 
472     // After a suspension point decide if we should branch into resume, cleanup
473     // or suspend block of the coroutine (see @llvm.coro.suspend return code
474     // documentation).
475     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
476     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
477                                               op.cleanupDest()};
478     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
479         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
480         /*defaultDestination=*/op.suspendDest(),
481         /*defaultOperands=*/ValueRange(),
482         /*caseValues=*/caseValues,
483         /*caseDestinations=*/caseDest,
484         /*caseOperands=*/ArrayRef<ValueRange>(),
485         /*branchWeights=*/ArrayRef<int32_t>());
486 
487     return success();
488   }
489 };
490 } // namespace
491 
492 //===----------------------------------------------------------------------===//
493 // Convert async.runtime.create to the corresponding runtime API call.
494 //
495 // To allocate storage for the async values we use getelementptr trick:
496 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
497 //===----------------------------------------------------------------------===//
498 
499 namespace {
500 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
501 public:
502   using OpConversionPattern::OpConversionPattern;
503 
504   LogicalResult
505   matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands,
506                   ConversionPatternRewriter &rewriter) const override {
507     TypeConverter *converter = getTypeConverter();
508     Type resultType = op->getResultTypes()[0];
509 
510     // Tokens and Groups lowered to function calls without arguments.
511     if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
512       rewriter.replaceOpWithNewOp<CallOp>(
513           op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
514           converter->convertType(resultType));
515       return success();
516     }
517 
518     // To create a value we need to compute the storage requirement.
519     if (auto value = resultType.dyn_cast<ValueType>()) {
520       // Returns the size requirements for the async value storage.
521       auto sizeOf = [&](ValueType valueType) -> Value {
522         auto loc = op->getLoc();
523         auto i32 = rewriter.getI32Type();
524 
525         auto storedType = converter->convertType(valueType.getValueType());
526         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
527 
528         // %Size = getelementptr %T* null, int 1
529         // %SizeI = ptrtoint %T* %Size to i32
530         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
531         auto one = rewriter.create<LLVM::ConstantOp>(
532             loc, i32, rewriter.getI32IntegerAttr(1));
533         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
534                                                 one.getResult());
535         return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep);
536       };
537 
538       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
539                                           sizeOf(value));
540 
541       return success();
542     }
543 
544     return rewriter.notifyMatchFailure(op, "unsupported async type");
545   }
546 };
547 } // namespace
548 
549 //===----------------------------------------------------------------------===//
550 // Convert async.runtime.set_available to the corresponding runtime API call.
551 //===----------------------------------------------------------------------===//
552 
553 namespace {
554 class RuntimeSetAvailableOpLowering
555     : public OpConversionPattern<RuntimeSetAvailableOp> {
556 public:
557   using OpConversionPattern::OpConversionPattern;
558 
559   LogicalResult
560   matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
561                   ConversionPatternRewriter &rewriter) const override {
562     Type operandType = op.operand().getType();
563 
564     if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) {
565       rewriter.create<CallOp>(op->getLoc(),
566                               operandType.isa<TokenType>() ? kEmplaceToken
567                                                            : kEmplaceValue,
568                               TypeRange(), operands);
569       rewriter.eraseOp(op);
570       return success();
571     }
572 
573     return rewriter.notifyMatchFailure(op, "unsupported async type");
574   }
575 };
576 } // namespace
577 
578 //===----------------------------------------------------------------------===//
579 // Convert async.runtime.await to the corresponding runtime API call.
580 //===----------------------------------------------------------------------===//
581 
582 namespace {
583 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
584 public:
585   using OpConversionPattern::OpConversionPattern;
586 
587   LogicalResult
588   matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
589                   ConversionPatternRewriter &rewriter) const override {
590     Type operandType = op.operand().getType();
591 
592     StringRef apiFuncName;
593     if (operandType.isa<TokenType>())
594       apiFuncName = kAwaitToken;
595     else if (operandType.isa<ValueType>())
596       apiFuncName = kAwaitValue;
597     else if (operandType.isa<GroupType>())
598       apiFuncName = kAwaitGroup;
599     else
600       return rewriter.notifyMatchFailure(op, "unsupported async type");
601 
602     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
603     rewriter.eraseOp(op);
604 
605     return success();
606   }
607 };
608 } // namespace
609 
610 //===----------------------------------------------------------------------===//
611 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
612 //===----------------------------------------------------------------------===//
613 
614 namespace {
615 class RuntimeAwaitAndResumeOpLowering
616     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
617 public:
618   using OpConversionPattern::OpConversionPattern;
619 
620   LogicalResult
621   matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
622                   ConversionPatternRewriter &rewriter) const override {
623     Type operandType = op.operand().getType();
624 
625     StringRef apiFuncName;
626     if (operandType.isa<TokenType>())
627       apiFuncName = kAwaitTokenAndExecute;
628     else if (operandType.isa<ValueType>())
629       apiFuncName = kAwaitValueAndExecute;
630     else if (operandType.isa<GroupType>())
631       apiFuncName = kAwaitAllAndExecute;
632     else
633       return rewriter.notifyMatchFailure(op, "unsupported async type");
634 
635     Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
636     Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
637 
638     // A pointer to coroutine resume intrinsic wrapper.
639     addResumeFunction(op->getParentOfType<ModuleOp>());
640     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
641     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
642         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
643 
644     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
645                             ValueRange({operand, handle, resumePtr.res()}));
646     rewriter.eraseOp(op);
647 
648     return success();
649   }
650 };
651 } // namespace
652 
653 //===----------------------------------------------------------------------===//
654 // Convert async.runtime.resume to the corresponding runtime API call.
655 //===----------------------------------------------------------------------===//
656 
657 namespace {
658 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
659 public:
660   using OpConversionPattern::OpConversionPattern;
661 
662   LogicalResult
663   matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands,
664                   ConversionPatternRewriter &rewriter) const override {
665     // A pointer to coroutine resume intrinsic wrapper.
666     addResumeFunction(op->getParentOfType<ModuleOp>());
667     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
668     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
669         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
670 
671     // Call async runtime API to execute a coroutine in the managed thread.
672     auto coroHdl = RuntimeResumeOpAdaptor(operands).handle();
673     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute,
674                                         ValueRange({coroHdl, resumePtr.res()}));
675 
676     return success();
677   }
678 };
679 } // namespace
680 
681 //===----------------------------------------------------------------------===//
682 // Convert async.runtime.store to the corresponding runtime API call.
683 //===----------------------------------------------------------------------===//
684 
685 namespace {
686 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
687 public:
688   using OpConversionPattern::OpConversionPattern;
689 
690   LogicalResult
691   matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands,
692                   ConversionPatternRewriter &rewriter) const override {
693     Location loc = op->getLoc();
694 
695     // Get a pointer to the async value storage from the runtime.
696     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
697     auto storage = RuntimeStoreOpAdaptor(operands).storage();
698     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
699                                               TypeRange(i8Ptr), storage);
700 
701     // Cast from i8* to the LLVM pointer type.
702     auto valueType = op.value().getType();
703     auto llvmValueType = getTypeConverter()->convertType(valueType);
704     if (!llvmValueType)
705       return rewriter.notifyMatchFailure(
706           op, "failed to convert stored value type to LLVM type");
707 
708     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
709         loc, LLVM::LLVMPointerType::get(llvmValueType),
710         storagePtr.getResult(0));
711 
712     // Store the yielded value into the async value storage.
713     auto value = RuntimeStoreOpAdaptor(operands).value();
714     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
715 
716     // Erase the original runtime store operation.
717     rewriter.eraseOp(op);
718 
719     return success();
720   }
721 };
722 } // namespace
723 
724 //===----------------------------------------------------------------------===//
725 // Convert async.runtime.load to the corresponding runtime API call.
726 //===----------------------------------------------------------------------===//
727 
728 namespace {
729 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
730 public:
731   using OpConversionPattern::OpConversionPattern;
732 
733   LogicalResult
734   matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands,
735                   ConversionPatternRewriter &rewriter) const override {
736     Location loc = op->getLoc();
737 
738     // Get a pointer to the async value storage from the runtime.
739     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
740     auto storage = RuntimeLoadOpAdaptor(operands).storage();
741     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
742                                               TypeRange(i8Ptr), storage);
743 
744     // Cast from i8* to the LLVM pointer type.
745     auto valueType = op.result().getType();
746     auto llvmValueType = getTypeConverter()->convertType(valueType);
747     if (!llvmValueType)
748       return rewriter.notifyMatchFailure(
749           op, "failed to convert loaded value type to LLVM type");
750 
751     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
752         loc, LLVM::LLVMPointerType::get(llvmValueType),
753         storagePtr.getResult(0));
754 
755     // Load from the casted pointer.
756     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
757 
758     return success();
759   }
760 };
761 } // namespace
762 
763 //===----------------------------------------------------------------------===//
764 // Convert async.runtime.add_to_group to the corresponding runtime API call.
765 //===----------------------------------------------------------------------===//
766 
767 namespace {
768 class RuntimeAddToGroupOpLowering
769     : public OpConversionPattern<RuntimeAddToGroupOp> {
770 public:
771   using OpConversionPattern::OpConversionPattern;
772 
773   LogicalResult
774   matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands,
775                   ConversionPatternRewriter &rewriter) const override {
776     // Currently we can only add tokens to the group.
777     if (!op.operand().getType().isa<TokenType>())
778       return rewriter.notifyMatchFailure(op, "only token type is supported");
779 
780     // Replace with a runtime API function call.
781     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup,
782                                         rewriter.getI64Type(), operands);
783 
784     return success();
785   }
786 };
787 } // namespace
788 
789 //===----------------------------------------------------------------------===//
790 // Async reference counting ops lowering (`async.runtime.add_ref` and
791 // `async.runtime.drop_ref` to the corresponding API calls).
792 //===----------------------------------------------------------------------===//
793 
794 namespace {
795 template <typename RefCountingOp>
796 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
797 public:
798   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
799                                  StringRef apiFunctionName)
800       : OpConversionPattern<RefCountingOp>(converter, ctx),
801         apiFunctionName(apiFunctionName) {}
802 
803   LogicalResult
804   matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands,
805                   ConversionPatternRewriter &rewriter) const override {
806     auto count =
807         rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(),
808                                     rewriter.getI32IntegerAttr(op.count()));
809 
810     auto operand = typename RefCountingOp::Adaptor(operands).operand();
811     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
812                                         ValueRange({operand, count}));
813 
814     return success();
815   }
816 
817 private:
818   StringRef apiFunctionName;
819 };
820 
821 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
822 public:
823   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
824       : RefCountingOpLowering(converter, ctx, kAddRef) {}
825 };
826 
827 class RuntimeDropRefOpLowering
828     : public RefCountingOpLowering<RuntimeDropRefOp> {
829 public:
830   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
831       : RefCountingOpLowering(converter, ctx, kDropRef) {}
832 };
833 } // namespace
834 
835 //===----------------------------------------------------------------------===//
836 // Convert return operations that return async values from async regions.
837 //===----------------------------------------------------------------------===//
838 
839 namespace {
840 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
841 public:
842   using OpConversionPattern::OpConversionPattern;
843 
844   LogicalResult
845   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
846                   ConversionPatternRewriter &rewriter) const override {
847     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
848     return success();
849   }
850 };
851 } // namespace
852 
853 //===----------------------------------------------------------------------===//
854 
855 namespace {
856 struct ConvertAsyncToLLVMPass
857     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
858   void runOnOperation() override;
859 };
860 } // namespace
861 
862 void ConvertAsyncToLLVMPass::runOnOperation() {
863   ModuleOp module = getOperation();
864   MLIRContext *ctx = module->getContext();
865 
866   // Add declarations for most functions required by the coroutines lowering.
867   // We delay adding the resume function until it's needed because it currently
868   // fails to compile unless '-O0' is specified.
869   addAsyncRuntimeApiDeclarations(module);
870   addCRuntimeDeclarations(module);
871 
872   // Lower async.runtime and async.coro operations to Async Runtime API and
873   // LLVM coroutine intrinsics.
874 
875   // Convert async dialect types and operations to LLVM dialect.
876   AsyncRuntimeTypeConverter converter;
877   RewritePatternSet patterns(ctx);
878 
879   // We use conversion to LLVM type to lower async.runtime load and store
880   // operations.
881   LLVMTypeConverter llvmConverter(ctx);
882   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
883 
884   // Convert async types in function signatures and function calls.
885   populateFuncOpTypeConversionPattern(patterns, converter);
886   populateCallOpTypeConversionPattern(patterns, converter);
887 
888   // Convert return operations inside async.execute regions.
889   patterns.add<ReturnOpOpConversion>(converter, ctx);
890 
891   // Lower async.runtime operations to the async runtime API calls.
892   patterns.add<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
893                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
894                RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
895                RuntimeDropRefOpLowering>(converter, ctx);
896 
897   // Lower async.runtime operations that rely on LLVM type converter to convert
898   // from async value payload type to the LLVM type.
899   patterns.add<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
900                RuntimeLoadOpLowering>(llvmConverter, ctx);
901 
902   // Lower async coroutine operations to LLVM coroutine intrinsics.
903   patterns
904       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
905            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
906           converter, ctx);
907 
908   ConversionTarget target(*ctx);
909   target.addLegalOp<ConstantOp>();
910   target.addLegalDialect<LLVM::LLVMDialect>();
911 
912   // All operations from Async dialect must be lowered to the runtime API and
913   // LLVM intrinsics calls.
914   target.addIllegalDialect<AsyncDialect>();
915 
916   // Add dynamic legality constraints to apply conversions defined above.
917   target.addDynamicallyLegalOp<FuncOp>(
918       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
919   target.addDynamicallyLegalOp<ReturnOp>(
920       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
921   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
922     return converter.isSignatureLegal(op.getCalleeType());
923   });
924 
925   if (failed(applyPartialConversion(module, target, std::move(patterns))))
926     signalPassFailure();
927 }
928 
929 //===----------------------------------------------------------------------===//
930 // Patterns for structural type conversions for the Async dialect operations.
931 //===----------------------------------------------------------------------===//
932 
933 namespace {
934 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
935 public:
936   using OpConversionPattern::OpConversionPattern;
937   LogicalResult
938   matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
939                   ConversionPatternRewriter &rewriter) const override {
940     ExecuteOp newOp =
941         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
942     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
943                                 newOp.getRegion().end());
944 
945     // Set operands and update block argument and result types.
946     newOp->setOperands(operands);
947     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
948       return failure();
949     for (auto result : newOp.getResults())
950       result.setType(typeConverter->convertType(result.getType()));
951 
952     rewriter.replaceOp(op, newOp.getResults());
953     return success();
954   }
955 };
956 
957 // Dummy pattern to trigger the appropriate type conversion / materialization.
958 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
959 public:
960   using OpConversionPattern::OpConversionPattern;
961   LogicalResult
962   matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
963                   ConversionPatternRewriter &rewriter) const override {
964     rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
965     return success();
966   }
967 };
968 
969 // Dummy pattern to trigger the appropriate type conversion / materialization.
970 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
971 public:
972   using OpConversionPattern::OpConversionPattern;
973   LogicalResult
974   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
975                   ConversionPatternRewriter &rewriter) const override {
976     rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
977     return success();
978   }
979 };
980 } // namespace
981 
982 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
983   return std::make_unique<ConvertAsyncToLLVMPass>();
984 }
985 
986 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
987     TypeConverter &typeConverter, RewritePatternSet &patterns,
988     ConversionTarget &target) {
989   typeConverter.addConversion([&](TokenType type) { return type; });
990   typeConverter.addConversion([&](ValueType type) {
991     return ValueType::get(typeConverter.convertType(type.getValueType()));
992   });
993 
994   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
995       typeConverter, patterns.getContext());
996 
997   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
998       [&](Operation *op) { return typeConverter.isLegal(op); });
999 }
1000