1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Dialect/Async/IR/Async.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 25 #define DEBUG_TYPE "convert-async-to-llvm" 26 27 using namespace mlir; 28 using namespace mlir::async; 29 30 //===----------------------------------------------------------------------===// 31 // Async Runtime C API declaration. 32 //===----------------------------------------------------------------------===// 33 34 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 35 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 37 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 38 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 39 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 40 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 41 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 42 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 43 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 44 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 45 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 46 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 47 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 48 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 49 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 50 static constexpr const char *kGetValueStorage = 51 "mlirAsyncRuntimeGetValueStorage"; 52 static constexpr const char *kAddTokenToGroup = 53 "mlirAsyncRuntimeAddTokenToGroup"; 54 static constexpr const char *kAwaitTokenAndExecute = 55 "mlirAsyncRuntimeAwaitTokenAndExecute"; 56 static constexpr const char *kAwaitValueAndExecute = 57 "mlirAsyncRuntimeAwaitValueAndExecute"; 58 static constexpr const char *kAwaitAllAndExecute = 59 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 60 61 namespace { 62 /// Async Runtime API function types. 63 /// 64 /// Because we can't create API function signature for type parametrized 65 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 66 /// lowering all async data types become opaque pointers at runtime. 67 struct AsyncAPI { 68 // All async types are lowered to opaque i8* LLVM pointers at runtime. 69 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 70 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 71 } 72 73 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 74 return LLVM::LLVMTokenType::get(ctx); 75 } 76 77 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 78 auto ref = opaquePointerType(ctx); 79 auto count = IntegerType::get(ctx, 64); 80 return FunctionType::get(ctx, {ref, count}, {}); 81 } 82 83 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 84 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 85 } 86 87 static FunctionType createValueFunctionType(MLIRContext *ctx) { 88 auto i64 = IntegerType::get(ctx, 64); 89 auto value = opaquePointerType(ctx); 90 return FunctionType::get(ctx, {i64}, {value}); 91 } 92 93 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 94 auto i64 = IntegerType::get(ctx, 64); 95 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 96 } 97 98 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 99 auto value = opaquePointerType(ctx); 100 auto storage = opaquePointerType(ctx); 101 return FunctionType::get(ctx, {value}, {storage}); 102 } 103 104 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 105 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 106 } 107 108 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 109 auto value = opaquePointerType(ctx); 110 return FunctionType::get(ctx, {value}, {}); 111 } 112 113 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 114 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 115 } 116 117 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 118 auto value = opaquePointerType(ctx); 119 return FunctionType::get(ctx, {value}, {}); 120 } 121 122 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 123 auto i1 = IntegerType::get(ctx, 1); 124 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 125 } 126 127 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 128 auto value = opaquePointerType(ctx); 129 auto i1 = IntegerType::get(ctx, 1); 130 return FunctionType::get(ctx, {value}, {i1}); 131 } 132 133 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 134 auto i1 = IntegerType::get(ctx, 1); 135 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 136 } 137 138 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 139 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 140 } 141 142 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 143 auto value = opaquePointerType(ctx); 144 return FunctionType::get(ctx, {value}, {}); 145 } 146 147 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 148 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 149 } 150 151 static FunctionType executeFunctionType(MLIRContext *ctx) { 152 auto hdl = opaquePointerType(ctx); 153 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 154 return FunctionType::get(ctx, {hdl, resume}, {}); 155 } 156 157 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 158 auto i64 = IntegerType::get(ctx, 64); 159 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 160 {i64}); 161 } 162 163 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 164 auto hdl = opaquePointerType(ctx); 165 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 166 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 167 } 168 169 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 170 auto value = opaquePointerType(ctx); 171 auto hdl = opaquePointerType(ctx); 172 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 173 return FunctionType::get(ctx, {value, hdl, resume}, {}); 174 } 175 176 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 177 auto hdl = opaquePointerType(ctx); 178 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 179 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 180 } 181 182 // Auxiliary coroutine resume intrinsic wrapper. 183 static Type resumeFunctionType(MLIRContext *ctx) { 184 auto voidTy = LLVM::LLVMVoidType::get(ctx); 185 auto i8Ptr = opaquePointerType(ctx); 186 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 187 } 188 }; 189 } // namespace 190 191 /// Adds Async Runtime C API declarations to the module. 192 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 193 auto builder = 194 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 195 196 auto addFuncDecl = [&](StringRef name, FunctionType type) { 197 if (module.lookupSymbol(name)) 198 return; 199 builder.create<FuncOp>(name, type).setPrivate(); 200 }; 201 202 MLIRContext *ctx = module.getContext(); 203 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 204 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 205 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 206 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 207 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 208 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 209 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 210 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 211 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 212 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 213 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 214 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 215 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 216 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 217 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 218 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 219 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 220 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 221 addFuncDecl(kAwaitTokenAndExecute, 222 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 223 addFuncDecl(kAwaitValueAndExecute, 224 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 225 addFuncDecl(kAwaitAllAndExecute, 226 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 227 } 228 229 //===----------------------------------------------------------------------===// 230 // Add malloc/free declarations to the module. 231 //===----------------------------------------------------------------------===// 232 233 static constexpr const char *kMalloc = "malloc"; 234 static constexpr const char *kFree = "free"; 235 236 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 237 StringRef name, Type ret, ArrayRef<Type> params) { 238 if (module.lookupSymbol(name)) 239 return; 240 Type type = LLVM::LLVMFunctionType::get(ret, params); 241 builder.create<LLVM::LLVMFuncOp>(name, type); 242 } 243 244 /// Adds malloc/free declarations to the module. 245 static void addCRuntimeDeclarations(ModuleOp module) { 246 using namespace mlir::LLVM; 247 248 MLIRContext *ctx = module.getContext(); 249 auto builder = 250 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 251 252 auto voidTy = LLVMVoidType::get(ctx); 253 auto i64 = IntegerType::get(ctx, 64); 254 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 255 256 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 257 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // Coroutine resume function wrapper. 262 //===----------------------------------------------------------------------===// 263 264 static constexpr const char *kResume = "__resume"; 265 266 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 267 /// intrinsics. We need this function to be able to pass it to the async 268 /// runtime execute API. 269 static void addResumeFunction(ModuleOp module) { 270 if (module.lookupSymbol(kResume)) 271 return; 272 273 MLIRContext *ctx = module.getContext(); 274 auto loc = module.getLoc(); 275 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 276 277 auto voidTy = LLVM::LLVMVoidType::get(ctx); 278 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 279 280 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 281 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 282 resumeOp.setPrivate(); 283 284 auto *block = resumeOp.addEntryBlock(); 285 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 286 287 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 288 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 289 } 290 291 //===----------------------------------------------------------------------===// 292 // Convert Async dialect types to LLVM types. 293 //===----------------------------------------------------------------------===// 294 295 namespace { 296 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 297 /// their runtime type (opaque pointers) and does not convert any other types. 298 class AsyncRuntimeTypeConverter : public TypeConverter { 299 public: 300 AsyncRuntimeTypeConverter() { 301 addConversion([](Type type) { return type; }); 302 addConversion(convertAsyncTypes); 303 } 304 305 static Optional<Type> convertAsyncTypes(Type type) { 306 if (type.isa<TokenType, GroupType, ValueType>()) 307 return AsyncAPI::opaquePointerType(type.getContext()); 308 309 if (type.isa<CoroIdType, CoroStateType>()) 310 return AsyncAPI::tokenType(type.getContext()); 311 if (type.isa<CoroHandleType>()) 312 return AsyncAPI::opaquePointerType(type.getContext()); 313 314 return llvm::None; 315 } 316 }; 317 } // namespace 318 319 //===----------------------------------------------------------------------===// 320 // Convert async.coro.id to @llvm.coro.id intrinsic. 321 //===----------------------------------------------------------------------===// 322 323 namespace { 324 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 325 public: 326 using OpConversionPattern::OpConversionPattern; 327 328 LogicalResult 329 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, 330 ConversionPatternRewriter &rewriter) const override { 331 auto token = AsyncAPI::tokenType(op->getContext()); 332 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 333 auto loc = op->getLoc(); 334 335 // Constants for initializing coroutine frame. 336 auto constZero = rewriter.create<LLVM::ConstantOp>( 337 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 338 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 339 340 // Get coroutine id: @llvm.coro.id. 341 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 342 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 343 344 return success(); 345 } 346 }; 347 } // namespace 348 349 //===----------------------------------------------------------------------===// 350 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 351 //===----------------------------------------------------------------------===// 352 353 namespace { 354 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 355 public: 356 using OpConversionPattern::OpConversionPattern; 357 358 LogicalResult 359 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, 360 ConversionPatternRewriter &rewriter) const override { 361 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 362 auto loc = op->getLoc(); 363 364 // Get coroutine frame size: @llvm.coro.size.i64. 365 auto coroSize = 366 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 367 368 // Allocate memory for the coroutine frame. 369 auto coroAlloc = rewriter.create<LLVM::CallOp>( 370 loc, i8Ptr, SymbolRefAttr::get(rewriter.getContext(), kMalloc), 371 ValueRange(coroSize.getResult())); 372 373 // Begin a coroutine: @llvm.coro.begin. 374 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); 375 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 376 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 377 378 return success(); 379 } 380 }; 381 } // namespace 382 383 //===----------------------------------------------------------------------===// 384 // Convert async.coro.free to @llvm.coro.free intrinsic. 385 //===----------------------------------------------------------------------===// 386 387 namespace { 388 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 389 public: 390 using OpConversionPattern::OpConversionPattern; 391 392 LogicalResult 393 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, 394 ConversionPatternRewriter &rewriter) const override { 395 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 396 auto loc = op->getLoc(); 397 398 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 399 auto coroMem = 400 rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands()); 401 402 // Free the memory. 403 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 404 op, TypeRange(), SymbolRefAttr::get(rewriter.getContext(), kFree), 405 ValueRange(coroMem.getResult())); 406 407 return success(); 408 } 409 }; 410 } // namespace 411 412 //===----------------------------------------------------------------------===// 413 // Convert async.coro.end to @llvm.coro.end intrinsic. 414 //===----------------------------------------------------------------------===// 415 416 namespace { 417 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 418 public: 419 using OpConversionPattern::OpConversionPattern; 420 421 LogicalResult 422 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, 423 ConversionPatternRewriter &rewriter) const override { 424 // We are not in the block that is part of the unwind sequence. 425 auto constFalse = rewriter.create<LLVM::ConstantOp>( 426 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 427 428 // Mark the end of a coroutine: @llvm.coro.end. 429 auto coroHdl = adaptor.handle(); 430 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 431 ValueRange({coroHdl, constFalse})); 432 rewriter.eraseOp(op); 433 434 return success(); 435 } 436 }; 437 } // namespace 438 439 //===----------------------------------------------------------------------===// 440 // Convert async.coro.save to @llvm.coro.save intrinsic. 441 //===----------------------------------------------------------------------===// 442 443 namespace { 444 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 445 public: 446 using OpConversionPattern::OpConversionPattern; 447 448 LogicalResult 449 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, 450 ConversionPatternRewriter &rewriter) const override { 451 // Save the coroutine state: @llvm.coro.save 452 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 453 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); 454 455 return success(); 456 } 457 }; 458 } // namespace 459 460 //===----------------------------------------------------------------------===// 461 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 462 //===----------------------------------------------------------------------===// 463 464 namespace { 465 466 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 467 /// branch to the appropriate block based on the return code. 468 /// 469 /// Before: 470 /// 471 /// ^suspended: 472 /// "opBefore"(...) 473 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 474 /// ^resume: 475 /// "op"(...) 476 /// ^cleanup: ... 477 /// ^suspend: ... 478 /// 479 /// After: 480 /// 481 /// ^suspended: 482 /// "opBefore"(...) 483 /// %suspend = llmv.intr.coro.suspend ... 484 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 485 /// ^resume: 486 /// "op"(...) 487 /// ^cleanup: ... 488 /// ^suspend: ... 489 /// 490 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 491 public: 492 using OpConversionPattern::OpConversionPattern; 493 494 LogicalResult 495 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, 496 ConversionPatternRewriter &rewriter) const override { 497 auto i8 = rewriter.getIntegerType(8); 498 auto i32 = rewriter.getI32Type(); 499 auto loc = op->getLoc(); 500 501 // This is not a final suspension point. 502 auto constFalse = rewriter.create<LLVM::ConstantOp>( 503 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 504 505 // Suspend a coroutine: @llvm.coro.suspend 506 auto coroState = adaptor.state(); 507 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 508 loc, i8, ValueRange({coroState, constFalse})); 509 510 // Cast return code to i32. 511 512 // After a suspension point decide if we should branch into resume, cleanup 513 // or suspend block of the coroutine (see @llvm.coro.suspend return code 514 // documentation). 515 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 516 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 517 op.cleanupDest()}; 518 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 519 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 520 /*defaultDestination=*/op.suspendDest(), 521 /*defaultOperands=*/ValueRange(), 522 /*caseValues=*/caseValues, 523 /*caseDestinations=*/caseDest, 524 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), 525 /*branchWeights=*/ArrayRef<int32_t>()); 526 527 return success(); 528 } 529 }; 530 } // namespace 531 532 //===----------------------------------------------------------------------===// 533 // Convert async.runtime.create to the corresponding runtime API call. 534 // 535 // To allocate storage for the async values we use getelementptr trick: 536 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 537 //===----------------------------------------------------------------------===// 538 539 namespace { 540 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 541 public: 542 using OpConversionPattern::OpConversionPattern; 543 544 LogicalResult 545 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, 546 ConversionPatternRewriter &rewriter) const override { 547 TypeConverter *converter = getTypeConverter(); 548 Type resultType = op->getResultTypes()[0]; 549 550 // Tokens creation maps to a simple function call. 551 if (resultType.isa<TokenType>()) { 552 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken, 553 converter->convertType(resultType)); 554 return success(); 555 } 556 557 // To create a value we need to compute the storage requirement. 558 if (auto value = resultType.dyn_cast<ValueType>()) { 559 // Returns the size requirements for the async value storage. 560 auto sizeOf = [&](ValueType valueType) -> Value { 561 auto loc = op->getLoc(); 562 auto i64 = rewriter.getI64Type(); 563 564 auto storedType = converter->convertType(valueType.getValueType()); 565 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 566 567 // %Size = getelementptr %T* null, int 1 568 // %SizeI = ptrtoint %T* %Size to i64 569 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 570 auto one = rewriter.create<LLVM::ConstantOp>( 571 loc, i64, rewriter.getI64IntegerAttr(1)); 572 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 573 one.getResult()); 574 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); 575 }; 576 577 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 578 sizeOf(value)); 579 580 return success(); 581 } 582 583 return rewriter.notifyMatchFailure(op, "unsupported async type"); 584 } 585 }; 586 } // namespace 587 588 //===----------------------------------------------------------------------===// 589 // Convert async.runtime.create_group to the corresponding runtime API call. 590 //===----------------------------------------------------------------------===// 591 592 namespace { 593 class RuntimeCreateGroupOpLowering 594 : public OpConversionPattern<RuntimeCreateGroupOp> { 595 public: 596 using OpConversionPattern::OpConversionPattern; 597 598 LogicalResult 599 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, 600 ConversionPatternRewriter &rewriter) const override { 601 TypeConverter *converter = getTypeConverter(); 602 Type resultType = op.getResult().getType(); 603 604 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, 605 converter->convertType(resultType), 606 adaptor.getOperands()); 607 return success(); 608 } 609 }; 610 } // namespace 611 612 //===----------------------------------------------------------------------===// 613 // Convert async.runtime.set_available to the corresponding runtime API call. 614 //===----------------------------------------------------------------------===// 615 616 namespace { 617 class RuntimeSetAvailableOpLowering 618 : public OpConversionPattern<RuntimeSetAvailableOp> { 619 public: 620 using OpConversionPattern::OpConversionPattern; 621 622 LogicalResult 623 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, 624 ConversionPatternRewriter &rewriter) const override { 625 StringRef apiFuncName = 626 TypeSwitch<Type, StringRef>(op.operand().getType()) 627 .Case<TokenType>([](Type) { return kEmplaceToken; }) 628 .Case<ValueType>([](Type) { return kEmplaceValue; }); 629 630 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 631 adaptor.getOperands()); 632 633 return success(); 634 } 635 }; 636 } // namespace 637 638 //===----------------------------------------------------------------------===// 639 // Convert async.runtime.set_error to the corresponding runtime API call. 640 //===----------------------------------------------------------------------===// 641 642 namespace { 643 class RuntimeSetErrorOpLowering 644 : public OpConversionPattern<RuntimeSetErrorOp> { 645 public: 646 using OpConversionPattern::OpConversionPattern; 647 648 LogicalResult 649 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, 650 ConversionPatternRewriter &rewriter) const override { 651 StringRef apiFuncName = 652 TypeSwitch<Type, StringRef>(op.operand().getType()) 653 .Case<TokenType>([](Type) { return kSetTokenError; }) 654 .Case<ValueType>([](Type) { return kSetValueError; }); 655 656 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 657 adaptor.getOperands()); 658 659 return success(); 660 } 661 }; 662 } // namespace 663 664 //===----------------------------------------------------------------------===// 665 // Convert async.runtime.is_error to the corresponding runtime API call. 666 //===----------------------------------------------------------------------===// 667 668 namespace { 669 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 670 public: 671 using OpConversionPattern::OpConversionPattern; 672 673 LogicalResult 674 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, 675 ConversionPatternRewriter &rewriter) const override { 676 StringRef apiFuncName = 677 TypeSwitch<Type, StringRef>(op.operand().getType()) 678 .Case<TokenType>([](Type) { return kIsTokenError; }) 679 .Case<GroupType>([](Type) { return kIsGroupError; }) 680 .Case<ValueType>([](Type) { return kIsValueError; }); 681 682 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(), 683 adaptor.getOperands()); 684 return success(); 685 } 686 }; 687 } // namespace 688 689 //===----------------------------------------------------------------------===// 690 // Convert async.runtime.await to the corresponding runtime API call. 691 //===----------------------------------------------------------------------===// 692 693 namespace { 694 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 695 public: 696 using OpConversionPattern::OpConversionPattern; 697 698 LogicalResult 699 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, 700 ConversionPatternRewriter &rewriter) const override { 701 StringRef apiFuncName = 702 TypeSwitch<Type, StringRef>(op.operand().getType()) 703 .Case<TokenType>([](Type) { return kAwaitToken; }) 704 .Case<ValueType>([](Type) { return kAwaitValue; }) 705 .Case<GroupType>([](Type) { return kAwaitGroup; }); 706 707 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 708 adaptor.getOperands()); 709 rewriter.eraseOp(op); 710 711 return success(); 712 } 713 }; 714 } // namespace 715 716 //===----------------------------------------------------------------------===// 717 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 718 //===----------------------------------------------------------------------===// 719 720 namespace { 721 class RuntimeAwaitAndResumeOpLowering 722 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 723 public: 724 using OpConversionPattern::OpConversionPattern; 725 726 LogicalResult 727 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, 728 ConversionPatternRewriter &rewriter) const override { 729 StringRef apiFuncName = 730 TypeSwitch<Type, StringRef>(op.operand().getType()) 731 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 732 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 733 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 734 735 Value operand = adaptor.operand(); 736 Value handle = adaptor.handle(); 737 738 // A pointer to coroutine resume intrinsic wrapper. 739 addResumeFunction(op->getParentOfType<ModuleOp>()); 740 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 741 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 742 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 743 744 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 745 ValueRange({operand, handle, resumePtr.res()})); 746 rewriter.eraseOp(op); 747 748 return success(); 749 } 750 }; 751 } // namespace 752 753 //===----------------------------------------------------------------------===// 754 // Convert async.runtime.resume to the corresponding runtime API call. 755 //===----------------------------------------------------------------------===// 756 757 namespace { 758 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 759 public: 760 using OpConversionPattern::OpConversionPattern; 761 762 LogicalResult 763 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, 764 ConversionPatternRewriter &rewriter) const override { 765 // A pointer to coroutine resume intrinsic wrapper. 766 addResumeFunction(op->getParentOfType<ModuleOp>()); 767 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 768 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 769 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 770 771 // Call async runtime API to execute a coroutine in the managed thread. 772 auto coroHdl = adaptor.handle(); 773 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute, 774 ValueRange({coroHdl, resumePtr.res()})); 775 776 return success(); 777 } 778 }; 779 } // namespace 780 781 //===----------------------------------------------------------------------===// 782 // Convert async.runtime.store to the corresponding runtime API call. 783 //===----------------------------------------------------------------------===// 784 785 namespace { 786 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 787 public: 788 using OpConversionPattern::OpConversionPattern; 789 790 LogicalResult 791 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, 792 ConversionPatternRewriter &rewriter) const override { 793 Location loc = op->getLoc(); 794 795 // Get a pointer to the async value storage from the runtime. 796 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 797 auto storage = adaptor.storage(); 798 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 799 TypeRange(i8Ptr), storage); 800 801 // Cast from i8* to the LLVM pointer type. 802 auto valueType = op.value().getType(); 803 auto llvmValueType = getTypeConverter()->convertType(valueType); 804 if (!llvmValueType) 805 return rewriter.notifyMatchFailure( 806 op, "failed to convert stored value type to LLVM type"); 807 808 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 809 loc, LLVM::LLVMPointerType::get(llvmValueType), 810 storagePtr.getResult(0)); 811 812 // Store the yielded value into the async value storage. 813 auto value = adaptor.value(); 814 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 815 816 // Erase the original runtime store operation. 817 rewriter.eraseOp(op); 818 819 return success(); 820 } 821 }; 822 } // namespace 823 824 //===----------------------------------------------------------------------===// 825 // Convert async.runtime.load to the corresponding runtime API call. 826 //===----------------------------------------------------------------------===// 827 828 namespace { 829 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 830 public: 831 using OpConversionPattern::OpConversionPattern; 832 833 LogicalResult 834 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, 835 ConversionPatternRewriter &rewriter) const override { 836 Location loc = op->getLoc(); 837 838 // Get a pointer to the async value storage from the runtime. 839 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 840 auto storage = adaptor.storage(); 841 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 842 TypeRange(i8Ptr), storage); 843 844 // Cast from i8* to the LLVM pointer type. 845 auto valueType = op.result().getType(); 846 auto llvmValueType = getTypeConverter()->convertType(valueType); 847 if (!llvmValueType) 848 return rewriter.notifyMatchFailure( 849 op, "failed to convert loaded value type to LLVM type"); 850 851 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 852 loc, LLVM::LLVMPointerType::get(llvmValueType), 853 storagePtr.getResult(0)); 854 855 // Load from the casted pointer. 856 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 857 858 return success(); 859 } 860 }; 861 } // namespace 862 863 //===----------------------------------------------------------------------===// 864 // Convert async.runtime.add_to_group to the corresponding runtime API call. 865 //===----------------------------------------------------------------------===// 866 867 namespace { 868 class RuntimeAddToGroupOpLowering 869 : public OpConversionPattern<RuntimeAddToGroupOp> { 870 public: 871 using OpConversionPattern::OpConversionPattern; 872 873 LogicalResult 874 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, 875 ConversionPatternRewriter &rewriter) const override { 876 // Currently we can only add tokens to the group. 877 if (!op.operand().getType().isa<TokenType>()) 878 return rewriter.notifyMatchFailure(op, "only token type is supported"); 879 880 // Replace with a runtime API function call. 881 rewriter.replaceOpWithNewOp<CallOp>( 882 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); 883 884 return success(); 885 } 886 }; 887 } // namespace 888 889 //===----------------------------------------------------------------------===// 890 // Async reference counting ops lowering (`async.runtime.add_ref` and 891 // `async.runtime.drop_ref` to the corresponding API calls). 892 //===----------------------------------------------------------------------===// 893 894 namespace { 895 template <typename RefCountingOp> 896 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 897 public: 898 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 899 StringRef apiFunctionName) 900 : OpConversionPattern<RefCountingOp>(converter, ctx), 901 apiFunctionName(apiFunctionName) {} 902 903 LogicalResult 904 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, 905 ConversionPatternRewriter &rewriter) const override { 906 auto count = 907 rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI64Type(), 908 rewriter.getI64IntegerAttr(op.count())); 909 910 auto operand = adaptor.operand(); 911 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 912 ValueRange({operand, count})); 913 914 return success(); 915 } 916 917 private: 918 StringRef apiFunctionName; 919 }; 920 921 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 922 public: 923 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 924 : RefCountingOpLowering(converter, ctx, kAddRef) {} 925 }; 926 927 class RuntimeDropRefOpLowering 928 : public RefCountingOpLowering<RuntimeDropRefOp> { 929 public: 930 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 931 : RefCountingOpLowering(converter, ctx, kDropRef) {} 932 }; 933 } // namespace 934 935 //===----------------------------------------------------------------------===// 936 // Convert return operations that return async values from async regions. 937 //===----------------------------------------------------------------------===// 938 939 namespace { 940 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 941 public: 942 using OpConversionPattern::OpConversionPattern; 943 944 LogicalResult 945 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 946 ConversionPatternRewriter &rewriter) const override { 947 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 948 return success(); 949 } 950 }; 951 } // namespace 952 953 //===----------------------------------------------------------------------===// 954 955 namespace { 956 struct ConvertAsyncToLLVMPass 957 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 958 void runOnOperation() override; 959 }; 960 } // namespace 961 962 void ConvertAsyncToLLVMPass::runOnOperation() { 963 ModuleOp module = getOperation(); 964 MLIRContext *ctx = module->getContext(); 965 966 // Add declarations for most functions required by the coroutines lowering. 967 // We delay adding the resume function until it's needed because it currently 968 // fails to compile unless '-O0' is specified. 969 addAsyncRuntimeApiDeclarations(module); 970 addCRuntimeDeclarations(module); 971 972 // Lower async.runtime and async.coro operations to Async Runtime API and 973 // LLVM coroutine intrinsics. 974 975 // Convert async dialect types and operations to LLVM dialect. 976 AsyncRuntimeTypeConverter converter; 977 RewritePatternSet patterns(ctx); 978 979 // We use conversion to LLVM type to lower async.runtime load and store 980 // operations. 981 LLVMTypeConverter llvmConverter(ctx); 982 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 983 984 // Convert async types in function signatures and function calls. 985 populateFuncOpTypeConversionPattern(patterns, converter); 986 populateCallOpTypeConversionPattern(patterns, converter); 987 988 // Convert return operations inside async.execute regions. 989 patterns.add<ReturnOpOpConversion>(converter, ctx); 990 991 // Lower async.runtime operations to the async runtime API calls. 992 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 993 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 994 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 995 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 996 RuntimeDropRefOpLowering>(converter, ctx); 997 998 // Lower async.runtime operations that rely on LLVM type converter to convert 999 // from async value payload type to the LLVM type. 1000 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 1001 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter, 1002 ctx); 1003 1004 // Lower async coroutine operations to LLVM coroutine intrinsics. 1005 patterns 1006 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 1007 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 1008 converter, ctx); 1009 1010 ConversionTarget target(*ctx); 1011 target.addLegalOp<ConstantOp, UnrealizedConversionCastOp>(); 1012 target.addLegalDialect<LLVM::LLVMDialect>(); 1013 1014 // All operations from Async dialect must be lowered to the runtime API and 1015 // LLVM intrinsics calls. 1016 target.addIllegalDialect<AsyncDialect>(); 1017 1018 // Add dynamic legality constraints to apply conversions defined above. 1019 target.addDynamicallyLegalOp<FuncOp>( 1020 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1021 target.addDynamicallyLegalOp<ReturnOp>( 1022 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1023 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1024 return converter.isSignatureLegal(op.getCalleeType()); 1025 }); 1026 1027 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1028 signalPassFailure(); 1029 } 1030 1031 //===----------------------------------------------------------------------===// 1032 // Patterns for structural type conversions for the Async dialect operations. 1033 //===----------------------------------------------------------------------===// 1034 1035 namespace { 1036 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1037 public: 1038 using OpConversionPattern::OpConversionPattern; 1039 LogicalResult 1040 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, 1041 ConversionPatternRewriter &rewriter) const override { 1042 ExecuteOp newOp = 1043 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1044 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1045 newOp.getRegion().end()); 1046 1047 // Set operands and update block argument and result types. 1048 newOp->setOperands(adaptor.getOperands()); 1049 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1050 return failure(); 1051 for (auto result : newOp.getResults()) 1052 result.setType(typeConverter->convertType(result.getType())); 1053 1054 rewriter.replaceOp(op, newOp.getResults()); 1055 return success(); 1056 } 1057 }; 1058 1059 // Dummy pattern to trigger the appropriate type conversion / materialization. 1060 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1061 public: 1062 using OpConversionPattern::OpConversionPattern; 1063 LogicalResult 1064 matchAndRewrite(AwaitOp op, OpAdaptor adaptor, 1065 ConversionPatternRewriter &rewriter) const override { 1066 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); 1067 return success(); 1068 } 1069 }; 1070 1071 // Dummy pattern to trigger the appropriate type conversion / materialization. 1072 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1073 public: 1074 using OpConversionPattern::OpConversionPattern; 1075 LogicalResult 1076 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 1077 ConversionPatternRewriter &rewriter) const override { 1078 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); 1079 return success(); 1080 } 1081 }; 1082 } // namespace 1083 1084 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1085 return std::make_unique<ConvertAsyncToLLVMPass>(); 1086 } 1087 1088 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1089 TypeConverter &typeConverter, RewritePatternSet &patterns, 1090 ConversionTarget &target) { 1091 typeConverter.addConversion([&](TokenType type) { return type; }); 1092 typeConverter.addConversion([&](ValueType type) { 1093 Type converted = typeConverter.convertType(type.getValueType()); 1094 return converted ? ValueType::get(converted) : converted; 1095 }); 1096 1097 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1098 typeConverter, patterns.getContext()); 1099 1100 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1101 [&](Operation *op) { return typeConverter.isLegal(op); }); 1102 } 1103