1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 20 #include "mlir/IR/ImplicitLocOpBuilder.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 26 #define DEBUG_TYPE "convert-async-to-llvm" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 //===----------------------------------------------------------------------===// 32 // Async Runtime C API declaration. 33 //===----------------------------------------------------------------------===// 34 35 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 36 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 37 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 38 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 39 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 40 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 41 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 42 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 43 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 44 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 45 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 46 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 47 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 48 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 49 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 50 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 51 static constexpr const char *kGetValueStorage = 52 "mlirAsyncRuntimeGetValueStorage"; 53 static constexpr const char *kAddTokenToGroup = 54 "mlirAsyncRuntimeAddTokenToGroup"; 55 static constexpr const char *kAwaitTokenAndExecute = 56 "mlirAsyncRuntimeAwaitTokenAndExecute"; 57 static constexpr const char *kAwaitValueAndExecute = 58 "mlirAsyncRuntimeAwaitValueAndExecute"; 59 static constexpr const char *kAwaitAllAndExecute = 60 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 61 62 namespace { 63 /// Async Runtime API function types. 64 /// 65 /// Because we can't create API function signature for type parametrized 66 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 67 /// lowering all async data types become opaque pointers at runtime. 68 struct AsyncAPI { 69 // All async types are lowered to opaque i8* LLVM pointers at runtime. 70 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 71 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 72 } 73 74 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 75 return LLVM::LLVMTokenType::get(ctx); 76 } 77 78 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 79 auto ref = opaquePointerType(ctx); 80 auto count = IntegerType::get(ctx, 64); 81 return FunctionType::get(ctx, {ref, count}, {}); 82 } 83 84 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 85 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 86 } 87 88 static FunctionType createValueFunctionType(MLIRContext *ctx) { 89 auto i64 = IntegerType::get(ctx, 64); 90 auto value = opaquePointerType(ctx); 91 return FunctionType::get(ctx, {i64}, {value}); 92 } 93 94 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 95 auto i64 = IntegerType::get(ctx, 64); 96 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 97 } 98 99 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 100 auto value = opaquePointerType(ctx); 101 auto storage = opaquePointerType(ctx); 102 return FunctionType::get(ctx, {value}, {storage}); 103 } 104 105 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 106 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 107 } 108 109 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 110 auto value = opaquePointerType(ctx); 111 return FunctionType::get(ctx, {value}, {}); 112 } 113 114 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 115 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 116 } 117 118 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 119 auto value = opaquePointerType(ctx); 120 return FunctionType::get(ctx, {value}, {}); 121 } 122 123 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 124 auto i1 = IntegerType::get(ctx, 1); 125 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 126 } 127 128 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 129 auto value = opaquePointerType(ctx); 130 auto i1 = IntegerType::get(ctx, 1); 131 return FunctionType::get(ctx, {value}, {i1}); 132 } 133 134 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 135 auto i1 = IntegerType::get(ctx, 1); 136 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 137 } 138 139 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 140 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 141 } 142 143 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 144 auto value = opaquePointerType(ctx); 145 return FunctionType::get(ctx, {value}, {}); 146 } 147 148 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 149 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 150 } 151 152 static FunctionType executeFunctionType(MLIRContext *ctx) { 153 auto hdl = opaquePointerType(ctx); 154 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 155 return FunctionType::get(ctx, {hdl, resume}, {}); 156 } 157 158 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 159 auto i64 = IntegerType::get(ctx, 64); 160 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 161 {i64}); 162 } 163 164 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 165 auto hdl = opaquePointerType(ctx); 166 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 167 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 168 } 169 170 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 171 auto value = opaquePointerType(ctx); 172 auto hdl = opaquePointerType(ctx); 173 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 174 return FunctionType::get(ctx, {value, hdl, resume}, {}); 175 } 176 177 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 178 auto hdl = opaquePointerType(ctx); 179 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 180 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 181 } 182 183 // Auxiliary coroutine resume intrinsic wrapper. 184 static Type resumeFunctionType(MLIRContext *ctx) { 185 auto voidTy = LLVM::LLVMVoidType::get(ctx); 186 auto i8Ptr = opaquePointerType(ctx); 187 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 188 } 189 }; 190 } // namespace 191 192 /// Adds Async Runtime C API declarations to the module. 193 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 194 auto builder = 195 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 196 197 auto addFuncDecl = [&](StringRef name, FunctionType type) { 198 if (module.lookupSymbol(name)) 199 return; 200 builder.create<FuncOp>(name, type).setPrivate(); 201 }; 202 203 MLIRContext *ctx = module.getContext(); 204 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 205 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 206 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 207 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 208 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 209 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 210 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 211 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 212 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 213 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 214 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 215 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 216 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 217 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 218 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 219 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 220 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 221 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 222 addFuncDecl(kAwaitTokenAndExecute, 223 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 224 addFuncDecl(kAwaitValueAndExecute, 225 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 226 addFuncDecl(kAwaitAllAndExecute, 227 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // Add malloc/free declarations to the module. 232 //===----------------------------------------------------------------------===// 233 234 static constexpr const char *kMalloc = "malloc"; 235 static constexpr const char *kFree = "free"; 236 237 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 238 StringRef name, Type ret, ArrayRef<Type> params) { 239 if (module.lookupSymbol(name)) 240 return; 241 Type type = LLVM::LLVMFunctionType::get(ret, params); 242 builder.create<LLVM::LLVMFuncOp>(name, type); 243 } 244 245 /// Adds malloc/free declarations to the module. 246 static void addCRuntimeDeclarations(ModuleOp module) { 247 using namespace mlir::LLVM; 248 249 MLIRContext *ctx = module.getContext(); 250 auto builder = 251 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 252 253 auto voidTy = LLVMVoidType::get(ctx); 254 auto i64 = IntegerType::get(ctx, 64); 255 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 256 257 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 258 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 259 } 260 261 //===----------------------------------------------------------------------===// 262 // Coroutine resume function wrapper. 263 //===----------------------------------------------------------------------===// 264 265 static constexpr const char *kResume = "__resume"; 266 267 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 268 /// intrinsics. We need this function to be able to pass it to the async 269 /// runtime execute API. 270 static void addResumeFunction(ModuleOp module) { 271 if (module.lookupSymbol(kResume)) 272 return; 273 274 MLIRContext *ctx = module.getContext(); 275 auto loc = module.getLoc(); 276 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 277 278 auto voidTy = LLVM::LLVMVoidType::get(ctx); 279 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 280 281 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 282 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 283 resumeOp.setPrivate(); 284 285 auto *block = resumeOp.addEntryBlock(); 286 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 287 288 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 289 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // Convert Async dialect types to LLVM types. 294 //===----------------------------------------------------------------------===// 295 296 namespace { 297 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 298 /// their runtime type (opaque pointers) and does not convert any other types. 299 class AsyncRuntimeTypeConverter : public TypeConverter { 300 public: 301 AsyncRuntimeTypeConverter() { 302 addConversion([](Type type) { return type; }); 303 addConversion(convertAsyncTypes); 304 } 305 306 static Optional<Type> convertAsyncTypes(Type type) { 307 if (type.isa<TokenType, GroupType, ValueType>()) 308 return AsyncAPI::opaquePointerType(type.getContext()); 309 310 if (type.isa<CoroIdType, CoroStateType>()) 311 return AsyncAPI::tokenType(type.getContext()); 312 if (type.isa<CoroHandleType>()) 313 return AsyncAPI::opaquePointerType(type.getContext()); 314 315 return llvm::None; 316 } 317 }; 318 } // namespace 319 320 //===----------------------------------------------------------------------===// 321 // Convert async.coro.id to @llvm.coro.id intrinsic. 322 //===----------------------------------------------------------------------===// 323 324 namespace { 325 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 326 public: 327 using OpConversionPattern::OpConversionPattern; 328 329 LogicalResult 330 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, 331 ConversionPatternRewriter &rewriter) const override { 332 auto token = AsyncAPI::tokenType(op->getContext()); 333 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 334 auto loc = op->getLoc(); 335 336 // Constants for initializing coroutine frame. 337 auto constZero = rewriter.create<LLVM::ConstantOp>( 338 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 339 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 340 341 // Get coroutine id: @llvm.coro.id. 342 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 343 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 344 345 return success(); 346 } 347 }; 348 } // namespace 349 350 //===----------------------------------------------------------------------===// 351 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 352 //===----------------------------------------------------------------------===// 353 354 namespace { 355 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 356 public: 357 using OpConversionPattern::OpConversionPattern; 358 359 LogicalResult 360 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, 361 ConversionPatternRewriter &rewriter) const override { 362 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 363 auto loc = op->getLoc(); 364 365 // Get coroutine frame size: @llvm.coro.size.i64. 366 auto coroSize = 367 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 368 369 // Allocate memory for the coroutine frame. 370 auto coroAlloc = rewriter.create<LLVM::CallOp>( 371 loc, i8Ptr, SymbolRefAttr::get(rewriter.getContext(), kMalloc), 372 ValueRange(coroSize.getResult())); 373 374 // Begin a coroutine: @llvm.coro.begin. 375 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); 376 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 377 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 378 379 return success(); 380 } 381 }; 382 } // namespace 383 384 //===----------------------------------------------------------------------===// 385 // Convert async.coro.free to @llvm.coro.free intrinsic. 386 //===----------------------------------------------------------------------===// 387 388 namespace { 389 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 390 public: 391 using OpConversionPattern::OpConversionPattern; 392 393 LogicalResult 394 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, 395 ConversionPatternRewriter &rewriter) const override { 396 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 397 auto loc = op->getLoc(); 398 399 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 400 auto coroMem = 401 rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands()); 402 403 // Free the memory. 404 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 405 op, TypeRange(), SymbolRefAttr::get(rewriter.getContext(), kFree), 406 ValueRange(coroMem.getResult())); 407 408 return success(); 409 } 410 }; 411 } // namespace 412 413 //===----------------------------------------------------------------------===// 414 // Convert async.coro.end to @llvm.coro.end intrinsic. 415 //===----------------------------------------------------------------------===// 416 417 namespace { 418 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 419 public: 420 using OpConversionPattern::OpConversionPattern; 421 422 LogicalResult 423 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, 424 ConversionPatternRewriter &rewriter) const override { 425 // We are not in the block that is part of the unwind sequence. 426 auto constFalse = rewriter.create<LLVM::ConstantOp>( 427 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 428 429 // Mark the end of a coroutine: @llvm.coro.end. 430 auto coroHdl = adaptor.handle(); 431 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 432 ValueRange({coroHdl, constFalse})); 433 rewriter.eraseOp(op); 434 435 return success(); 436 } 437 }; 438 } // namespace 439 440 //===----------------------------------------------------------------------===// 441 // Convert async.coro.save to @llvm.coro.save intrinsic. 442 //===----------------------------------------------------------------------===// 443 444 namespace { 445 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 446 public: 447 using OpConversionPattern::OpConversionPattern; 448 449 LogicalResult 450 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, 451 ConversionPatternRewriter &rewriter) const override { 452 // Save the coroutine state: @llvm.coro.save 453 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 454 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); 455 456 return success(); 457 } 458 }; 459 } // namespace 460 461 //===----------------------------------------------------------------------===// 462 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 463 //===----------------------------------------------------------------------===// 464 465 namespace { 466 467 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 468 /// branch to the appropriate block based on the return code. 469 /// 470 /// Before: 471 /// 472 /// ^suspended: 473 /// "opBefore"(...) 474 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 475 /// ^resume: 476 /// "op"(...) 477 /// ^cleanup: ... 478 /// ^suspend: ... 479 /// 480 /// After: 481 /// 482 /// ^suspended: 483 /// "opBefore"(...) 484 /// %suspend = llmv.intr.coro.suspend ... 485 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 486 /// ^resume: 487 /// "op"(...) 488 /// ^cleanup: ... 489 /// ^suspend: ... 490 /// 491 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 492 public: 493 using OpConversionPattern::OpConversionPattern; 494 495 LogicalResult 496 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, 497 ConversionPatternRewriter &rewriter) const override { 498 auto i8 = rewriter.getIntegerType(8); 499 auto i32 = rewriter.getI32Type(); 500 auto loc = op->getLoc(); 501 502 // This is not a final suspension point. 503 auto constFalse = rewriter.create<LLVM::ConstantOp>( 504 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 505 506 // Suspend a coroutine: @llvm.coro.suspend 507 auto coroState = adaptor.state(); 508 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 509 loc, i8, ValueRange({coroState, constFalse})); 510 511 // Cast return code to i32. 512 513 // After a suspension point decide if we should branch into resume, cleanup 514 // or suspend block of the coroutine (see @llvm.coro.suspend return code 515 // documentation). 516 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 517 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 518 op.cleanupDest()}; 519 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 520 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 521 /*defaultDestination=*/op.suspendDest(), 522 /*defaultOperands=*/ValueRange(), 523 /*caseValues=*/caseValues, 524 /*caseDestinations=*/caseDest, 525 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), 526 /*branchWeights=*/ArrayRef<int32_t>()); 527 528 return success(); 529 } 530 }; 531 } // namespace 532 533 //===----------------------------------------------------------------------===// 534 // Convert async.runtime.create to the corresponding runtime API call. 535 // 536 // To allocate storage for the async values we use getelementptr trick: 537 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 538 //===----------------------------------------------------------------------===// 539 540 namespace { 541 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 542 public: 543 using OpConversionPattern::OpConversionPattern; 544 545 LogicalResult 546 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, 547 ConversionPatternRewriter &rewriter) const override { 548 TypeConverter *converter = getTypeConverter(); 549 Type resultType = op->getResultTypes()[0]; 550 551 // Tokens creation maps to a simple function call. 552 if (resultType.isa<TokenType>()) { 553 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken, 554 converter->convertType(resultType)); 555 return success(); 556 } 557 558 // To create a value we need to compute the storage requirement. 559 if (auto value = resultType.dyn_cast<ValueType>()) { 560 // Returns the size requirements for the async value storage. 561 auto sizeOf = [&](ValueType valueType) -> Value { 562 auto loc = op->getLoc(); 563 auto i64 = rewriter.getI64Type(); 564 565 auto storedType = converter->convertType(valueType.getValueType()); 566 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 567 568 // %Size = getelementptr %T* null, int 1 569 // %SizeI = ptrtoint %T* %Size to i64 570 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 571 auto one = rewriter.create<LLVM::ConstantOp>( 572 loc, i64, rewriter.getI64IntegerAttr(1)); 573 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 574 one.getResult()); 575 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); 576 }; 577 578 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 579 sizeOf(value)); 580 581 return success(); 582 } 583 584 return rewriter.notifyMatchFailure(op, "unsupported async type"); 585 } 586 }; 587 } // namespace 588 589 //===----------------------------------------------------------------------===// 590 // Convert async.runtime.create_group to the corresponding runtime API call. 591 //===----------------------------------------------------------------------===// 592 593 namespace { 594 class RuntimeCreateGroupOpLowering 595 : public OpConversionPattern<RuntimeCreateGroupOp> { 596 public: 597 using OpConversionPattern::OpConversionPattern; 598 599 LogicalResult 600 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, 601 ConversionPatternRewriter &rewriter) const override { 602 TypeConverter *converter = getTypeConverter(); 603 Type resultType = op.getResult().getType(); 604 605 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, 606 converter->convertType(resultType), 607 adaptor.getOperands()); 608 return success(); 609 } 610 }; 611 } // namespace 612 613 //===----------------------------------------------------------------------===// 614 // Convert async.runtime.set_available to the corresponding runtime API call. 615 //===----------------------------------------------------------------------===// 616 617 namespace { 618 class RuntimeSetAvailableOpLowering 619 : public OpConversionPattern<RuntimeSetAvailableOp> { 620 public: 621 using OpConversionPattern::OpConversionPattern; 622 623 LogicalResult 624 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, 625 ConversionPatternRewriter &rewriter) const override { 626 StringRef apiFuncName = 627 TypeSwitch<Type, StringRef>(op.operand().getType()) 628 .Case<TokenType>([](Type) { return kEmplaceToken; }) 629 .Case<ValueType>([](Type) { return kEmplaceValue; }); 630 631 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 632 adaptor.getOperands()); 633 634 return success(); 635 } 636 }; 637 } // namespace 638 639 //===----------------------------------------------------------------------===// 640 // Convert async.runtime.set_error to the corresponding runtime API call. 641 //===----------------------------------------------------------------------===// 642 643 namespace { 644 class RuntimeSetErrorOpLowering 645 : public OpConversionPattern<RuntimeSetErrorOp> { 646 public: 647 using OpConversionPattern::OpConversionPattern; 648 649 LogicalResult 650 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, 651 ConversionPatternRewriter &rewriter) const override { 652 StringRef apiFuncName = 653 TypeSwitch<Type, StringRef>(op.operand().getType()) 654 .Case<TokenType>([](Type) { return kSetTokenError; }) 655 .Case<ValueType>([](Type) { return kSetValueError; }); 656 657 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 658 adaptor.getOperands()); 659 660 return success(); 661 } 662 }; 663 } // namespace 664 665 //===----------------------------------------------------------------------===// 666 // Convert async.runtime.is_error to the corresponding runtime API call. 667 //===----------------------------------------------------------------------===// 668 669 namespace { 670 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 671 public: 672 using OpConversionPattern::OpConversionPattern; 673 674 LogicalResult 675 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, 676 ConversionPatternRewriter &rewriter) const override { 677 StringRef apiFuncName = 678 TypeSwitch<Type, StringRef>(op.operand().getType()) 679 .Case<TokenType>([](Type) { return kIsTokenError; }) 680 .Case<GroupType>([](Type) { return kIsGroupError; }) 681 .Case<ValueType>([](Type) { return kIsValueError; }); 682 683 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(), 684 adaptor.getOperands()); 685 return success(); 686 } 687 }; 688 } // namespace 689 690 //===----------------------------------------------------------------------===// 691 // Convert async.runtime.await to the corresponding runtime API call. 692 //===----------------------------------------------------------------------===// 693 694 namespace { 695 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 696 public: 697 using OpConversionPattern::OpConversionPattern; 698 699 LogicalResult 700 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, 701 ConversionPatternRewriter &rewriter) const override { 702 StringRef apiFuncName = 703 TypeSwitch<Type, StringRef>(op.operand().getType()) 704 .Case<TokenType>([](Type) { return kAwaitToken; }) 705 .Case<ValueType>([](Type) { return kAwaitValue; }) 706 .Case<GroupType>([](Type) { return kAwaitGroup; }); 707 708 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 709 adaptor.getOperands()); 710 rewriter.eraseOp(op); 711 712 return success(); 713 } 714 }; 715 } // namespace 716 717 //===----------------------------------------------------------------------===// 718 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 719 //===----------------------------------------------------------------------===// 720 721 namespace { 722 class RuntimeAwaitAndResumeOpLowering 723 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 724 public: 725 using OpConversionPattern::OpConversionPattern; 726 727 LogicalResult 728 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, 729 ConversionPatternRewriter &rewriter) const override { 730 StringRef apiFuncName = 731 TypeSwitch<Type, StringRef>(op.operand().getType()) 732 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 733 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 734 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 735 736 Value operand = adaptor.operand(); 737 Value handle = adaptor.handle(); 738 739 // A pointer to coroutine resume intrinsic wrapper. 740 addResumeFunction(op->getParentOfType<ModuleOp>()); 741 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 742 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 743 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 744 745 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 746 ValueRange({operand, handle, resumePtr.getRes()})); 747 rewriter.eraseOp(op); 748 749 return success(); 750 } 751 }; 752 } // namespace 753 754 //===----------------------------------------------------------------------===// 755 // Convert async.runtime.resume to the corresponding runtime API call. 756 //===----------------------------------------------------------------------===// 757 758 namespace { 759 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 760 public: 761 using OpConversionPattern::OpConversionPattern; 762 763 LogicalResult 764 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, 765 ConversionPatternRewriter &rewriter) const override { 766 // A pointer to coroutine resume intrinsic wrapper. 767 addResumeFunction(op->getParentOfType<ModuleOp>()); 768 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 769 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 770 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 771 772 // Call async runtime API to execute a coroutine in the managed thread. 773 auto coroHdl = adaptor.handle(); 774 rewriter.replaceOpWithNewOp<CallOp>( 775 op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); 776 777 return success(); 778 } 779 }; 780 } // namespace 781 782 //===----------------------------------------------------------------------===// 783 // Convert async.runtime.store to the corresponding runtime API call. 784 //===----------------------------------------------------------------------===// 785 786 namespace { 787 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 788 public: 789 using OpConversionPattern::OpConversionPattern; 790 791 LogicalResult 792 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, 793 ConversionPatternRewriter &rewriter) const override { 794 Location loc = op->getLoc(); 795 796 // Get a pointer to the async value storage from the runtime. 797 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 798 auto storage = adaptor.storage(); 799 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 800 TypeRange(i8Ptr), storage); 801 802 // Cast from i8* to the LLVM pointer type. 803 auto valueType = op.value().getType(); 804 auto llvmValueType = getTypeConverter()->convertType(valueType); 805 if (!llvmValueType) 806 return rewriter.notifyMatchFailure( 807 op, "failed to convert stored value type to LLVM type"); 808 809 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 810 loc, LLVM::LLVMPointerType::get(llvmValueType), 811 storagePtr.getResult(0)); 812 813 // Store the yielded value into the async value storage. 814 auto value = adaptor.value(); 815 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 816 817 // Erase the original runtime store operation. 818 rewriter.eraseOp(op); 819 820 return success(); 821 } 822 }; 823 } // namespace 824 825 //===----------------------------------------------------------------------===// 826 // Convert async.runtime.load to the corresponding runtime API call. 827 //===----------------------------------------------------------------------===// 828 829 namespace { 830 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 831 public: 832 using OpConversionPattern::OpConversionPattern; 833 834 LogicalResult 835 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, 836 ConversionPatternRewriter &rewriter) const override { 837 Location loc = op->getLoc(); 838 839 // Get a pointer to the async value storage from the runtime. 840 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 841 auto storage = adaptor.storage(); 842 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 843 TypeRange(i8Ptr), storage); 844 845 // Cast from i8* to the LLVM pointer type. 846 auto valueType = op.result().getType(); 847 auto llvmValueType = getTypeConverter()->convertType(valueType); 848 if (!llvmValueType) 849 return rewriter.notifyMatchFailure( 850 op, "failed to convert loaded value type to LLVM type"); 851 852 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 853 loc, LLVM::LLVMPointerType::get(llvmValueType), 854 storagePtr.getResult(0)); 855 856 // Load from the casted pointer. 857 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 858 859 return success(); 860 } 861 }; 862 } // namespace 863 864 //===----------------------------------------------------------------------===// 865 // Convert async.runtime.add_to_group to the corresponding runtime API call. 866 //===----------------------------------------------------------------------===// 867 868 namespace { 869 class RuntimeAddToGroupOpLowering 870 : public OpConversionPattern<RuntimeAddToGroupOp> { 871 public: 872 using OpConversionPattern::OpConversionPattern; 873 874 LogicalResult 875 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, 876 ConversionPatternRewriter &rewriter) const override { 877 // Currently we can only add tokens to the group. 878 if (!op.operand().getType().isa<TokenType>()) 879 return rewriter.notifyMatchFailure(op, "only token type is supported"); 880 881 // Replace with a runtime API function call. 882 rewriter.replaceOpWithNewOp<CallOp>( 883 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); 884 885 return success(); 886 } 887 }; 888 } // namespace 889 890 //===----------------------------------------------------------------------===// 891 // Async reference counting ops lowering (`async.runtime.add_ref` and 892 // `async.runtime.drop_ref` to the corresponding API calls). 893 //===----------------------------------------------------------------------===// 894 895 namespace { 896 template <typename RefCountingOp> 897 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 898 public: 899 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 900 StringRef apiFunctionName) 901 : OpConversionPattern<RefCountingOp>(converter, ctx), 902 apiFunctionName(apiFunctionName) {} 903 904 LogicalResult 905 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, 906 ConversionPatternRewriter &rewriter) const override { 907 auto count = rewriter.create<arith::ConstantOp>( 908 op->getLoc(), rewriter.getI64Type(), 909 rewriter.getI64IntegerAttr(op.count())); 910 911 auto operand = adaptor.operand(); 912 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 913 ValueRange({operand, count})); 914 915 return success(); 916 } 917 918 private: 919 StringRef apiFunctionName; 920 }; 921 922 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 923 public: 924 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 925 : RefCountingOpLowering(converter, ctx, kAddRef) {} 926 }; 927 928 class RuntimeDropRefOpLowering 929 : public RefCountingOpLowering<RuntimeDropRefOp> { 930 public: 931 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 932 : RefCountingOpLowering(converter, ctx, kDropRef) {} 933 }; 934 } // namespace 935 936 //===----------------------------------------------------------------------===// 937 // Convert return operations that return async values from async regions. 938 //===----------------------------------------------------------------------===// 939 940 namespace { 941 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 942 public: 943 using OpConversionPattern::OpConversionPattern; 944 945 LogicalResult 946 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 947 ConversionPatternRewriter &rewriter) const override { 948 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 949 return success(); 950 } 951 }; 952 } // namespace 953 954 //===----------------------------------------------------------------------===// 955 956 namespace { 957 struct ConvertAsyncToLLVMPass 958 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 959 void runOnOperation() override; 960 }; 961 } // namespace 962 963 void ConvertAsyncToLLVMPass::runOnOperation() { 964 ModuleOp module = getOperation(); 965 MLIRContext *ctx = module->getContext(); 966 967 // Add declarations for most functions required by the coroutines lowering. 968 // We delay adding the resume function until it's needed because it currently 969 // fails to compile unless '-O0' is specified. 970 addAsyncRuntimeApiDeclarations(module); 971 addCRuntimeDeclarations(module); 972 973 // Lower async.runtime and async.coro operations to Async Runtime API and 974 // LLVM coroutine intrinsics. 975 976 // Convert async dialect types and operations to LLVM dialect. 977 AsyncRuntimeTypeConverter converter; 978 RewritePatternSet patterns(ctx); 979 980 // We use conversion to LLVM type to lower async.runtime load and store 981 // operations. 982 LLVMTypeConverter llvmConverter(ctx); 983 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 984 985 // Convert async types in function signatures and function calls. 986 populateFuncOpTypeConversionPattern(patterns, converter); 987 populateCallOpTypeConversionPattern(patterns, converter); 988 989 // Convert return operations inside async.execute regions. 990 patterns.add<ReturnOpOpConversion>(converter, ctx); 991 992 // Lower async.runtime operations to the async runtime API calls. 993 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 994 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 995 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 996 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 997 RuntimeDropRefOpLowering>(converter, ctx); 998 999 // Lower async.runtime operations that rely on LLVM type converter to convert 1000 // from async value payload type to the LLVM type. 1001 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 1002 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter, 1003 ctx); 1004 1005 // Lower async coroutine operations to LLVM coroutine intrinsics. 1006 patterns 1007 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 1008 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 1009 converter, ctx); 1010 1011 ConversionTarget target(*ctx); 1012 target 1013 .addLegalOp<arith::ConstantOp, ConstantOp, UnrealizedConversionCastOp>(); 1014 target.addLegalDialect<LLVM::LLVMDialect>(); 1015 1016 // All operations from Async dialect must be lowered to the runtime API and 1017 // LLVM intrinsics calls. 1018 target.addIllegalDialect<AsyncDialect>(); 1019 1020 // Add dynamic legality constraints to apply conversions defined above. 1021 target.addDynamicallyLegalOp<FuncOp>( 1022 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1023 target.addDynamicallyLegalOp<ReturnOp>( 1024 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1025 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1026 return converter.isSignatureLegal(op.getCalleeType()); 1027 }); 1028 1029 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1030 signalPassFailure(); 1031 } 1032 1033 //===----------------------------------------------------------------------===// 1034 // Patterns for structural type conversions for the Async dialect operations. 1035 //===----------------------------------------------------------------------===// 1036 1037 namespace { 1038 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1039 public: 1040 using OpConversionPattern::OpConversionPattern; 1041 LogicalResult 1042 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, 1043 ConversionPatternRewriter &rewriter) const override { 1044 ExecuteOp newOp = 1045 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1046 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1047 newOp.getRegion().end()); 1048 1049 // Set operands and update block argument and result types. 1050 newOp->setOperands(adaptor.getOperands()); 1051 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1052 return failure(); 1053 for (auto result : newOp.getResults()) 1054 result.setType(typeConverter->convertType(result.getType())); 1055 1056 rewriter.replaceOp(op, newOp.getResults()); 1057 return success(); 1058 } 1059 }; 1060 1061 // Dummy pattern to trigger the appropriate type conversion / materialization. 1062 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1063 public: 1064 using OpConversionPattern::OpConversionPattern; 1065 LogicalResult 1066 matchAndRewrite(AwaitOp op, OpAdaptor adaptor, 1067 ConversionPatternRewriter &rewriter) const override { 1068 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); 1069 return success(); 1070 } 1071 }; 1072 1073 // Dummy pattern to trigger the appropriate type conversion / materialization. 1074 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1075 public: 1076 using OpConversionPattern::OpConversionPattern; 1077 LogicalResult 1078 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 1079 ConversionPatternRewriter &rewriter) const override { 1080 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); 1081 return success(); 1082 } 1083 }; 1084 } // namespace 1085 1086 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1087 return std::make_unique<ConvertAsyncToLLVMPass>(); 1088 } 1089 1090 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1091 TypeConverter &typeConverter, RewritePatternSet &patterns, 1092 ConversionTarget &target) { 1093 typeConverter.addConversion([&](TokenType type) { return type; }); 1094 typeConverter.addConversion([&](ValueType type) { 1095 Type converted = typeConverter.convertType(type.getValueType()); 1096 return converted ? ValueType::get(converted) : converted; 1097 }); 1098 1099 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1100 typeConverter, patterns.getContext()); 1101 1102 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1103 [&](Operation *op) { return typeConverter.isLegal(op); }); 1104 } 1105