1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 #define DEBUG_TYPE "convert-async-to-llvm"
28 
29 using namespace mlir;
30 using namespace mlir::async;
31 
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35 
36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
52 static constexpr const char *kGetValueStorage =
53     "mlirAsyncRuntimeGetValueStorage";
54 static constexpr const char *kAddTokenToGroup =
55     "mlirAsyncRuntimeAddTokenToGroup";
56 static constexpr const char *kAwaitTokenAndExecute =
57     "mlirAsyncRuntimeAwaitTokenAndExecute";
58 static constexpr const char *kAwaitValueAndExecute =
59     "mlirAsyncRuntimeAwaitValueAndExecute";
60 static constexpr const char *kAwaitAllAndExecute =
61     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
62 
63 namespace {
64 /// Async Runtime API function types.
65 ///
66 /// Because we can't create API function signature for type parametrized
67 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
68 /// lowering all async data types become opaque pointers at runtime.
69 struct AsyncAPI {
70   // All async types are lowered to opaque i8* LLVM pointers at runtime.
71   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
72     return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
73   }
74 
75   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
76     return LLVM::LLVMTokenType::get(ctx);
77   }
78 
79   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
80     auto ref = opaquePointerType(ctx);
81     auto count = IntegerType::get(ctx, 64);
82     return FunctionType::get(ctx, {ref, count}, {});
83   }
84 
85   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
86     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
87   }
88 
89   static FunctionType createValueFunctionType(MLIRContext *ctx) {
90     auto i64 = IntegerType::get(ctx, 64);
91     auto value = opaquePointerType(ctx);
92     return FunctionType::get(ctx, {i64}, {value});
93   }
94 
95   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
96     auto i64 = IntegerType::get(ctx, 64);
97     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
98   }
99 
100   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
101     auto value = opaquePointerType(ctx);
102     auto storage = opaquePointerType(ctx);
103     return FunctionType::get(ctx, {value}, {storage});
104   }
105 
106   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
107     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
108   }
109 
110   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
111     auto value = opaquePointerType(ctx);
112     return FunctionType::get(ctx, {value}, {});
113   }
114 
115   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
116     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
117   }
118 
119   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
120     auto value = opaquePointerType(ctx);
121     return FunctionType::get(ctx, {value}, {});
122   }
123 
124   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
125     auto i1 = IntegerType::get(ctx, 1);
126     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
127   }
128 
129   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
130     auto value = opaquePointerType(ctx);
131     auto i1 = IntegerType::get(ctx, 1);
132     return FunctionType::get(ctx, {value}, {i1});
133   }
134 
135   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
136     auto i1 = IntegerType::get(ctx, 1);
137     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
138   }
139 
140   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
141     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
142   }
143 
144   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
145     auto value = opaquePointerType(ctx);
146     return FunctionType::get(ctx, {value}, {});
147   }
148 
149   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
150     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
151   }
152 
153   static FunctionType executeFunctionType(MLIRContext *ctx) {
154     auto hdl = opaquePointerType(ctx);
155     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
156     return FunctionType::get(ctx, {hdl, resume}, {});
157   }
158 
159   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
160     auto i64 = IntegerType::get(ctx, 64);
161     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
162                              {i64});
163   }
164 
165   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
166     auto hdl = opaquePointerType(ctx);
167     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
168     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
169   }
170 
171   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
172     auto value = opaquePointerType(ctx);
173     auto hdl = opaquePointerType(ctx);
174     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
175     return FunctionType::get(ctx, {value, hdl, resume}, {});
176   }
177 
178   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
179     auto hdl = opaquePointerType(ctx);
180     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
181     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
182   }
183 
184   // Auxiliary coroutine resume intrinsic wrapper.
185   static Type resumeFunctionType(MLIRContext *ctx) {
186     auto voidTy = LLVM::LLVMVoidType::get(ctx);
187     auto i8Ptr = opaquePointerType(ctx);
188     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
189   }
190 };
191 } // namespace
192 
193 /// Adds Async Runtime C API declarations to the module.
194 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
195   auto builder =
196       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
197 
198   auto addFuncDecl = [&](StringRef name, FunctionType type) {
199     if (module.lookupSymbol(name))
200       return;
201     builder.create<FuncOp>(name, type).setPrivate();
202   };
203 
204   MLIRContext *ctx = module.getContext();
205   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
206   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
207   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
208   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
209   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
210   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
211   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
212   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
213   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
214   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
215   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
216   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
217   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
218   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
219   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
220   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
221   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
222   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
223   addFuncDecl(kAwaitTokenAndExecute,
224               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
225   addFuncDecl(kAwaitValueAndExecute,
226               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
227   addFuncDecl(kAwaitAllAndExecute,
228               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // Coroutine resume function wrapper.
233 //===----------------------------------------------------------------------===//
234 
235 static constexpr const char *kResume = "__resume";
236 
237 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
238 /// intrinsics. We need this function to be able to pass it to the async
239 /// runtime execute API.
240 static void addResumeFunction(ModuleOp module) {
241   if (module.lookupSymbol(kResume))
242     return;
243 
244   MLIRContext *ctx = module.getContext();
245   auto loc = module.getLoc();
246   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
247 
248   auto voidTy = LLVM::LLVMVoidType::get(ctx);
249   auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
250 
251   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
252       kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
253   resumeOp.setPrivate();
254 
255   auto *block = resumeOp.addEntryBlock();
256   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
257 
258   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
259   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // Convert Async dialect types to LLVM types.
264 //===----------------------------------------------------------------------===//
265 
266 namespace {
267 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
268 /// their runtime type (opaque pointers) and does not convert any other types.
269 class AsyncRuntimeTypeConverter : public TypeConverter {
270 public:
271   AsyncRuntimeTypeConverter() {
272     addConversion([](Type type) { return type; });
273     addConversion(convertAsyncTypes);
274   }
275 
276   static Optional<Type> convertAsyncTypes(Type type) {
277     if (type.isa<TokenType, GroupType, ValueType>())
278       return AsyncAPI::opaquePointerType(type.getContext());
279 
280     if (type.isa<CoroIdType, CoroStateType>())
281       return AsyncAPI::tokenType(type.getContext());
282     if (type.isa<CoroHandleType>())
283       return AsyncAPI::opaquePointerType(type.getContext());
284 
285     return llvm::None;
286   }
287 };
288 } // namespace
289 
290 //===----------------------------------------------------------------------===//
291 // Convert async.coro.id to @llvm.coro.id intrinsic.
292 //===----------------------------------------------------------------------===//
293 
294 namespace {
295 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
296 public:
297   using OpConversionPattern::OpConversionPattern;
298 
299   LogicalResult
300   matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
301                   ConversionPatternRewriter &rewriter) const override {
302     auto token = AsyncAPI::tokenType(op->getContext());
303     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
304     auto loc = op->getLoc();
305 
306     // Constants for initializing coroutine frame.
307     auto constZero = rewriter.create<LLVM::ConstantOp>(
308         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
309     auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
310 
311     // Get coroutine id: @llvm.coro.id.
312     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
313         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
314 
315     return success();
316   }
317 };
318 } // namespace
319 
320 //===----------------------------------------------------------------------===//
321 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
322 //===----------------------------------------------------------------------===//
323 
324 namespace {
325 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
326 public:
327   using OpConversionPattern::OpConversionPattern;
328 
329   LogicalResult
330   matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
331                   ConversionPatternRewriter &rewriter) const override {
332     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
333     auto loc = op->getLoc();
334 
335     // Get coroutine frame size: @llvm.coro.size.i64.
336     auto coroSize =
337         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
338     // The coroutine lowering doesn't properly account for alignment of the
339     // frame, so align everything to 64 bytes which ought to be enough for
340     // everyone. https://llvm.org/PR53148
341     auto coroAlign = rewriter.create<LLVM::ConstantOp>(
342         op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(64));
343 
344     // Allocate memory for the coroutine frame.
345     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
346         op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
347     auto coroAlloc = rewriter.create<LLVM::CallOp>(
348         loc, i8Ptr, SymbolRefAttr::get(allocFuncOp),
349         ValueRange{coroAlign, coroSize.getResult()});
350 
351     // Begin a coroutine: @llvm.coro.begin.
352     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
353     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
354         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
355 
356     return success();
357   }
358 };
359 } // namespace
360 
361 //===----------------------------------------------------------------------===//
362 // Convert async.coro.free to @llvm.coro.free intrinsic.
363 //===----------------------------------------------------------------------===//
364 
365 namespace {
366 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
367 public:
368   using OpConversionPattern::OpConversionPattern;
369 
370   LogicalResult
371   matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
372                   ConversionPatternRewriter &rewriter) const override {
373     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
374     auto loc = op->getLoc();
375 
376     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
377     auto coroMem =
378         rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
379 
380     // Free the memory.
381     auto freeFuncOp =
382         LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
383     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
384                                               SymbolRefAttr::get(freeFuncOp),
385                                               ValueRange(coroMem.getResult()));
386 
387     return success();
388   }
389 };
390 } // namespace
391 
392 //===----------------------------------------------------------------------===//
393 // Convert async.coro.end to @llvm.coro.end intrinsic.
394 //===----------------------------------------------------------------------===//
395 
396 namespace {
397 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
398 public:
399   using OpConversionPattern::OpConversionPattern;
400 
401   LogicalResult
402   matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
403                   ConversionPatternRewriter &rewriter) const override {
404     // We are not in the block that is part of the unwind sequence.
405     auto constFalse = rewriter.create<LLVM::ConstantOp>(
406         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
407 
408     // Mark the end of a coroutine: @llvm.coro.end.
409     auto coroHdl = adaptor.handle();
410     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
411                                      ValueRange({coroHdl, constFalse}));
412     rewriter.eraseOp(op);
413 
414     return success();
415   }
416 };
417 } // namespace
418 
419 //===----------------------------------------------------------------------===//
420 // Convert async.coro.save to @llvm.coro.save intrinsic.
421 //===----------------------------------------------------------------------===//
422 
423 namespace {
424 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
425 public:
426   using OpConversionPattern::OpConversionPattern;
427 
428   LogicalResult
429   matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
430                   ConversionPatternRewriter &rewriter) const override {
431     // Save the coroutine state: @llvm.coro.save
432     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
433         op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
434 
435     return success();
436   }
437 };
438 } // namespace
439 
440 //===----------------------------------------------------------------------===//
441 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
442 //===----------------------------------------------------------------------===//
443 
444 namespace {
445 
446 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
447 /// branch to the appropriate block based on the return code.
448 ///
449 /// Before:
450 ///
451 ///   ^suspended:
452 ///     "opBefore"(...)
453 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
454 ///   ^resume:
455 ///     "op"(...)
456 ///   ^cleanup: ...
457 ///   ^suspend: ...
458 ///
459 /// After:
460 ///
461 ///   ^suspended:
462 ///     "opBefore"(...)
463 ///     %suspend = llmv.intr.coro.suspend ...
464 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
465 ///   ^resume:
466 ///     "op"(...)
467 ///   ^cleanup: ...
468 ///   ^suspend: ...
469 ///
470 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
471 public:
472   using OpConversionPattern::OpConversionPattern;
473 
474   LogicalResult
475   matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
476                   ConversionPatternRewriter &rewriter) const override {
477     auto i8 = rewriter.getIntegerType(8);
478     auto i32 = rewriter.getI32Type();
479     auto loc = op->getLoc();
480 
481     // This is not a final suspension point.
482     auto constFalse = rewriter.create<LLVM::ConstantOp>(
483         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
484 
485     // Suspend a coroutine: @llvm.coro.suspend
486     auto coroState = adaptor.state();
487     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
488         loc, i8, ValueRange({coroState, constFalse}));
489 
490     // Cast return code to i32.
491 
492     // After a suspension point decide if we should branch into resume, cleanup
493     // or suspend block of the coroutine (see @llvm.coro.suspend return code
494     // documentation).
495     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
496     llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
497                                               op.cleanupDest()};
498     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
499         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
500         /*defaultDestination=*/op.suspendDest(),
501         /*defaultOperands=*/ValueRange(),
502         /*caseValues=*/caseValues,
503         /*caseDestinations=*/caseDest,
504         /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
505         /*branchWeights=*/ArrayRef<int32_t>());
506 
507     return success();
508   }
509 };
510 } // namespace
511 
512 //===----------------------------------------------------------------------===//
513 // Convert async.runtime.create to the corresponding runtime API call.
514 //
515 // To allocate storage for the async values we use getelementptr trick:
516 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
517 //===----------------------------------------------------------------------===//
518 
519 namespace {
520 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
521 public:
522   using OpConversionPattern::OpConversionPattern;
523 
524   LogicalResult
525   matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
526                   ConversionPatternRewriter &rewriter) const override {
527     TypeConverter *converter = getTypeConverter();
528     Type resultType = op->getResultTypes()[0];
529 
530     // Tokens creation maps to a simple function call.
531     if (resultType.isa<TokenType>()) {
532       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
533                                           converter->convertType(resultType));
534       return success();
535     }
536 
537     // To create a value we need to compute the storage requirement.
538     if (auto value = resultType.dyn_cast<ValueType>()) {
539       // Returns the size requirements for the async value storage.
540       auto sizeOf = [&](ValueType valueType) -> Value {
541         auto loc = op->getLoc();
542         auto i64 = rewriter.getI64Type();
543 
544         auto storedType = converter->convertType(valueType.getValueType());
545         auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
546 
547         // %Size = getelementptr %T* null, int 1
548         // %SizeI = ptrtoint %T* %Size to i64
549         auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
550         auto one = rewriter.create<LLVM::ConstantOp>(
551             loc, i64, rewriter.getI64IntegerAttr(1));
552         auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
553                                                 one.getResult());
554         return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
555       };
556 
557       rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType,
558                                           sizeOf(value));
559 
560       return success();
561     }
562 
563     return rewriter.notifyMatchFailure(op, "unsupported async type");
564   }
565 };
566 } // namespace
567 
568 //===----------------------------------------------------------------------===//
569 // Convert async.runtime.create_group to the corresponding runtime API call.
570 //===----------------------------------------------------------------------===//
571 
572 namespace {
573 class RuntimeCreateGroupOpLowering
574     : public OpConversionPattern<RuntimeCreateGroupOp> {
575 public:
576   using OpConversionPattern::OpConversionPattern;
577 
578   LogicalResult
579   matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
580                   ConversionPatternRewriter &rewriter) const override {
581     TypeConverter *converter = getTypeConverter();
582     Type resultType = op.getResult().getType();
583 
584     rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup,
585                                         converter->convertType(resultType),
586                                         adaptor.getOperands());
587     return success();
588   }
589 };
590 } // namespace
591 
592 //===----------------------------------------------------------------------===//
593 // Convert async.runtime.set_available to the corresponding runtime API call.
594 //===----------------------------------------------------------------------===//
595 
596 namespace {
597 class RuntimeSetAvailableOpLowering
598     : public OpConversionPattern<RuntimeSetAvailableOp> {
599 public:
600   using OpConversionPattern::OpConversionPattern;
601 
602   LogicalResult
603   matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
604                   ConversionPatternRewriter &rewriter) const override {
605     StringRef apiFuncName =
606         TypeSwitch<Type, StringRef>(op.operand().getType())
607             .Case<TokenType>([](Type) { return kEmplaceToken; })
608             .Case<ValueType>([](Type) { return kEmplaceValue; });
609 
610     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
611                                         adaptor.getOperands());
612 
613     return success();
614   }
615 };
616 } // namespace
617 
618 //===----------------------------------------------------------------------===//
619 // Convert async.runtime.set_error to the corresponding runtime API call.
620 //===----------------------------------------------------------------------===//
621 
622 namespace {
623 class RuntimeSetErrorOpLowering
624     : public OpConversionPattern<RuntimeSetErrorOp> {
625 public:
626   using OpConversionPattern::OpConversionPattern;
627 
628   LogicalResult
629   matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
630                   ConversionPatternRewriter &rewriter) const override {
631     StringRef apiFuncName =
632         TypeSwitch<Type, StringRef>(op.operand().getType())
633             .Case<TokenType>([](Type) { return kSetTokenError; })
634             .Case<ValueType>([](Type) { return kSetValueError; });
635 
636     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
637                                         adaptor.getOperands());
638 
639     return success();
640   }
641 };
642 } // namespace
643 
644 //===----------------------------------------------------------------------===//
645 // Convert async.runtime.is_error to the corresponding runtime API call.
646 //===----------------------------------------------------------------------===//
647 
648 namespace {
649 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
650 public:
651   using OpConversionPattern::OpConversionPattern;
652 
653   LogicalResult
654   matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
655                   ConversionPatternRewriter &rewriter) const override {
656     StringRef apiFuncName =
657         TypeSwitch<Type, StringRef>(op.operand().getType())
658             .Case<TokenType>([](Type) { return kIsTokenError; })
659             .Case<GroupType>([](Type) { return kIsGroupError; })
660             .Case<ValueType>([](Type) { return kIsValueError; });
661 
662     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
663                                         adaptor.getOperands());
664     return success();
665   }
666 };
667 } // namespace
668 
669 //===----------------------------------------------------------------------===//
670 // Convert async.runtime.await to the corresponding runtime API call.
671 //===----------------------------------------------------------------------===//
672 
673 namespace {
674 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
675 public:
676   using OpConversionPattern::OpConversionPattern;
677 
678   LogicalResult
679   matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
680                   ConversionPatternRewriter &rewriter) const override {
681     StringRef apiFuncName =
682         TypeSwitch<Type, StringRef>(op.operand().getType())
683             .Case<TokenType>([](Type) { return kAwaitToken; })
684             .Case<ValueType>([](Type) { return kAwaitValue; })
685             .Case<GroupType>([](Type) { return kAwaitGroup; });
686 
687     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
688                             adaptor.getOperands());
689     rewriter.eraseOp(op);
690 
691     return success();
692   }
693 };
694 } // namespace
695 
696 //===----------------------------------------------------------------------===//
697 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
698 //===----------------------------------------------------------------------===//
699 
700 namespace {
701 class RuntimeAwaitAndResumeOpLowering
702     : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
703 public:
704   using OpConversionPattern::OpConversionPattern;
705 
706   LogicalResult
707   matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
708                   ConversionPatternRewriter &rewriter) const override {
709     StringRef apiFuncName =
710         TypeSwitch<Type, StringRef>(op.operand().getType())
711             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
712             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
713             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
714 
715     Value operand = adaptor.operand();
716     Value handle = adaptor.handle();
717 
718     // A pointer to coroutine resume intrinsic wrapper.
719     addResumeFunction(op->getParentOfType<ModuleOp>());
720     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
721     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
722         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
723 
724     rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
725                             ValueRange({operand, handle, resumePtr.getRes()}));
726     rewriter.eraseOp(op);
727 
728     return success();
729   }
730 };
731 } // namespace
732 
733 //===----------------------------------------------------------------------===//
734 // Convert async.runtime.resume to the corresponding runtime API call.
735 //===----------------------------------------------------------------------===//
736 
737 namespace {
738 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
739 public:
740   using OpConversionPattern::OpConversionPattern;
741 
742   LogicalResult
743   matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
744                   ConversionPatternRewriter &rewriter) const override {
745     // A pointer to coroutine resume intrinsic wrapper.
746     addResumeFunction(op->getParentOfType<ModuleOp>());
747     auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
748     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
749         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
750 
751     // Call async runtime API to execute a coroutine in the managed thread.
752     auto coroHdl = adaptor.handle();
753     rewriter.replaceOpWithNewOp<CallOp>(
754         op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
755 
756     return success();
757   }
758 };
759 } // namespace
760 
761 //===----------------------------------------------------------------------===//
762 // Convert async.runtime.store to the corresponding runtime API call.
763 //===----------------------------------------------------------------------===//
764 
765 namespace {
766 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
767 public:
768   using OpConversionPattern::OpConversionPattern;
769 
770   LogicalResult
771   matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
772                   ConversionPatternRewriter &rewriter) const override {
773     Location loc = op->getLoc();
774 
775     // Get a pointer to the async value storage from the runtime.
776     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
777     auto storage = adaptor.storage();
778     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
779                                               TypeRange(i8Ptr), storage);
780 
781     // Cast from i8* to the LLVM pointer type.
782     auto valueType = op.value().getType();
783     auto llvmValueType = getTypeConverter()->convertType(valueType);
784     if (!llvmValueType)
785       return rewriter.notifyMatchFailure(
786           op, "failed to convert stored value type to LLVM type");
787 
788     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
789         loc, LLVM::LLVMPointerType::get(llvmValueType),
790         storagePtr.getResult(0));
791 
792     // Store the yielded value into the async value storage.
793     auto value = adaptor.value();
794     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
795 
796     // Erase the original runtime store operation.
797     rewriter.eraseOp(op);
798 
799     return success();
800   }
801 };
802 } // namespace
803 
804 //===----------------------------------------------------------------------===//
805 // Convert async.runtime.load to the corresponding runtime API call.
806 //===----------------------------------------------------------------------===//
807 
808 namespace {
809 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
810 public:
811   using OpConversionPattern::OpConversionPattern;
812 
813   LogicalResult
814   matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
815                   ConversionPatternRewriter &rewriter) const override {
816     Location loc = op->getLoc();
817 
818     // Get a pointer to the async value storage from the runtime.
819     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
820     auto storage = adaptor.storage();
821     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
822                                               TypeRange(i8Ptr), storage);
823 
824     // Cast from i8* to the LLVM pointer type.
825     auto valueType = op.result().getType();
826     auto llvmValueType = getTypeConverter()->convertType(valueType);
827     if (!llvmValueType)
828       return rewriter.notifyMatchFailure(
829           op, "failed to convert loaded value type to LLVM type");
830 
831     auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
832         loc, LLVM::LLVMPointerType::get(llvmValueType),
833         storagePtr.getResult(0));
834 
835     // Load from the casted pointer.
836     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
837 
838     return success();
839   }
840 };
841 } // namespace
842 
843 //===----------------------------------------------------------------------===//
844 // Convert async.runtime.add_to_group to the corresponding runtime API call.
845 //===----------------------------------------------------------------------===//
846 
847 namespace {
848 class RuntimeAddToGroupOpLowering
849     : public OpConversionPattern<RuntimeAddToGroupOp> {
850 public:
851   using OpConversionPattern::OpConversionPattern;
852 
853   LogicalResult
854   matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
855                   ConversionPatternRewriter &rewriter) const override {
856     // Currently we can only add tokens to the group.
857     if (!op.operand().getType().isa<TokenType>())
858       return rewriter.notifyMatchFailure(op, "only token type is supported");
859 
860     // Replace with a runtime API function call.
861     rewriter.replaceOpWithNewOp<CallOp>(
862         op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
863 
864     return success();
865   }
866 };
867 } // namespace
868 
869 //===----------------------------------------------------------------------===//
870 // Async reference counting ops lowering (`async.runtime.add_ref` and
871 // `async.runtime.drop_ref` to the corresponding API calls).
872 //===----------------------------------------------------------------------===//
873 
874 namespace {
875 template <typename RefCountingOp>
876 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
877 public:
878   explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
879                                  StringRef apiFunctionName)
880       : OpConversionPattern<RefCountingOp>(converter, ctx),
881         apiFunctionName(apiFunctionName) {}
882 
883   LogicalResult
884   matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
885                   ConversionPatternRewriter &rewriter) const override {
886     auto count = rewriter.create<arith::ConstantOp>(
887         op->getLoc(), rewriter.getI64Type(),
888         rewriter.getI64IntegerAttr(op.count()));
889 
890     auto operand = adaptor.operand();
891     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
892                                         ValueRange({operand, count}));
893 
894     return success();
895   }
896 
897 private:
898   StringRef apiFunctionName;
899 };
900 
901 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
902 public:
903   explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
904       : RefCountingOpLowering(converter, ctx, kAddRef) {}
905 };
906 
907 class RuntimeDropRefOpLowering
908     : public RefCountingOpLowering<RuntimeDropRefOp> {
909 public:
910   explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
911       : RefCountingOpLowering(converter, ctx, kDropRef) {}
912 };
913 } // namespace
914 
915 //===----------------------------------------------------------------------===//
916 // Convert return operations that return async values from async regions.
917 //===----------------------------------------------------------------------===//
918 
919 namespace {
920 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> {
921 public:
922   using OpConversionPattern::OpConversionPattern;
923 
924   LogicalResult
925   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
926                   ConversionPatternRewriter &rewriter) const override {
927     rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
928     return success();
929   }
930 };
931 } // namespace
932 
933 //===----------------------------------------------------------------------===//
934 
935 namespace {
936 struct ConvertAsyncToLLVMPass
937     : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
938   void runOnOperation() override;
939 };
940 } // namespace
941 
942 void ConvertAsyncToLLVMPass::runOnOperation() {
943   ModuleOp module = getOperation();
944   MLIRContext *ctx = module->getContext();
945 
946   // Add declarations for most functions required by the coroutines lowering.
947   // We delay adding the resume function until it's needed because it currently
948   // fails to compile unless '-O0' is specified.
949   addAsyncRuntimeApiDeclarations(module);
950 
951   // Lower async.runtime and async.coro operations to Async Runtime API and
952   // LLVM coroutine intrinsics.
953 
954   // Convert async dialect types and operations to LLVM dialect.
955   AsyncRuntimeTypeConverter converter;
956   RewritePatternSet patterns(ctx);
957 
958   // We use conversion to LLVM type to lower async.runtime load and store
959   // operations.
960   LLVMTypeConverter llvmConverter(ctx);
961   llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
962 
963   // Convert async types in function signatures and function calls.
964   populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, converter);
965   populateCallOpTypeConversionPattern(patterns, converter);
966 
967   // Convert return operations inside async.execute regions.
968   patterns.add<ReturnOpOpConversion>(converter, ctx);
969 
970   // Lower async.runtime operations to the async runtime API calls.
971   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
972                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
973                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
974                RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
975                RuntimeDropRefOpLowering>(converter, ctx);
976 
977   // Lower async.runtime operations that rely on LLVM type converter to convert
978   // from async value payload type to the LLVM type.
979   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
980                RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
981                                                               ctx);
982 
983   // Lower async coroutine operations to LLVM coroutine intrinsics.
984   patterns
985       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
986            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
987           converter, ctx);
988 
989   ConversionTarget target(*ctx);
990   target
991       .addLegalOp<arith::ConstantOp, ConstantOp, UnrealizedConversionCastOp>();
992   target.addLegalDialect<LLVM::LLVMDialect>();
993 
994   // All operations from Async dialect must be lowered to the runtime API and
995   // LLVM intrinsics calls.
996   target.addIllegalDialect<AsyncDialect>();
997 
998   // Add dynamic legality constraints to apply conversions defined above.
999   target.addDynamicallyLegalOp<FuncOp>(
1000       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1001   target.addDynamicallyLegalOp<ReturnOp>(
1002       [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
1003   target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
1004     return converter.isSignatureLegal(op.getCalleeType());
1005   });
1006 
1007   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1008     signalPassFailure();
1009 }
1010 
1011 //===----------------------------------------------------------------------===//
1012 // Patterns for structural type conversions for the Async dialect operations.
1013 //===----------------------------------------------------------------------===//
1014 
1015 namespace {
1016 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1017 public:
1018   using OpConversionPattern::OpConversionPattern;
1019   LogicalResult
1020   matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1021                   ConversionPatternRewriter &rewriter) const override {
1022     ExecuteOp newOp =
1023         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1024     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1025                                 newOp.getRegion().end());
1026 
1027     // Set operands and update block argument and result types.
1028     newOp->setOperands(adaptor.getOperands());
1029     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1030       return failure();
1031     for (auto result : newOp.getResults())
1032       result.setType(typeConverter->convertType(result.getType()));
1033 
1034     rewriter.replaceOp(op, newOp.getResults());
1035     return success();
1036   }
1037 };
1038 
1039 // Dummy pattern to trigger the appropriate type conversion / materialization.
1040 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1041 public:
1042   using OpConversionPattern::OpConversionPattern;
1043   LogicalResult
1044   matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1045                   ConversionPatternRewriter &rewriter) const override {
1046     rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1047     return success();
1048   }
1049 };
1050 
1051 // Dummy pattern to trigger the appropriate type conversion / materialization.
1052 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1053 public:
1054   using OpConversionPattern::OpConversionPattern;
1055   LogicalResult
1056   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1057                   ConversionPatternRewriter &rewriter) const override {
1058     rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1059     return success();
1060   }
1061 };
1062 } // namespace
1063 
1064 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1065   return std::make_unique<ConvertAsyncToLLVMPass>();
1066 }
1067 
1068 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1069     TypeConverter &typeConverter, RewritePatternSet &patterns,
1070     ConversionTarget &target) {
1071   typeConverter.addConversion([&](TokenType type) { return type; });
1072   typeConverter.addConversion([&](ValueType type) {
1073     Type converted = typeConverter.convertType(type.getValueType());
1074     return converted ? ValueType::get(converted) : converted;
1075   });
1076 
1077   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1078       typeConverter, patterns.getContext());
1079 
1080   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1081       [&](Operation *op) { return typeConverter.isLegal(op); });
1082 }
1083