1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 #define DEBUG_TYPE "convert-async-to-llvm" 28 29 using namespace mlir; 30 using namespace mlir::async; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 52 static constexpr const char *kGetValueStorage = 53 "mlirAsyncRuntimeGetValueStorage"; 54 static constexpr const char *kAddTokenToGroup = 55 "mlirAsyncRuntimeAddTokenToGroup"; 56 static constexpr const char *kAwaitTokenAndExecute = 57 "mlirAsyncRuntimeAwaitTokenAndExecute"; 58 static constexpr const char *kAwaitValueAndExecute = 59 "mlirAsyncRuntimeAwaitValueAndExecute"; 60 static constexpr const char *kAwaitAllAndExecute = 61 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 62 63 namespace { 64 /// Async Runtime API function types. 65 /// 66 /// Because we can't create API function signature for type parametrized 67 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 68 /// lowering all async data types become opaque pointers at runtime. 69 struct AsyncAPI { 70 // All async types are lowered to opaque i8* LLVM pointers at runtime. 71 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 72 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 73 } 74 75 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 76 return LLVM::LLVMTokenType::get(ctx); 77 } 78 79 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 80 auto ref = opaquePointerType(ctx); 81 auto count = IntegerType::get(ctx, 64); 82 return FunctionType::get(ctx, {ref, count}, {}); 83 } 84 85 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 86 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 87 } 88 89 static FunctionType createValueFunctionType(MLIRContext *ctx) { 90 auto i64 = IntegerType::get(ctx, 64); 91 auto value = opaquePointerType(ctx); 92 return FunctionType::get(ctx, {i64}, {value}); 93 } 94 95 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 96 auto i64 = IntegerType::get(ctx, 64); 97 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 98 } 99 100 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 101 auto value = opaquePointerType(ctx); 102 auto storage = opaquePointerType(ctx); 103 return FunctionType::get(ctx, {value}, {storage}); 104 } 105 106 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 107 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 108 } 109 110 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 111 auto value = opaquePointerType(ctx); 112 return FunctionType::get(ctx, {value}, {}); 113 } 114 115 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 116 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 117 } 118 119 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 120 auto value = opaquePointerType(ctx); 121 return FunctionType::get(ctx, {value}, {}); 122 } 123 124 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 125 auto i1 = IntegerType::get(ctx, 1); 126 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 127 } 128 129 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 130 auto value = opaquePointerType(ctx); 131 auto i1 = IntegerType::get(ctx, 1); 132 return FunctionType::get(ctx, {value}, {i1}); 133 } 134 135 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 136 auto i1 = IntegerType::get(ctx, 1); 137 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 138 } 139 140 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 141 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 142 } 143 144 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 145 auto value = opaquePointerType(ctx); 146 return FunctionType::get(ctx, {value}, {}); 147 } 148 149 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 150 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 151 } 152 153 static FunctionType executeFunctionType(MLIRContext *ctx) { 154 auto hdl = opaquePointerType(ctx); 155 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 156 return FunctionType::get(ctx, {hdl, resume}, {}); 157 } 158 159 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 160 auto i64 = IntegerType::get(ctx, 64); 161 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 162 {i64}); 163 } 164 165 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 166 auto hdl = opaquePointerType(ctx); 167 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 168 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 169 } 170 171 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 172 auto value = opaquePointerType(ctx); 173 auto hdl = opaquePointerType(ctx); 174 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 175 return FunctionType::get(ctx, {value, hdl, resume}, {}); 176 } 177 178 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 179 auto hdl = opaquePointerType(ctx); 180 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 181 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 182 } 183 184 // Auxiliary coroutine resume intrinsic wrapper. 185 static Type resumeFunctionType(MLIRContext *ctx) { 186 auto voidTy = LLVM::LLVMVoidType::get(ctx); 187 auto i8Ptr = opaquePointerType(ctx); 188 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 189 } 190 }; 191 } // namespace 192 193 /// Adds Async Runtime C API declarations to the module. 194 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 195 auto builder = 196 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 197 198 auto addFuncDecl = [&](StringRef name, FunctionType type) { 199 if (module.lookupSymbol(name)) 200 return; 201 builder.create<FuncOp>(name, type).setPrivate(); 202 }; 203 204 MLIRContext *ctx = module.getContext(); 205 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 206 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 207 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 208 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 209 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 210 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 211 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 212 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 213 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 214 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 215 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 216 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 217 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 218 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 219 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 220 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 221 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 222 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 223 addFuncDecl(kAwaitTokenAndExecute, 224 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 225 addFuncDecl(kAwaitValueAndExecute, 226 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 227 addFuncDecl(kAwaitAllAndExecute, 228 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 229 } 230 231 //===----------------------------------------------------------------------===// 232 // Coroutine resume function wrapper. 233 //===----------------------------------------------------------------------===// 234 235 static constexpr const char *kResume = "__resume"; 236 237 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 238 /// intrinsics. We need this function to be able to pass it to the async 239 /// runtime execute API. 240 static void addResumeFunction(ModuleOp module) { 241 if (module.lookupSymbol(kResume)) 242 return; 243 244 MLIRContext *ctx = module.getContext(); 245 auto loc = module.getLoc(); 246 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 247 248 auto voidTy = LLVM::LLVMVoidType::get(ctx); 249 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 250 251 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 252 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 253 resumeOp.setPrivate(); 254 255 auto *block = resumeOp.addEntryBlock(); 256 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 257 258 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 259 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // Convert Async dialect types to LLVM types. 264 //===----------------------------------------------------------------------===// 265 266 namespace { 267 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 268 /// their runtime type (opaque pointers) and does not convert any other types. 269 class AsyncRuntimeTypeConverter : public TypeConverter { 270 public: 271 AsyncRuntimeTypeConverter() { 272 addConversion([](Type type) { return type; }); 273 addConversion(convertAsyncTypes); 274 } 275 276 static Optional<Type> convertAsyncTypes(Type type) { 277 if (type.isa<TokenType, GroupType, ValueType>()) 278 return AsyncAPI::opaquePointerType(type.getContext()); 279 280 if (type.isa<CoroIdType, CoroStateType>()) 281 return AsyncAPI::tokenType(type.getContext()); 282 if (type.isa<CoroHandleType>()) 283 return AsyncAPI::opaquePointerType(type.getContext()); 284 285 return llvm::None; 286 } 287 }; 288 } // namespace 289 290 //===----------------------------------------------------------------------===// 291 // Convert async.coro.id to @llvm.coro.id intrinsic. 292 //===----------------------------------------------------------------------===// 293 294 namespace { 295 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 296 public: 297 using OpConversionPattern::OpConversionPattern; 298 299 LogicalResult 300 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, 301 ConversionPatternRewriter &rewriter) const override { 302 auto token = AsyncAPI::tokenType(op->getContext()); 303 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 304 auto loc = op->getLoc(); 305 306 // Constants for initializing coroutine frame. 307 auto constZero = rewriter.create<LLVM::ConstantOp>( 308 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 309 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 310 311 // Get coroutine id: @llvm.coro.id. 312 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 313 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 314 315 return success(); 316 } 317 }; 318 } // namespace 319 320 //===----------------------------------------------------------------------===// 321 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 322 //===----------------------------------------------------------------------===// 323 324 namespace { 325 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 326 public: 327 using OpConversionPattern::OpConversionPattern; 328 329 LogicalResult 330 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, 331 ConversionPatternRewriter &rewriter) const override { 332 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 333 auto loc = op->getLoc(); 334 335 // Get coroutine frame size: @llvm.coro.size.i64. 336 auto coroSize = 337 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 338 // The coroutine lowering doesn't properly account for alignment of the 339 // frame, so align everything to 64 bytes which ought to be enough for 340 // everyone. https://llvm.org/PR53148 341 auto coroAlign = rewriter.create<LLVM::ConstantOp>( 342 op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(64)); 343 344 // Allocate memory for the coroutine frame. 345 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 346 op->getParentOfType<ModuleOp>(), rewriter.getI64Type()); 347 auto coroAlloc = rewriter.create<LLVM::CallOp>( 348 loc, i8Ptr, SymbolRefAttr::get(allocFuncOp), 349 ValueRange{coroAlign, coroSize.getResult()}); 350 351 // Begin a coroutine: @llvm.coro.begin. 352 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); 353 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 354 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 355 356 return success(); 357 } 358 }; 359 } // namespace 360 361 //===----------------------------------------------------------------------===// 362 // Convert async.coro.free to @llvm.coro.free intrinsic. 363 //===----------------------------------------------------------------------===// 364 365 namespace { 366 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 367 public: 368 using OpConversionPattern::OpConversionPattern; 369 370 LogicalResult 371 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, 372 ConversionPatternRewriter &rewriter) const override { 373 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 374 auto loc = op->getLoc(); 375 376 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 377 auto coroMem = 378 rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands()); 379 380 // Free the memory. 381 auto freeFuncOp = 382 LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 383 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 384 SymbolRefAttr::get(freeFuncOp), 385 ValueRange(coroMem.getResult())); 386 387 return success(); 388 } 389 }; 390 } // namespace 391 392 //===----------------------------------------------------------------------===// 393 // Convert async.coro.end to @llvm.coro.end intrinsic. 394 //===----------------------------------------------------------------------===// 395 396 namespace { 397 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 398 public: 399 using OpConversionPattern::OpConversionPattern; 400 401 LogicalResult 402 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, 403 ConversionPatternRewriter &rewriter) const override { 404 // We are not in the block that is part of the unwind sequence. 405 auto constFalse = rewriter.create<LLVM::ConstantOp>( 406 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 407 408 // Mark the end of a coroutine: @llvm.coro.end. 409 auto coroHdl = adaptor.handle(); 410 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 411 ValueRange({coroHdl, constFalse})); 412 rewriter.eraseOp(op); 413 414 return success(); 415 } 416 }; 417 } // namespace 418 419 //===----------------------------------------------------------------------===// 420 // Convert async.coro.save to @llvm.coro.save intrinsic. 421 //===----------------------------------------------------------------------===// 422 423 namespace { 424 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 425 public: 426 using OpConversionPattern::OpConversionPattern; 427 428 LogicalResult 429 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, 430 ConversionPatternRewriter &rewriter) const override { 431 // Save the coroutine state: @llvm.coro.save 432 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 433 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); 434 435 return success(); 436 } 437 }; 438 } // namespace 439 440 //===----------------------------------------------------------------------===// 441 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 442 //===----------------------------------------------------------------------===// 443 444 namespace { 445 446 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 447 /// branch to the appropriate block based on the return code. 448 /// 449 /// Before: 450 /// 451 /// ^suspended: 452 /// "opBefore"(...) 453 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 454 /// ^resume: 455 /// "op"(...) 456 /// ^cleanup: ... 457 /// ^suspend: ... 458 /// 459 /// After: 460 /// 461 /// ^suspended: 462 /// "opBefore"(...) 463 /// %suspend = llmv.intr.coro.suspend ... 464 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 465 /// ^resume: 466 /// "op"(...) 467 /// ^cleanup: ... 468 /// ^suspend: ... 469 /// 470 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 471 public: 472 using OpConversionPattern::OpConversionPattern; 473 474 LogicalResult 475 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, 476 ConversionPatternRewriter &rewriter) const override { 477 auto i8 = rewriter.getIntegerType(8); 478 auto i32 = rewriter.getI32Type(); 479 auto loc = op->getLoc(); 480 481 // This is not a final suspension point. 482 auto constFalse = rewriter.create<LLVM::ConstantOp>( 483 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 484 485 // Suspend a coroutine: @llvm.coro.suspend 486 auto coroState = adaptor.state(); 487 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 488 loc, i8, ValueRange({coroState, constFalse})); 489 490 // Cast return code to i32. 491 492 // After a suspension point decide if we should branch into resume, cleanup 493 // or suspend block of the coroutine (see @llvm.coro.suspend return code 494 // documentation). 495 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 496 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 497 op.cleanupDest()}; 498 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 499 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 500 /*defaultDestination=*/op.suspendDest(), 501 /*defaultOperands=*/ValueRange(), 502 /*caseValues=*/caseValues, 503 /*caseDestinations=*/caseDest, 504 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), 505 /*branchWeights=*/ArrayRef<int32_t>()); 506 507 return success(); 508 } 509 }; 510 } // namespace 511 512 //===----------------------------------------------------------------------===// 513 // Convert async.runtime.create to the corresponding runtime API call. 514 // 515 // To allocate storage for the async values we use getelementptr trick: 516 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 517 //===----------------------------------------------------------------------===// 518 519 namespace { 520 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 521 public: 522 using OpConversionPattern::OpConversionPattern; 523 524 LogicalResult 525 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, 526 ConversionPatternRewriter &rewriter) const override { 527 TypeConverter *converter = getTypeConverter(); 528 Type resultType = op->getResultTypes()[0]; 529 530 // Tokens creation maps to a simple function call. 531 if (resultType.isa<TokenType>()) { 532 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken, 533 converter->convertType(resultType)); 534 return success(); 535 } 536 537 // To create a value we need to compute the storage requirement. 538 if (auto value = resultType.dyn_cast<ValueType>()) { 539 // Returns the size requirements for the async value storage. 540 auto sizeOf = [&](ValueType valueType) -> Value { 541 auto loc = op->getLoc(); 542 auto i64 = rewriter.getI64Type(); 543 544 auto storedType = converter->convertType(valueType.getValueType()); 545 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 546 547 // %Size = getelementptr %T* null, int 1 548 // %SizeI = ptrtoint %T* %Size to i64 549 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 550 auto one = rewriter.create<LLVM::ConstantOp>( 551 loc, i64, rewriter.getI64IntegerAttr(1)); 552 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 553 one.getResult()); 554 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); 555 }; 556 557 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 558 sizeOf(value)); 559 560 return success(); 561 } 562 563 return rewriter.notifyMatchFailure(op, "unsupported async type"); 564 } 565 }; 566 } // namespace 567 568 //===----------------------------------------------------------------------===// 569 // Convert async.runtime.create_group to the corresponding runtime API call. 570 //===----------------------------------------------------------------------===// 571 572 namespace { 573 class RuntimeCreateGroupOpLowering 574 : public OpConversionPattern<RuntimeCreateGroupOp> { 575 public: 576 using OpConversionPattern::OpConversionPattern; 577 578 LogicalResult 579 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, 580 ConversionPatternRewriter &rewriter) const override { 581 TypeConverter *converter = getTypeConverter(); 582 Type resultType = op.getResult().getType(); 583 584 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, 585 converter->convertType(resultType), 586 adaptor.getOperands()); 587 return success(); 588 } 589 }; 590 } // namespace 591 592 //===----------------------------------------------------------------------===// 593 // Convert async.runtime.set_available to the corresponding runtime API call. 594 //===----------------------------------------------------------------------===// 595 596 namespace { 597 class RuntimeSetAvailableOpLowering 598 : public OpConversionPattern<RuntimeSetAvailableOp> { 599 public: 600 using OpConversionPattern::OpConversionPattern; 601 602 LogicalResult 603 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, 604 ConversionPatternRewriter &rewriter) const override { 605 StringRef apiFuncName = 606 TypeSwitch<Type, StringRef>(op.operand().getType()) 607 .Case<TokenType>([](Type) { return kEmplaceToken; }) 608 .Case<ValueType>([](Type) { return kEmplaceValue; }); 609 610 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 611 adaptor.getOperands()); 612 613 return success(); 614 } 615 }; 616 } // namespace 617 618 //===----------------------------------------------------------------------===// 619 // Convert async.runtime.set_error to the corresponding runtime API call. 620 //===----------------------------------------------------------------------===// 621 622 namespace { 623 class RuntimeSetErrorOpLowering 624 : public OpConversionPattern<RuntimeSetErrorOp> { 625 public: 626 using OpConversionPattern::OpConversionPattern; 627 628 LogicalResult 629 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, 630 ConversionPatternRewriter &rewriter) const override { 631 StringRef apiFuncName = 632 TypeSwitch<Type, StringRef>(op.operand().getType()) 633 .Case<TokenType>([](Type) { return kSetTokenError; }) 634 .Case<ValueType>([](Type) { return kSetValueError; }); 635 636 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 637 adaptor.getOperands()); 638 639 return success(); 640 } 641 }; 642 } // namespace 643 644 //===----------------------------------------------------------------------===// 645 // Convert async.runtime.is_error to the corresponding runtime API call. 646 //===----------------------------------------------------------------------===// 647 648 namespace { 649 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 650 public: 651 using OpConversionPattern::OpConversionPattern; 652 653 LogicalResult 654 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, 655 ConversionPatternRewriter &rewriter) const override { 656 StringRef apiFuncName = 657 TypeSwitch<Type, StringRef>(op.operand().getType()) 658 .Case<TokenType>([](Type) { return kIsTokenError; }) 659 .Case<GroupType>([](Type) { return kIsGroupError; }) 660 .Case<ValueType>([](Type) { return kIsValueError; }); 661 662 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(), 663 adaptor.getOperands()); 664 return success(); 665 } 666 }; 667 } // namespace 668 669 //===----------------------------------------------------------------------===// 670 // Convert async.runtime.await to the corresponding runtime API call. 671 //===----------------------------------------------------------------------===// 672 673 namespace { 674 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 675 public: 676 using OpConversionPattern::OpConversionPattern; 677 678 LogicalResult 679 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, 680 ConversionPatternRewriter &rewriter) const override { 681 StringRef apiFuncName = 682 TypeSwitch<Type, StringRef>(op.operand().getType()) 683 .Case<TokenType>([](Type) { return kAwaitToken; }) 684 .Case<ValueType>([](Type) { return kAwaitValue; }) 685 .Case<GroupType>([](Type) { return kAwaitGroup; }); 686 687 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 688 adaptor.getOperands()); 689 rewriter.eraseOp(op); 690 691 return success(); 692 } 693 }; 694 } // namespace 695 696 //===----------------------------------------------------------------------===// 697 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 698 //===----------------------------------------------------------------------===// 699 700 namespace { 701 class RuntimeAwaitAndResumeOpLowering 702 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 703 public: 704 using OpConversionPattern::OpConversionPattern; 705 706 LogicalResult 707 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, 708 ConversionPatternRewriter &rewriter) const override { 709 StringRef apiFuncName = 710 TypeSwitch<Type, StringRef>(op.operand().getType()) 711 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 712 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 713 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 714 715 Value operand = adaptor.operand(); 716 Value handle = adaptor.handle(); 717 718 // A pointer to coroutine resume intrinsic wrapper. 719 addResumeFunction(op->getParentOfType<ModuleOp>()); 720 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 721 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 722 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 723 724 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 725 ValueRange({operand, handle, resumePtr.getRes()})); 726 rewriter.eraseOp(op); 727 728 return success(); 729 } 730 }; 731 } // namespace 732 733 //===----------------------------------------------------------------------===// 734 // Convert async.runtime.resume to the corresponding runtime API call. 735 //===----------------------------------------------------------------------===// 736 737 namespace { 738 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 739 public: 740 using OpConversionPattern::OpConversionPattern; 741 742 LogicalResult 743 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, 744 ConversionPatternRewriter &rewriter) const override { 745 // A pointer to coroutine resume intrinsic wrapper. 746 addResumeFunction(op->getParentOfType<ModuleOp>()); 747 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 748 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 749 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 750 751 // Call async runtime API to execute a coroutine in the managed thread. 752 auto coroHdl = adaptor.handle(); 753 rewriter.replaceOpWithNewOp<CallOp>( 754 op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); 755 756 return success(); 757 } 758 }; 759 } // namespace 760 761 //===----------------------------------------------------------------------===// 762 // Convert async.runtime.store to the corresponding runtime API call. 763 //===----------------------------------------------------------------------===// 764 765 namespace { 766 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 767 public: 768 using OpConversionPattern::OpConversionPattern; 769 770 LogicalResult 771 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, 772 ConversionPatternRewriter &rewriter) const override { 773 Location loc = op->getLoc(); 774 775 // Get a pointer to the async value storage from the runtime. 776 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 777 auto storage = adaptor.storage(); 778 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 779 TypeRange(i8Ptr), storage); 780 781 // Cast from i8* to the LLVM pointer type. 782 auto valueType = op.value().getType(); 783 auto llvmValueType = getTypeConverter()->convertType(valueType); 784 if (!llvmValueType) 785 return rewriter.notifyMatchFailure( 786 op, "failed to convert stored value type to LLVM type"); 787 788 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 789 loc, LLVM::LLVMPointerType::get(llvmValueType), 790 storagePtr.getResult(0)); 791 792 // Store the yielded value into the async value storage. 793 auto value = adaptor.value(); 794 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 795 796 // Erase the original runtime store operation. 797 rewriter.eraseOp(op); 798 799 return success(); 800 } 801 }; 802 } // namespace 803 804 //===----------------------------------------------------------------------===// 805 // Convert async.runtime.load to the corresponding runtime API call. 806 //===----------------------------------------------------------------------===// 807 808 namespace { 809 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 810 public: 811 using OpConversionPattern::OpConversionPattern; 812 813 LogicalResult 814 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, 815 ConversionPatternRewriter &rewriter) const override { 816 Location loc = op->getLoc(); 817 818 // Get a pointer to the async value storage from the runtime. 819 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 820 auto storage = adaptor.storage(); 821 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 822 TypeRange(i8Ptr), storage); 823 824 // Cast from i8* to the LLVM pointer type. 825 auto valueType = op.result().getType(); 826 auto llvmValueType = getTypeConverter()->convertType(valueType); 827 if (!llvmValueType) 828 return rewriter.notifyMatchFailure( 829 op, "failed to convert loaded value type to LLVM type"); 830 831 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 832 loc, LLVM::LLVMPointerType::get(llvmValueType), 833 storagePtr.getResult(0)); 834 835 // Load from the casted pointer. 836 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 837 838 return success(); 839 } 840 }; 841 } // namespace 842 843 //===----------------------------------------------------------------------===// 844 // Convert async.runtime.add_to_group to the corresponding runtime API call. 845 //===----------------------------------------------------------------------===// 846 847 namespace { 848 class RuntimeAddToGroupOpLowering 849 : public OpConversionPattern<RuntimeAddToGroupOp> { 850 public: 851 using OpConversionPattern::OpConversionPattern; 852 853 LogicalResult 854 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, 855 ConversionPatternRewriter &rewriter) const override { 856 // Currently we can only add tokens to the group. 857 if (!op.operand().getType().isa<TokenType>()) 858 return rewriter.notifyMatchFailure(op, "only token type is supported"); 859 860 // Replace with a runtime API function call. 861 rewriter.replaceOpWithNewOp<CallOp>( 862 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); 863 864 return success(); 865 } 866 }; 867 } // namespace 868 869 //===----------------------------------------------------------------------===// 870 // Async reference counting ops lowering (`async.runtime.add_ref` and 871 // `async.runtime.drop_ref` to the corresponding API calls). 872 //===----------------------------------------------------------------------===// 873 874 namespace { 875 template <typename RefCountingOp> 876 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 877 public: 878 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 879 StringRef apiFunctionName) 880 : OpConversionPattern<RefCountingOp>(converter, ctx), 881 apiFunctionName(apiFunctionName) {} 882 883 LogicalResult 884 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, 885 ConversionPatternRewriter &rewriter) const override { 886 auto count = rewriter.create<arith::ConstantOp>( 887 op->getLoc(), rewriter.getI64Type(), 888 rewriter.getI64IntegerAttr(op.count())); 889 890 auto operand = adaptor.operand(); 891 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 892 ValueRange({operand, count})); 893 894 return success(); 895 } 896 897 private: 898 StringRef apiFunctionName; 899 }; 900 901 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 902 public: 903 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 904 : RefCountingOpLowering(converter, ctx, kAddRef) {} 905 }; 906 907 class RuntimeDropRefOpLowering 908 : public RefCountingOpLowering<RuntimeDropRefOp> { 909 public: 910 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 911 : RefCountingOpLowering(converter, ctx, kDropRef) {} 912 }; 913 } // namespace 914 915 //===----------------------------------------------------------------------===// 916 // Convert return operations that return async values from async regions. 917 //===----------------------------------------------------------------------===// 918 919 namespace { 920 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 921 public: 922 using OpConversionPattern::OpConversionPattern; 923 924 LogicalResult 925 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 926 ConversionPatternRewriter &rewriter) const override { 927 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 928 return success(); 929 } 930 }; 931 } // namespace 932 933 //===----------------------------------------------------------------------===// 934 935 namespace { 936 struct ConvertAsyncToLLVMPass 937 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 938 void runOnOperation() override; 939 }; 940 } // namespace 941 942 void ConvertAsyncToLLVMPass::runOnOperation() { 943 ModuleOp module = getOperation(); 944 MLIRContext *ctx = module->getContext(); 945 946 // Add declarations for most functions required by the coroutines lowering. 947 // We delay adding the resume function until it's needed because it currently 948 // fails to compile unless '-O0' is specified. 949 addAsyncRuntimeApiDeclarations(module); 950 951 // Lower async.runtime and async.coro operations to Async Runtime API and 952 // LLVM coroutine intrinsics. 953 954 // Convert async dialect types and operations to LLVM dialect. 955 AsyncRuntimeTypeConverter converter; 956 RewritePatternSet patterns(ctx); 957 958 // We use conversion to LLVM type to lower async.runtime load and store 959 // operations. 960 LLVMTypeConverter llvmConverter(ctx); 961 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 962 963 // Convert async types in function signatures and function calls. 964 populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, converter); 965 populateCallOpTypeConversionPattern(patterns, converter); 966 967 // Convert return operations inside async.execute regions. 968 patterns.add<ReturnOpOpConversion>(converter, ctx); 969 970 // Lower async.runtime operations to the async runtime API calls. 971 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 972 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 973 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 974 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 975 RuntimeDropRefOpLowering>(converter, ctx); 976 977 // Lower async.runtime operations that rely on LLVM type converter to convert 978 // from async value payload type to the LLVM type. 979 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 980 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter, 981 ctx); 982 983 // Lower async coroutine operations to LLVM coroutine intrinsics. 984 patterns 985 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 986 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 987 converter, ctx); 988 989 ConversionTarget target(*ctx); 990 target 991 .addLegalOp<arith::ConstantOp, ConstantOp, UnrealizedConversionCastOp>(); 992 target.addLegalDialect<LLVM::LLVMDialect>(); 993 994 // All operations from Async dialect must be lowered to the runtime API and 995 // LLVM intrinsics calls. 996 target.addIllegalDialect<AsyncDialect>(); 997 998 // Add dynamic legality constraints to apply conversions defined above. 999 target.addDynamicallyLegalOp<FuncOp>( 1000 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1001 target.addDynamicallyLegalOp<ReturnOp>( 1002 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1003 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1004 return converter.isSignatureLegal(op.getCalleeType()); 1005 }); 1006 1007 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1008 signalPassFailure(); 1009 } 1010 1011 //===----------------------------------------------------------------------===// 1012 // Patterns for structural type conversions for the Async dialect operations. 1013 //===----------------------------------------------------------------------===// 1014 1015 namespace { 1016 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1017 public: 1018 using OpConversionPattern::OpConversionPattern; 1019 LogicalResult 1020 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, 1021 ConversionPatternRewriter &rewriter) const override { 1022 ExecuteOp newOp = 1023 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1024 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1025 newOp.getRegion().end()); 1026 1027 // Set operands and update block argument and result types. 1028 newOp->setOperands(adaptor.getOperands()); 1029 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1030 return failure(); 1031 for (auto result : newOp.getResults()) 1032 result.setType(typeConverter->convertType(result.getType())); 1033 1034 rewriter.replaceOp(op, newOp.getResults()); 1035 return success(); 1036 } 1037 }; 1038 1039 // Dummy pattern to trigger the appropriate type conversion / materialization. 1040 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1041 public: 1042 using OpConversionPattern::OpConversionPattern; 1043 LogicalResult 1044 matchAndRewrite(AwaitOp op, OpAdaptor adaptor, 1045 ConversionPatternRewriter &rewriter) const override { 1046 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); 1047 return success(); 1048 } 1049 }; 1050 1051 // Dummy pattern to trigger the appropriate type conversion / materialization. 1052 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1053 public: 1054 using OpConversionPattern::OpConversionPattern; 1055 LogicalResult 1056 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 1057 ConversionPatternRewriter &rewriter) const override { 1058 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); 1059 return success(); 1060 } 1061 }; 1062 } // namespace 1063 1064 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1065 return std::make_unique<ConvertAsyncToLLVMPass>(); 1066 } 1067 1068 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1069 TypeConverter &typeConverter, RewritePatternSet &patterns, 1070 ConversionTarget &target) { 1071 typeConverter.addConversion([&](TokenType type) { return type; }); 1072 typeConverter.addConversion([&](ValueType type) { 1073 Type converted = typeConverter.convertType(type.getValueType()); 1074 return converted ? ValueType::get(converted) : converted; 1075 }); 1076 1077 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1078 typeConverter, patterns.getContext()); 1079 1080 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1081 [&](Operation *op) { return typeConverter.isLegal(op); }); 1082 } 1083