1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/Async/IR/Async.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 #define DEBUG_TYPE "convert-async-to-llvm"
24 
25 using namespace mlir;
26 using namespace mlir::async;
27 
28 //===----------------------------------------------------------------------===//
29 // Async Runtime C API declaration.
30 //===----------------------------------------------------------------------===//
31 
32 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
33 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
34 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
35 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
36 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
37 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
38 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
39 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
40 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
41 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
42 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
43 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
44 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
45 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
46 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
47 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
48 static constexpr const char *kGetValueStorage =
49     "mlirAsyncRuntimeGetValueStorage";
50 static constexpr const char *kAddTokenToGroup =
51     "mlirAsyncRuntimeAddTokenToGroup";
52 static constexpr const char *kAwaitTokenAndExecute =
53     "mlirAsyncRuntimeAwaitTokenAndExecute";
54 static constexpr const char *kAwaitValueAndExecute =
55     "mlirAsyncRuntimeAwaitValueAndExecute";
56 static constexpr const char *kAwaitAllAndExecute =
57     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
58 
59 namespace {
60 /// Async Runtime API function types.
61 ///
62 /// Because we can't create API function signature for type parametrized
63 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
64 /// lowering all async data types become opaque pointers at runtime.
65 struct AsyncAPI {
66   // All async types are lowered to opaque i8* LLVM pointers at runtime.
67   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
68     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
69   }
70 
71   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
72     return LLVM::LLVMTokenType::get(ctx);
73   }
74 
75   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
76     auto ref = opaquePointerType(ctx);
77     auto count = IntegerType::get(ctx, 32);
78     return FunctionType::get(ctx, {ref, count}, {});
79   }
80 
81   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
82     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
83   }
84 
85   static FunctionType createValueFunctionType(MLIRContext *ctx) {
86     auto i32 = IntegerType::get(ctx, 32);
87     auto value = opaquePointerType(ctx);
88     return FunctionType::get(ctx, {i32}, {value});
89   }
90 
91   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
92     return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
93   }
94 
95   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
96     auto value = opaquePointerType(ctx);
97     auto storage = opaquePointerType(ctx);
98     return FunctionType::get(ctx, {value}, {storage});
99   }
100 
101   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
102     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
103   }
104 
105   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
106     auto value = opaquePointerType(ctx);
107     return FunctionType::get(ctx, {value}, {});
108   }
109 
110   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
111     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
112   }
113 
114   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
115     auto value = opaquePointerType(ctx);
116     return FunctionType::get(ctx, {value}, {});
117   }
118 
119   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
120     auto i1 = IntegerType::get(ctx, 1);
121     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
122   }
123 
124   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
125     auto value = opaquePointerType(ctx);
126     auto i1 = IntegerType::get(ctx, 1);
127     return FunctionType::get(ctx, {value}, {i1});
128   }
129 
130   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
131     auto i1 = IntegerType::get(ctx, 1);
132     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
133   }
134 
135   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
136     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
137   }
138 
139   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
140     auto value = opaquePointerType(ctx);
141     return FunctionType::get(ctx, {value}, {});
142   }
143 
144   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
145     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
146   }
147 
148   static FunctionType executeFunctionType(MLIRContext *ctx) {
149     auto hdl = opaquePointerType(ctx);
150     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
151     return FunctionType::get(ctx, {hdl, resume}, {});
152   }
153 
154   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
155     auto i64 = IntegerType::get(ctx, 64);
156     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
157                              {i64});
158   }
159 
160   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
161     auto hdl = opaquePointerType(ctx);
162     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
163     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
164   }
165 
166   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
167     auto value = opaquePointerType(ctx);
168     auto hdl = opaquePointerType(ctx);
169     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
170     return FunctionType::get(ctx, {value, hdl, resume}, {});
171   }
172 
173   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
174     auto hdl = opaquePointerType(ctx);
175     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
176     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
177   }
178 
179   // Auxiliary coroutine resume intrinsic wrapper.
180   static Type resumeFunctionType(MLIRContext *ctx) {
181     auto voidTy = LLVM::LLVMVoidType::get(ctx);
182     auto i8Ptr = opaquePointerType(ctx);
183     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
184   }
185 };
186 } // namespace
187 
188 /// Adds Async Runtime C API declarations to the module.
189 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
190   auto builder =
191       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
192 
193   auto addFuncDecl = [&](StringRef name, FunctionType type) {
194     if (module.lookupSymbol(name))
195       return;
196     builder.create<FuncOp>(name, type).setPrivate();
197   };
198 
199   MLIRContext *ctx = module.getContext();
200   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
201   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
202   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
203   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
204   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
205   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
206   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
207   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
208   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
209   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
210   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
211   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
212   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
213   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
214   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
215   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
216   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
217   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
218   addFuncDecl(kAwaitTokenAndExecute,
219               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
220   addFuncDecl(kAwaitValueAndExecute,
221               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
222   addFuncDecl(kAwaitAllAndExecute,
223               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // Add malloc/free declarations to the module.
228 //===----------------------------------------------------------------------===//
229 
230 static constexpr const char *kMalloc = "malloc";
231 static constexpr const char *kFree = "free";
232 
233 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
234                             StringRef name, Type ret, ArrayRef<Type> params) {
235   if (module.lookupSymbol(name))
236     return;
237   Type type = LLVM::LLVMFunctionType::get(ret, params);
238   builder.create<LLVM::LLVMFuncOp>(name, type);
239 }
240 
241 /// Adds malloc/free declarations to the module.
242 static void addCRuntimeDeclarations(ModuleOp module) {
243   using namespace mlir::LLVM;
244 
245   MLIRContext *ctx = module.getContext();
246   auto builder =
247       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
248 
249   auto voidTy = LLVMVoidType::get(ctx);
250   auto i64 = IntegerType::get(ctx, 64);
251   auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8));
252 
253   addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
254   addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // Coroutine resume function wrapper.
259 //===----------------------------------------------------------------------===//
260 
261 static constexpr const char *kResume = "__resume";
262 
263 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
264 /// intrinsics. We need this function to be able to pass it to the async
265 /// runtime execute API.
266 static void addResumeFunction(ModuleOp module) {
267   if (module.lookupSymbol(kResume))
268     return;
269 
270   MLIRContext *ctx = module.getContext();
271   auto loc = module.getLoc();
272   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
273 
274   auto voidTy = LLVM::LLVMVoidType::get(ctx);
275   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
276 
277   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
278       kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
279   resumeOp.setPrivate();
280 
281   auto *block = resumeOp.addEntryBlock();
282   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
283 
284   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
285   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // Convert Async dialect types to LLVM types.
290 //===----------------------------------------------------------------------===//
291 
292 namespace {
293 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
294 /// their runtime type (opaque pointers) and does not convert any other types.
295 class AsyncRuntimeTypeConverter : public TypeConverter {
296 public:
297   AsyncRuntimeTypeConverter() {
298     addConversion([](Type type) { return type; });
299     addConversion(convertAsyncTypes);
300   }
301 
302   static Optional<Type> convertAsyncTypes(Type type) {
303     if (type.isa<TokenType, GroupType, ValueType>())
304       return AsyncAPI::opaquePointerType(type.getContext());
305 
306     if (type.isa<CoroIdType, CoroStateType>())
307       return AsyncAPI::tokenType(type.getContext());
308     if (type.isa<CoroHandleType>())
309       return AsyncAPI::opaquePointerType(type.getContext());
310 
311     return llvm::None;
312   }
313 };
314 } // namespace
315 
316 //===----------------------------------------------------------------------===//
317 // Convert async.coro.id to @llvm.coro.id intrinsic.
318 //===----------------------------------------------------------------------===//
319 
320 namespace {
321 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
322 public:
323   using OpConversionPattern::OpConversionPattern;
324 
325   LogicalResult
326   matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands,
327                   ConversionPatternRewriter &rewriter) const override {
328     auto token = AsyncAPI::tokenType(op->getContext());
329     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
330     auto loc = op->getLoc();
331 
332     // Constants for initializing coroutine frame.
333     auto constZero = rewriter.create<LLVM::ConstantOp>(
334         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
335     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
336 
337     // Get coroutine id: @llvm.coro.id.
338     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
339         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
340 
341     return success();
342   }
343 };
344 } // namespace
345 
346 //===----------------------------------------------------------------------===//
347 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
348 //===----------------------------------------------------------------------===//
349 
350 namespace {
351 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
352 public:
353   using OpConversionPattern::OpConversionPattern;
354 
355   LogicalResult
356   matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands,
357                   ConversionPatternRewriter &rewriter) const override {
358     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
359     auto loc = op->getLoc();
360 
361     // Get coroutine frame size: @llvm.coro.size.i64.
362     auto coroSize =
363         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
364 
365     // Allocate memory for the coroutine frame.
366     auto coroAlloc = rewriter.create<LLVM::CallOp>(
367         loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc),
368         ValueRange(coroSize.getResult()));
369 
370     // Begin a coroutine: @llvm.coro.begin.
371     auto coroId = CoroBeginOpAdaptor(operands).id();
372     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
373         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
374 
375     return success();
376   }
377 };
378 } // namespace
379 
380 //===----------------------------------------------------------------------===//
381 // Convert async.coro.free to @llvm.coro.free intrinsic.
382 //===----------------------------------------------------------------------===//
383 
384 namespace {
385 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
386 public:
387   using OpConversionPattern::OpConversionPattern;
388 
389   LogicalResult
390   matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands,
391                   ConversionPatternRewriter &rewriter) const override {
392     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
393     auto loc = op->getLoc();
394 
395     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
396     auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands);
397 
398     // Free the memory.
399     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
400                                               rewriter.getSymbolRefAttr(kFree),
401                                               ValueRange(coroMem.getResult()));
402 
403     return success();
404   }
405 };
406 } // namespace
407 
408 //===----------------------------------------------------------------------===//
409 // Convert async.coro.end to @llvm.coro.end intrinsic.
410 //===----------------------------------------------------------------------===//
411 
412 namespace {
413 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
414 public:
415   using OpConversionPattern::OpConversionPattern;
416 
417   LogicalResult
418   matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands,
419                   ConversionPatternRewriter &rewriter) const override {
420     // We are not in the block that is part of the unwind sequence.
421     auto constFalse = rewriter.create<LLVM::ConstantOp>(
422         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
423 
424     // Mark the end of a coroutine: @llvm.coro.end.
425     auto coroHdl = CoroEndOpAdaptor(operands).handle();
426     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
427                                      ValueRange({coroHdl, constFalse}));
428     rewriter.eraseOp(op);
429 
430     return success();
431   }
432 };
433 } // namespace
434 
435 //===----------------------------------------------------------------------===//
436 // Convert async.coro.save to @llvm.coro.save intrinsic.
437 //===----------------------------------------------------------------------===//
438 
439 namespace {
440 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
441 public:
442   using OpConversionPattern::OpConversionPattern;
443 
444   LogicalResult
445   matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands,
446                   ConversionPatternRewriter &rewriter) const override {
447     // Save the coroutine state: @llvm.coro.save
448     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
449         op, AsyncAPI::tokenType(op->getContext()), operands);
450 
451     return success();
452   }
453 };
454 } // namespace
455 
456 //===----------------------------------------------------------------------===//
457 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
458 //===----------------------------------------------------------------------===//
459 
460 namespace {
461 
462 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
463 /// branch to the appropriate block based on the return code.
464 ///
465 /// Before:
466 ///
467 ///   ^suspended:
468 ///     "opBefore"(...)
469 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
470 ///   ^resume:
471 ///     "op"(...)
472 ///   ^cleanup: ...
473 ///   ^suspend: ...
474 ///
475 /// After:
476 ///
477 ///   ^suspended:
478 ///     "opBefore"(...)
479 ///     %suspend = llmv.intr.coro.suspend ...
480 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
481 ///   ^resume:
482 ///     "op"(...)
483 ///   ^cleanup: ...
484 ///   ^suspend: ...
485 ///
486 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
487 public:
488   using OpConversionPattern::OpConversionPattern;
489 
490   LogicalResult
491   matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands,
492                   ConversionPatternRewriter &rewriter) const override {
493     auto i8 = rewriter.getIntegerType(8);
494     auto i32 = rewriter.getI32Type();
495     auto loc = op->getLoc();
496 
497     // This is not a final suspension point.
498     auto constFalse = rewriter.create<LLVM::ConstantOp>(
499         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
500 
501     // Suspend a coroutine: @llvm.coro.suspend
502     auto coroState = CoroSuspendOpAdaptor(operands).state();
503     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
504         loc, i8, ValueRange({coroState, constFalse}));
505 
506     // Cast return code to i32.
507 
508     // After a suspension point decide if we should branch into resume, cleanup
509     // or suspend block of the coroutine (see @llvm.coro.suspend return code
510     // documentation).
511     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
512     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
513                                               op.cleanupDest()};
514     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
515         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
516         /*defaultDestination=*/op.suspendDest(),
517         /*defaultOperands=*/ValueRange(),
518         /*caseValues=*/caseValues,
519         /*caseDestinations=*/caseDest,
520         /*caseOperands=*/ArrayRef<ValueRange>(),
521         /*branchWeights=*/ArrayRef<int32_t>());
522 
523     return success();
524   }
525 };
526 } // namespace
527 
528 //===----------------------------------------------------------------------===//
529 // Convert async.runtime.create to the corresponding runtime API call.
530 //
531 // To allocate storage for the async values we use getelementptr trick:
532 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
533 //===----------------------------------------------------------------------===//
534 
535 namespace {
536 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
537 public:
538   using OpConversionPattern::OpConversionPattern;
539 
540   LogicalResult
541   matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands,
542                   ConversionPatternRewriter &rewriter) const override {
543     TypeConverter *converter = getTypeConverter();
544     Type resultType = op->getResultTypes()[0];
545 
546     // Tokens and Groups lowered to function calls without arguments.
547     if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
548       rewriter.replaceOpWithNewOp<CallOp>(
549           op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
550           converter->convertType(resultType));
551       return success();
552     }
553 
554     // To create a value we need to compute the storage requirement.
555     if (auto value = resultType.dyn_cast<ValueType>()) {
556       // Returns the size requirements for the async value storage.
557       auto sizeOf = [&](ValueType valueType) -> Value {
558         auto loc = op->getLoc();
559         auto i32 = rewriter.getI32Type();
560 
561         auto storedType = converter->convertType(valueType.getValueType());
562         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
563 
564         // %Size = getelementptr %T* null, int 1
565         // %SizeI = ptrtoint %T* %Size to i32
566         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
567         auto one = rewriter.create<LLVM::ConstantOp>(
568             loc, i32, rewriter.getI32IntegerAttr(1));
569         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
570                                                 one.getResult());
571         return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep);
572       };
573 
574       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
575                                           sizeOf(value));
576 
577       return success();
578     }
579 
580     return rewriter.notifyMatchFailure(op, "unsupported async type");
581   }
582 };
583 } // namespace
584 
585 //===----------------------------------------------------------------------===//
586 // Convert async.runtime.set_available to the corresponding runtime API call.
587 //===----------------------------------------------------------------------===//
588 
589 namespace {
590 class RuntimeSetAvailableOpLowering
591     : public OpConversionPattern<RuntimeSetAvailableOp> {
592 public:
593   using OpConversionPattern::OpConversionPattern;
594 
595   LogicalResult
596   matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
597                   ConversionPatternRewriter &rewriter) const override {
598     StringRef apiFuncName =
599         TypeSwitch<Type, StringRef>(op.operand().getType())
600             .Case<TokenType>([](Type) { return kEmplaceToken; })
601             .Case<ValueType>([](Type) { return kEmplaceValue; });
602 
603     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
604 
605     return success();
606   }
607 };
608 } // namespace
609 
610 //===----------------------------------------------------------------------===//
611 // Convert async.runtime.set_error to the corresponding runtime API call.
612 //===----------------------------------------------------------------------===//
613 
614 namespace {
615 class RuntimeSetErrorOpLowering
616     : public OpConversionPattern<RuntimeSetErrorOp> {
617 public:
618   using OpConversionPattern::OpConversionPattern;
619 
620   LogicalResult
621   matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands,
622                   ConversionPatternRewriter &rewriter) const override {
623     StringRef apiFuncName =
624         TypeSwitch<Type, StringRef>(op.operand().getType())
625             .Case<TokenType>([](Type) { return kSetTokenError; })
626             .Case<ValueType>([](Type) { return kSetValueError; });
627 
628     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
629 
630     return success();
631   }
632 };
633 } // namespace
634 
635 //===----------------------------------------------------------------------===//
636 // Convert async.runtime.is_error to the corresponding runtime API call.
637 //===----------------------------------------------------------------------===//
638 
639 namespace {
640 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
641 public:
642   using OpConversionPattern::OpConversionPattern;
643 
644   LogicalResult
645   matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands,
646                   ConversionPatternRewriter &rewriter) const override {
647     StringRef apiFuncName =
648         TypeSwitch<Type, StringRef>(op.operand().getType())
649             .Case<TokenType>([](Type) { return kIsTokenError; })
650             .Case<GroupType>([](Type) { return kIsGroupError; })
651             .Case<ValueType>([](Type) { return kIsValueError; });
652 
653     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
654                                         operands);
655     return success();
656   }
657 };
658 } // namespace
659 
660 //===----------------------------------------------------------------------===//
661 // Convert async.runtime.await to the corresponding runtime API call.
662 //===----------------------------------------------------------------------===//
663 
664 namespace {
665 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
666 public:
667   using OpConversionPattern::OpConversionPattern;
668 
669   LogicalResult
670   matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
671                   ConversionPatternRewriter &rewriter) const override {
672     StringRef apiFuncName =
673         TypeSwitch<Type, StringRef>(op.operand().getType())
674             .Case<TokenType>([](Type) { return kAwaitToken; })
675             .Case<ValueType>([](Type) { return kAwaitValue; })
676             .Case<GroupType>([](Type) { return kAwaitGroup; });
677 
678     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
679     rewriter.eraseOp(op);
680 
681     return success();
682   }
683 };
684 } // namespace
685 
686 //===----------------------------------------------------------------------===//
687 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
688 //===----------------------------------------------------------------------===//
689 
690 namespace {
691 class RuntimeAwaitAndResumeOpLowering
692     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
693 public:
694   using OpConversionPattern::OpConversionPattern;
695 
696   LogicalResult
697   matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
698                   ConversionPatternRewriter &rewriter) const override {
699     StringRef apiFuncName =
700         TypeSwitch<Type, StringRef>(op.operand().getType())
701             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
702             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
703             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
704 
705     Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
706     Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
707 
708     // A pointer to coroutine resume intrinsic wrapper.
709     addResumeFunction(op->getParentOfType<ModuleOp>());
710     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
711     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
712         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
713 
714     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
715                             ValueRange({operand, handle, resumePtr.res()}));
716     rewriter.eraseOp(op);
717 
718     return success();
719   }
720 };
721 } // namespace
722 
723 //===----------------------------------------------------------------------===//
724 // Convert async.runtime.resume to the corresponding runtime API call.
725 //===----------------------------------------------------------------------===//
726 
727 namespace {
728 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
729 public:
730   using OpConversionPattern::OpConversionPattern;
731 
732   LogicalResult
733   matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands,
734                   ConversionPatternRewriter &rewriter) const override {
735     // A pointer to coroutine resume intrinsic wrapper.
736     addResumeFunction(op->getParentOfType<ModuleOp>());
737     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
738     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
739         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
740 
741     // Call async runtime API to execute a coroutine in the managed thread.
742     auto coroHdl = RuntimeResumeOpAdaptor(operands).handle();
743     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute,
744                                         ValueRange({coroHdl, resumePtr.res()}));
745 
746     return success();
747   }
748 };
749 } // namespace
750 
751 //===----------------------------------------------------------------------===//
752 // Convert async.runtime.store to the corresponding runtime API call.
753 //===----------------------------------------------------------------------===//
754 
755 namespace {
756 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
757 public:
758   using OpConversionPattern::OpConversionPattern;
759 
760   LogicalResult
761   matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands,
762                   ConversionPatternRewriter &rewriter) const override {
763     Location loc = op->getLoc();
764 
765     // Get a pointer to the async value storage from the runtime.
766     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
767     auto storage = RuntimeStoreOpAdaptor(operands).storage();
768     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
769                                               TypeRange(i8Ptr), storage);
770 
771     // Cast from i8* to the LLVM pointer type.
772     auto valueType = op.value().getType();
773     auto llvmValueType = getTypeConverter()->convertType(valueType);
774     if (!llvmValueType)
775       return rewriter.notifyMatchFailure(
776           op, "failed to convert stored value type to LLVM type");
777 
778     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
779         loc, LLVM::LLVMPointerType::get(llvmValueType),
780         storagePtr.getResult(0));
781 
782     // Store the yielded value into the async value storage.
783     auto value = RuntimeStoreOpAdaptor(operands).value();
784     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
785 
786     // Erase the original runtime store operation.
787     rewriter.eraseOp(op);
788 
789     return success();
790   }
791 };
792 } // namespace
793 
794 //===----------------------------------------------------------------------===//
795 // Convert async.runtime.load to the corresponding runtime API call.
796 //===----------------------------------------------------------------------===//
797 
798 namespace {
799 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
800 public:
801   using OpConversionPattern::OpConversionPattern;
802 
803   LogicalResult
804   matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands,
805                   ConversionPatternRewriter &rewriter) const override {
806     Location loc = op->getLoc();
807 
808     // Get a pointer to the async value storage from the runtime.
809     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
810     auto storage = RuntimeLoadOpAdaptor(operands).storage();
811     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
812                                               TypeRange(i8Ptr), storage);
813 
814     // Cast from i8* to the LLVM pointer type.
815     auto valueType = op.result().getType();
816     auto llvmValueType = getTypeConverter()->convertType(valueType);
817     if (!llvmValueType)
818       return rewriter.notifyMatchFailure(
819           op, "failed to convert loaded value type to LLVM type");
820 
821     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
822         loc, LLVM::LLVMPointerType::get(llvmValueType),
823         storagePtr.getResult(0));
824 
825     // Load from the casted pointer.
826     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
827 
828     return success();
829   }
830 };
831 } // namespace
832 
833 //===----------------------------------------------------------------------===//
834 // Convert async.runtime.add_to_group to the corresponding runtime API call.
835 //===----------------------------------------------------------------------===//
836 
837 namespace {
838 class RuntimeAddToGroupOpLowering
839     : public OpConversionPattern<RuntimeAddToGroupOp> {
840 public:
841   using OpConversionPattern::OpConversionPattern;
842 
843   LogicalResult
844   matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands,
845                   ConversionPatternRewriter &rewriter) const override {
846     // Currently we can only add tokens to the group.
847     if (!op.operand().getType().isa<TokenType>())
848       return rewriter.notifyMatchFailure(op, "only token type is supported");
849 
850     // Replace with a runtime API function call.
851     rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup,
852                                         rewriter.getI64Type(), operands);
853 
854     return success();
855   }
856 };
857 } // namespace
858 
859 //===----------------------------------------------------------------------===//
860 // Async reference counting ops lowering (`async.runtime.add_ref` and
861 // `async.runtime.drop_ref` to the corresponding API calls).
862 //===----------------------------------------------------------------------===//
863 
864 namespace {
865 template <typename RefCountingOp>
866 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
867 public:
868   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
869                                  StringRef apiFunctionName)
870       : OpConversionPattern<RefCountingOp>(converter, ctx),
871         apiFunctionName(apiFunctionName) {}
872 
873   LogicalResult
874   matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands,
875                   ConversionPatternRewriter &rewriter) const override {
876     auto count =
877         rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(),
878                                     rewriter.getI32IntegerAttr(op.count()));
879 
880     auto operand = typename RefCountingOp::Adaptor(operands).operand();
881     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
882                                         ValueRange({operand, count}));
883 
884     return success();
885   }
886 
887 private:
888   StringRef apiFunctionName;
889 };
890 
891 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
892 public:
893   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
894       : RefCountingOpLowering(converter, ctx, kAddRef) {}
895 };
896 
897 class RuntimeDropRefOpLowering
898     : public RefCountingOpLowering<RuntimeDropRefOp> {
899 public:
900   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
901       : RefCountingOpLowering(converter, ctx, kDropRef) {}
902 };
903 } // namespace
904 
905 //===----------------------------------------------------------------------===//
906 // Convert return operations that return async values from async regions.
907 //===----------------------------------------------------------------------===//
908 
909 namespace {
910 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
911 public:
912   using OpConversionPattern::OpConversionPattern;
913 
914   LogicalResult
915   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
916                   ConversionPatternRewriter &rewriter) const override {
917     rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
918     return success();
919   }
920 };
921 } // namespace
922 
923 //===----------------------------------------------------------------------===//
924 
925 namespace {
926 struct ConvertAsyncToLLVMPass
927     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
928   void runOnOperation() override;
929 };
930 } // namespace
931 
932 void ConvertAsyncToLLVMPass::runOnOperation() {
933   ModuleOp module = getOperation();
934   MLIRContext *ctx = module->getContext();
935 
936   // Add declarations for most functions required by the coroutines lowering.
937   // We delay adding the resume function until it's needed because it currently
938   // fails to compile unless '-O0' is specified.
939   addAsyncRuntimeApiDeclarations(module);
940   addCRuntimeDeclarations(module);
941 
942   // Lower async.runtime and async.coro operations to Async Runtime API and
943   // LLVM coroutine intrinsics.
944 
945   // Convert async dialect types and operations to LLVM dialect.
946   AsyncRuntimeTypeConverter converter;
947   RewritePatternSet patterns(ctx);
948 
949   // We use conversion to LLVM type to lower async.runtime load and store
950   // operations.
951   LLVMTypeConverter llvmConverter(ctx);
952   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
953 
954   // Convert async types in function signatures and function calls.
955   populateFuncOpTypeConversionPattern(patterns, converter);
956   populateCallOpTypeConversionPattern(patterns, converter);
957 
958   // Convert return operations inside async.execute regions.
959   patterns.add<ReturnOpOpConversion>(converter, ctx);
960 
961   // Lower async.runtime operations to the async runtime API calls.
962   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
963                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
964                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
965                RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
966                RuntimeDropRefOpLowering>(converter, ctx);
967 
968   // Lower async.runtime operations that rely on LLVM type converter to convert
969   // from async value payload type to the LLVM type.
970   patterns.add<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
971                RuntimeLoadOpLowering>(llvmConverter, ctx);
972 
973   // Lower async coroutine operations to LLVM coroutine intrinsics.
974   patterns
975       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
976            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
977           converter, ctx);
978 
979   ConversionTarget target(*ctx);
980   target.addLegalOp<ConstantOp>();
981   target.addLegalDialect<LLVM::LLVMDialect>();
982 
983   // All operations from Async dialect must be lowered to the runtime API and
984   // LLVM intrinsics calls.
985   target.addIllegalDialect<AsyncDialect>();
986 
987   // Add dynamic legality constraints to apply conversions defined above.
988   target.addDynamicallyLegalOp<FuncOp>(
989       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
990   target.addDynamicallyLegalOp<ReturnOp>(
991       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
992   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
993     return converter.isSignatureLegal(op.getCalleeType());
994   });
995 
996   if (failed(applyPartialConversion(module, target, std::move(patterns))))
997     signalPassFailure();
998 }
999 
1000 //===----------------------------------------------------------------------===//
1001 // Patterns for structural type conversions for the Async dialect operations.
1002 //===----------------------------------------------------------------------===//
1003 
1004 namespace {
1005 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1006 public:
1007   using OpConversionPattern::OpConversionPattern;
1008   LogicalResult
1009   matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
1010                   ConversionPatternRewriter &rewriter) const override {
1011     ExecuteOp newOp =
1012         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1013     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1014                                 newOp.getRegion().end());
1015 
1016     // Set operands and update block argument and result types.
1017     newOp->setOperands(operands);
1018     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1019       return failure();
1020     for (auto result : newOp.getResults())
1021       result.setType(typeConverter->convertType(result.getType()));
1022 
1023     rewriter.replaceOp(op, newOp.getResults());
1024     return success();
1025   }
1026 };
1027 
1028 // Dummy pattern to trigger the appropriate type conversion / materialization.
1029 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1030 public:
1031   using OpConversionPattern::OpConversionPattern;
1032   LogicalResult
1033   matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
1034                   ConversionPatternRewriter &rewriter) const override {
1035     rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
1036     return success();
1037   }
1038 };
1039 
1040 // Dummy pattern to trigger the appropriate type conversion / materialization.
1041 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1042 public:
1043   using OpConversionPattern::OpConversionPattern;
1044   LogicalResult
1045   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
1046                   ConversionPatternRewriter &rewriter) const override {
1047     rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
1048     return success();
1049   }
1050 };
1051 } // namespace
1052 
1053 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1054   return std::make_unique<ConvertAsyncToLLVMPass>();
1055 }
1056 
1057 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1058     TypeConverter &typeConverter, RewritePatternSet &patterns,
1059     ConversionTarget &target) {
1060   typeConverter.addConversion([&](TokenType type) { return type; });
1061   typeConverter.addConversion([&](ValueType type) {
1062     Type converted = typeConverter.convertType(type.getValueType());
1063     return converted ? ValueType::get(converted) : converted;
1064   });
1065 
1066   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1067       typeConverter, patterns.getContext());
1068 
1069   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1070       [&](Operation *op) { return typeConverter.isLegal(op); });
1071 }
1072