1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Async/IR/Async.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 #define DEBUG_TYPE "convert-async-to-llvm" 23 24 using namespace mlir; 25 using namespace mlir::async; 26 27 //===----------------------------------------------------------------------===// 28 // Async Runtime C API declaration. 29 //===----------------------------------------------------------------------===// 30 31 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 32 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 33 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 34 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 35 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 36 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 37 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 39 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 42 static constexpr const char *kGetValueStorage = 43 "mlirAsyncRuntimeGetValueStorage"; 44 static constexpr const char *kAddTokenToGroup = 45 "mlirAsyncRuntimeAddTokenToGroup"; 46 static constexpr const char *kAwaitTokenAndExecute = 47 "mlirAsyncRuntimeAwaitTokenAndExecute"; 48 static constexpr const char *kAwaitValueAndExecute = 49 "mlirAsyncRuntimeAwaitValueAndExecute"; 50 static constexpr const char *kAwaitAllAndExecute = 51 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 52 53 namespace { 54 /// Async Runtime API function types. 55 /// 56 /// Because we can't create API function signature for type parametrized 57 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 58 /// lowering all async data types become opaque pointers at runtime. 59 struct AsyncAPI { 60 // All async types are lowered to opaque i8* LLVM pointers at runtime. 61 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 62 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 63 } 64 65 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 66 return LLVM::LLVMTokenType::get(ctx); 67 } 68 69 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 70 auto ref = opaquePointerType(ctx); 71 auto count = IntegerType::get(ctx, 32); 72 return FunctionType::get(ctx, {ref, count}, {}); 73 } 74 75 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 76 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 77 } 78 79 static FunctionType createValueFunctionType(MLIRContext *ctx) { 80 auto i32 = IntegerType::get(ctx, 32); 81 auto value = opaquePointerType(ctx); 82 return FunctionType::get(ctx, {i32}, {value}); 83 } 84 85 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 86 return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); 87 } 88 89 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 90 auto value = opaquePointerType(ctx); 91 auto storage = opaquePointerType(ctx); 92 return FunctionType::get(ctx, {value}, {storage}); 93 } 94 95 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 96 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 97 } 98 99 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 100 auto value = opaquePointerType(ctx); 101 return FunctionType::get(ctx, {value}, {}); 102 } 103 104 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 105 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 106 } 107 108 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 109 auto value = opaquePointerType(ctx); 110 return FunctionType::get(ctx, {value}, {}); 111 } 112 113 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 114 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 115 } 116 117 static FunctionType executeFunctionType(MLIRContext *ctx) { 118 auto hdl = opaquePointerType(ctx); 119 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 120 return FunctionType::get(ctx, {hdl, resume}, {}); 121 } 122 123 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 124 auto i64 = IntegerType::get(ctx, 64); 125 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 126 {i64}); 127 } 128 129 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 130 auto hdl = opaquePointerType(ctx); 131 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 132 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 133 } 134 135 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 136 auto value = opaquePointerType(ctx); 137 auto hdl = opaquePointerType(ctx); 138 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 139 return FunctionType::get(ctx, {value, hdl, resume}, {}); 140 } 141 142 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 143 auto hdl = opaquePointerType(ctx); 144 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 145 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 146 } 147 148 // Auxiliary coroutine resume intrinsic wrapper. 149 static Type resumeFunctionType(MLIRContext *ctx) { 150 auto voidTy = LLVM::LLVMVoidType::get(ctx); 151 auto i8Ptr = opaquePointerType(ctx); 152 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 153 } 154 }; 155 } // namespace 156 157 /// Adds Async Runtime C API declarations to the module. 158 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 159 auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), 160 module.getBody()); 161 162 auto addFuncDecl = [&](StringRef name, FunctionType type) { 163 if (module.lookupSymbol(name)) 164 return; 165 builder.create<FuncOp>(name, type).setPrivate(); 166 }; 167 168 MLIRContext *ctx = module.getContext(); 169 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 170 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 171 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 172 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 173 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 174 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 175 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 176 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 177 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 178 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 179 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 180 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 181 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 182 addFuncDecl(kAwaitTokenAndExecute, 183 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 184 addFuncDecl(kAwaitValueAndExecute, 185 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 186 addFuncDecl(kAwaitAllAndExecute, 187 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 188 } 189 190 //===----------------------------------------------------------------------===// 191 // Add malloc/free declarations to the module. 192 //===----------------------------------------------------------------------===// 193 194 static constexpr const char *kMalloc = "malloc"; 195 static constexpr const char *kFree = "free"; 196 197 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 198 StringRef name, Type ret, ArrayRef<Type> params) { 199 if (module.lookupSymbol(name)) 200 return; 201 Type type = LLVM::LLVMFunctionType::get(ret, params); 202 builder.create<LLVM::LLVMFuncOp>(name, type); 203 } 204 205 /// Adds malloc/free declarations to the module. 206 static void addCRuntimeDeclarations(ModuleOp module) { 207 using namespace mlir::LLVM; 208 209 MLIRContext *ctx = module.getContext(); 210 ImplicitLocOpBuilder builder(module.getLoc(), 211 module.getBody()->getTerminator()); 212 213 auto voidTy = LLVMVoidType::get(ctx); 214 auto i64 = IntegerType::get(ctx, 64); 215 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 216 217 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 218 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // Coroutine resume function wrapper. 223 //===----------------------------------------------------------------------===// 224 225 static constexpr const char *kResume = "__resume"; 226 227 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 228 /// intrinsics. We need this function to be able to pass it to the async 229 /// runtime execute API. 230 static void addResumeFunction(ModuleOp module) { 231 MLIRContext *ctx = module.getContext(); 232 233 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 234 Location loc = module.getLoc(); 235 236 if (module.lookupSymbol(kResume)) 237 return; 238 239 auto voidTy = LLVM::LLVMVoidType::get(ctx); 240 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 241 242 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 243 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 244 resumeOp.setPrivate(); 245 246 auto *block = resumeOp.addEntryBlock(); 247 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 248 249 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 250 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 251 } 252 253 //===----------------------------------------------------------------------===// 254 // Convert Async dialect types to LLVM types. 255 //===----------------------------------------------------------------------===// 256 257 namespace { 258 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 259 /// their runtime type (opaque pointers) and does not convert any other types. 260 class AsyncRuntimeTypeConverter : public TypeConverter { 261 public: 262 AsyncRuntimeTypeConverter() { 263 addConversion([](Type type) { return type; }); 264 addConversion(convertAsyncTypes); 265 } 266 267 static Optional<Type> convertAsyncTypes(Type type) { 268 if (type.isa<TokenType, GroupType, ValueType>()) 269 return AsyncAPI::opaquePointerType(type.getContext()); 270 271 if (type.isa<CoroIdType, CoroStateType>()) 272 return AsyncAPI::tokenType(type.getContext()); 273 if (type.isa<CoroHandleType>()) 274 return AsyncAPI::opaquePointerType(type.getContext()); 275 276 return llvm::None; 277 } 278 }; 279 } // namespace 280 281 //===----------------------------------------------------------------------===// 282 // Convert async.coro.id to @llvm.coro.id intrinsic. 283 //===----------------------------------------------------------------------===// 284 285 namespace { 286 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 287 public: 288 using OpConversionPattern::OpConversionPattern; 289 290 LogicalResult 291 matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands, 292 ConversionPatternRewriter &rewriter) const override { 293 auto token = AsyncAPI::tokenType(op->getContext()); 294 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 295 auto loc = op->getLoc(); 296 297 // Constants for initializing coroutine frame. 298 auto constZero = rewriter.create<LLVM::ConstantOp>( 299 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 300 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 301 302 // Get coroutine id: @llvm.coro.id. 303 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 304 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 305 306 return success(); 307 } 308 }; 309 } // namespace 310 311 //===----------------------------------------------------------------------===// 312 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 313 //===----------------------------------------------------------------------===// 314 315 namespace { 316 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 317 public: 318 using OpConversionPattern::OpConversionPattern; 319 320 LogicalResult 321 matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands, 322 ConversionPatternRewriter &rewriter) const override { 323 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 324 auto loc = op->getLoc(); 325 326 // Get coroutine frame size: @llvm.coro.size.i64. 327 auto coroSize = 328 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 329 330 // Allocate memory for the coroutine frame. 331 auto coroAlloc = rewriter.create<LLVM::CallOp>( 332 loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc), 333 ValueRange(coroSize.getResult())); 334 335 // Begin a coroutine: @llvm.coro.begin. 336 auto coroId = CoroBeginOpAdaptor(operands).id(); 337 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 338 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 339 340 return success(); 341 } 342 }; 343 } // namespace 344 345 //===----------------------------------------------------------------------===// 346 // Convert async.coro.free to @llvm.coro.free intrinsic. 347 //===----------------------------------------------------------------------===// 348 349 namespace { 350 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 351 public: 352 using OpConversionPattern::OpConversionPattern; 353 354 LogicalResult 355 matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands, 356 ConversionPatternRewriter &rewriter) const override { 357 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 358 auto loc = op->getLoc(); 359 360 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 361 auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands); 362 363 // Free the memory. 364 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 365 rewriter.getSymbolRefAttr(kFree), 366 ValueRange(coroMem.getResult())); 367 368 return success(); 369 } 370 }; 371 } // namespace 372 373 //===----------------------------------------------------------------------===// 374 // Convert async.coro.end to @llvm.coro.end intrinsic. 375 //===----------------------------------------------------------------------===// 376 377 namespace { 378 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 379 public: 380 using OpConversionPattern::OpConversionPattern; 381 382 LogicalResult 383 matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands, 384 ConversionPatternRewriter &rewriter) const override { 385 // We are not in the block that is part of the unwind sequence. 386 auto constFalse = rewriter.create<LLVM::ConstantOp>( 387 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 388 389 // Mark the end of a coroutine: @llvm.coro.end. 390 auto coroHdl = CoroEndOpAdaptor(operands).handle(); 391 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 392 ValueRange({coroHdl, constFalse})); 393 rewriter.eraseOp(op); 394 395 return success(); 396 } 397 }; 398 } // namespace 399 400 //===----------------------------------------------------------------------===// 401 // Convert async.coro.save to @llvm.coro.save intrinsic. 402 //===----------------------------------------------------------------------===// 403 404 namespace { 405 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 406 public: 407 using OpConversionPattern::OpConversionPattern; 408 409 LogicalResult 410 matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands, 411 ConversionPatternRewriter &rewriter) const override { 412 // Save the coroutine state: @llvm.coro.save 413 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 414 op, AsyncAPI::tokenType(op->getContext()), operands); 415 416 return success(); 417 } 418 }; 419 } // namespace 420 421 //===----------------------------------------------------------------------===// 422 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 423 //===----------------------------------------------------------------------===// 424 425 namespace { 426 427 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 428 /// branch to the appropriate block based on the return code. 429 /// 430 /// Before: 431 /// 432 /// ^suspended: 433 /// "opBefore"(...) 434 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 435 /// ^resume: 436 /// "op"(...) 437 /// ^cleanup: ... 438 /// ^suspend: ... 439 /// 440 /// After: 441 /// 442 /// ^suspended: 443 /// "opBefore"(...) 444 /// %suspend = llmv.intr.coro.suspend ... 445 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 446 /// ^resume: 447 /// "op"(...) 448 /// ^cleanup: ... 449 /// ^suspend: ... 450 /// 451 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 452 public: 453 using OpConversionPattern::OpConversionPattern; 454 455 LogicalResult 456 matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands, 457 ConversionPatternRewriter &rewriter) const override { 458 auto i8 = rewriter.getIntegerType(8); 459 auto i32 = rewriter.getI32Type(); 460 auto loc = op->getLoc(); 461 462 // This is not a final suspension point. 463 auto constFalse = rewriter.create<LLVM::ConstantOp>( 464 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 465 466 // Suspend a coroutine: @llvm.coro.suspend 467 auto coroState = CoroSuspendOpAdaptor(operands).state(); 468 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 469 loc, i8, ValueRange({coroState, constFalse})); 470 471 // Cast return code to i32. 472 473 // After a suspension point decide if we should branch into resume, cleanup 474 // or suspend block of the coroutine (see @llvm.coro.suspend return code 475 // documentation). 476 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 477 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 478 op.cleanupDest()}; 479 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 480 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 481 /*defaultDestination=*/op.suspendDest(), 482 /*defaultOperands=*/ValueRange(), 483 /*caseValues=*/caseValues, 484 /*caseDestinations=*/caseDest, 485 /*caseOperands=*/ArrayRef<ValueRange>(), 486 /*branchWeights=*/ArrayRef<int32_t>()); 487 488 return success(); 489 } 490 }; 491 } // namespace 492 493 //===----------------------------------------------------------------------===// 494 // Convert async.runtime.create to the corresponding runtime API call. 495 // 496 // To allocate storage for the async values we use getelementptr trick: 497 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 498 //===----------------------------------------------------------------------===// 499 500 namespace { 501 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 502 public: 503 using OpConversionPattern::OpConversionPattern; 504 505 LogicalResult 506 matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands, 507 ConversionPatternRewriter &rewriter) const override { 508 TypeConverter *converter = getTypeConverter(); 509 Type resultType = op->getResultTypes()[0]; 510 511 // Tokens and Groups lowered to function calls without arguments. 512 if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) { 513 rewriter.replaceOpWithNewOp<CallOp>( 514 op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup, 515 converter->convertType(resultType)); 516 return success(); 517 } 518 519 // To create a value we need to compute the storage requirement. 520 if (auto value = resultType.dyn_cast<ValueType>()) { 521 // Returns the size requirements for the async value storage. 522 auto sizeOf = [&](ValueType valueType) -> Value { 523 auto loc = op->getLoc(); 524 auto i32 = rewriter.getI32Type(); 525 526 auto storedType = converter->convertType(valueType.getValueType()); 527 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 528 529 // %Size = getelementptr %T* null, int 1 530 // %SizeI = ptrtoint %T* %Size to i32 531 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 532 auto one = rewriter.create<LLVM::ConstantOp>( 533 loc, i32, rewriter.getI32IntegerAttr(1)); 534 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 535 one.getResult()); 536 return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep); 537 }; 538 539 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 540 sizeOf(value)); 541 542 return success(); 543 } 544 545 return rewriter.notifyMatchFailure(op, "unsupported async type"); 546 } 547 }; 548 } // namespace 549 550 //===----------------------------------------------------------------------===// 551 // Convert async.runtime.set_available to the corresponding runtime API call. 552 //===----------------------------------------------------------------------===// 553 554 namespace { 555 class RuntimeSetAvailableOpLowering 556 : public OpConversionPattern<RuntimeSetAvailableOp> { 557 public: 558 using OpConversionPattern::OpConversionPattern; 559 560 LogicalResult 561 matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands, 562 ConversionPatternRewriter &rewriter) const override { 563 Type operandType = op.operand().getType(); 564 565 if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) { 566 rewriter.create<CallOp>(op->getLoc(), 567 operandType.isa<TokenType>() ? kEmplaceToken 568 : kEmplaceValue, 569 TypeRange(), operands); 570 rewriter.eraseOp(op); 571 return success(); 572 } 573 574 return rewriter.notifyMatchFailure(op, "unsupported async type"); 575 } 576 }; 577 } // namespace 578 579 //===----------------------------------------------------------------------===// 580 // Convert async.runtime.await to the corresponding runtime API call. 581 //===----------------------------------------------------------------------===// 582 583 namespace { 584 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 585 public: 586 using OpConversionPattern::OpConversionPattern; 587 588 LogicalResult 589 matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands, 590 ConversionPatternRewriter &rewriter) const override { 591 Type operandType = op.operand().getType(); 592 593 StringRef apiFuncName; 594 if (operandType.isa<TokenType>()) 595 apiFuncName = kAwaitToken; 596 else if (operandType.isa<ValueType>()) 597 apiFuncName = kAwaitValue; 598 else if (operandType.isa<GroupType>()) 599 apiFuncName = kAwaitGroup; 600 else 601 return rewriter.notifyMatchFailure(op, "unsupported async type"); 602 603 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands); 604 rewriter.eraseOp(op); 605 606 return success(); 607 } 608 }; 609 } // namespace 610 611 //===----------------------------------------------------------------------===// 612 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 613 //===----------------------------------------------------------------------===// 614 615 namespace { 616 class RuntimeAwaitAndResumeOpLowering 617 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 618 public: 619 using OpConversionPattern::OpConversionPattern; 620 621 LogicalResult 622 matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands, 623 ConversionPatternRewriter &rewriter) const override { 624 Type operandType = op.operand().getType(); 625 626 StringRef apiFuncName; 627 if (operandType.isa<TokenType>()) 628 apiFuncName = kAwaitTokenAndExecute; 629 else if (operandType.isa<ValueType>()) 630 apiFuncName = kAwaitValueAndExecute; 631 else if (operandType.isa<GroupType>()) 632 apiFuncName = kAwaitAllAndExecute; 633 else 634 return rewriter.notifyMatchFailure(op, "unsupported async type"); 635 636 Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); 637 Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); 638 639 // A pointer to coroutine resume intrinsic wrapper. 640 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 641 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 642 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 643 644 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 645 ValueRange({operand, handle, resumePtr.res()})); 646 rewriter.eraseOp(op); 647 648 return success(); 649 } 650 }; 651 } // namespace 652 653 //===----------------------------------------------------------------------===// 654 // Convert async.runtime.resume to the corresponding runtime API call. 655 //===----------------------------------------------------------------------===// 656 657 namespace { 658 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 659 public: 660 using OpConversionPattern::OpConversionPattern; 661 662 LogicalResult 663 matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands, 664 ConversionPatternRewriter &rewriter) const override { 665 // A pointer to coroutine resume intrinsic wrapper. 666 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 667 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 668 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 669 670 // Call async runtime API to execute a coroutine in the managed thread. 671 auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); 672 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute, 673 ValueRange({coroHdl, resumePtr.res()})); 674 675 return success(); 676 } 677 }; 678 } // namespace 679 680 //===----------------------------------------------------------------------===// 681 // Convert async.runtime.store to the corresponding runtime API call. 682 //===----------------------------------------------------------------------===// 683 684 namespace { 685 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 686 public: 687 using OpConversionPattern::OpConversionPattern; 688 689 LogicalResult 690 matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands, 691 ConversionPatternRewriter &rewriter) const override { 692 Location loc = op->getLoc(); 693 694 // Get a pointer to the async value storage from the runtime. 695 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 696 auto storage = RuntimeStoreOpAdaptor(operands).storage(); 697 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 698 TypeRange(i8Ptr), storage); 699 700 // Cast from i8* to the LLVM pointer type. 701 auto valueType = op.value().getType(); 702 auto llvmValueType = getTypeConverter()->convertType(valueType); 703 if (!llvmValueType) 704 return rewriter.notifyMatchFailure( 705 op, "failed to convert stored value type to LLVM type"); 706 707 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 708 loc, LLVM::LLVMPointerType::get(llvmValueType), 709 storagePtr.getResult(0)); 710 711 // Store the yielded value into the async value storage. 712 auto value = RuntimeStoreOpAdaptor(operands).value(); 713 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 714 715 // Erase the original runtime store operation. 716 rewriter.eraseOp(op); 717 718 return success(); 719 } 720 }; 721 } // namespace 722 723 //===----------------------------------------------------------------------===// 724 // Convert async.runtime.load to the corresponding runtime API call. 725 //===----------------------------------------------------------------------===// 726 727 namespace { 728 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 729 public: 730 using OpConversionPattern::OpConversionPattern; 731 732 LogicalResult 733 matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands, 734 ConversionPatternRewriter &rewriter) const override { 735 Location loc = op->getLoc(); 736 737 // Get a pointer to the async value storage from the runtime. 738 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 739 auto storage = RuntimeLoadOpAdaptor(operands).storage(); 740 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 741 TypeRange(i8Ptr), storage); 742 743 // Cast from i8* to the LLVM pointer type. 744 auto valueType = op.result().getType(); 745 auto llvmValueType = getTypeConverter()->convertType(valueType); 746 if (!llvmValueType) 747 return rewriter.notifyMatchFailure( 748 op, "failed to convert loaded value type to LLVM type"); 749 750 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 751 loc, LLVM::LLVMPointerType::get(llvmValueType), 752 storagePtr.getResult(0)); 753 754 // Load from the casted pointer. 755 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 756 757 return success(); 758 } 759 }; 760 } // namespace 761 762 //===----------------------------------------------------------------------===// 763 // Convert async.runtime.add_to_group to the corresponding runtime API call. 764 //===----------------------------------------------------------------------===// 765 766 namespace { 767 class RuntimeAddToGroupOpLowering 768 : public OpConversionPattern<RuntimeAddToGroupOp> { 769 public: 770 using OpConversionPattern::OpConversionPattern; 771 772 LogicalResult 773 matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands, 774 ConversionPatternRewriter &rewriter) const override { 775 // Currently we can only add tokens to the group. 776 if (!op.operand().getType().isa<TokenType>()) 777 return rewriter.notifyMatchFailure(op, "only token type is supported"); 778 779 // Replace with a runtime API function call. 780 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, 781 rewriter.getI64Type(), operands); 782 783 return success(); 784 } 785 }; 786 } // namespace 787 788 //===----------------------------------------------------------------------===// 789 // Async reference counting ops lowering (`async.runtime.add_ref` and 790 // `async.runtime.drop_ref` to the corresponding API calls). 791 //===----------------------------------------------------------------------===// 792 793 namespace { 794 template <typename RefCountingOp> 795 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 796 public: 797 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 798 StringRef apiFunctionName) 799 : OpConversionPattern<RefCountingOp>(converter, ctx), 800 apiFunctionName(apiFunctionName) {} 801 802 LogicalResult 803 matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands, 804 ConversionPatternRewriter &rewriter) const override { 805 auto count = 806 rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(), 807 rewriter.getI32IntegerAttr(op.count())); 808 809 auto operand = typename RefCountingOp::Adaptor(operands).operand(); 810 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 811 ValueRange({operand, count})); 812 813 return success(); 814 } 815 816 private: 817 StringRef apiFunctionName; 818 }; 819 820 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 821 public: 822 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 823 : RefCountingOpLowering(converter, ctx, kAddRef) {} 824 }; 825 826 class RuntimeDropRefOpLowering 827 : public RefCountingOpLowering<RuntimeDropRefOp> { 828 public: 829 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 830 : RefCountingOpLowering(converter, ctx, kDropRef) {} 831 }; 832 } // namespace 833 834 //===----------------------------------------------------------------------===// 835 // Convert return operations that return async values from async regions. 836 //===----------------------------------------------------------------------===// 837 838 namespace { 839 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 840 public: 841 using OpConversionPattern::OpConversionPattern; 842 843 LogicalResult 844 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 845 ConversionPatternRewriter &rewriter) const override { 846 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 847 return success(); 848 } 849 }; 850 } // namespace 851 852 //===----------------------------------------------------------------------===// 853 854 namespace { 855 struct ConvertAsyncToLLVMPass 856 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 857 void runOnOperation() override; 858 }; 859 } // namespace 860 861 void ConvertAsyncToLLVMPass::runOnOperation() { 862 ModuleOp module = getOperation(); 863 MLIRContext *ctx = module->getContext(); 864 865 // Add declarations for all functions required by the coroutines lowering. 866 addResumeFunction(module); 867 addAsyncRuntimeApiDeclarations(module); 868 addCRuntimeDeclarations(module); 869 870 // Lower async.runtime and async.coro operations to Async Runtime API and 871 // LLVM coroutine intrinsics. 872 873 // Convert async dialect types and operations to LLVM dialect. 874 AsyncRuntimeTypeConverter converter; 875 OwningRewritePatternList patterns; 876 877 // We use conversion to LLVM type to lower async.runtime load and store 878 // operations. 879 LLVMTypeConverter llvmConverter(ctx); 880 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 881 882 // Convert async types in function signatures and function calls. 883 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 884 populateCallOpTypeConversionPattern(patterns, ctx, converter); 885 886 // Convert return operations inside async.execute regions. 887 patterns.insert<ReturnOpOpConversion>(converter, ctx); 888 889 // Lower async.runtime operations to the async runtime API calls. 890 patterns.insert<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering, 891 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 892 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 893 RuntimeDropRefOpLowering>(converter, ctx); 894 895 // Lower async.runtime operations that rely on LLVM type converter to convert 896 // from async value payload type to the LLVM type. 897 patterns.insert<RuntimeCreateOpLowering, RuntimeStoreOpLowering, 898 RuntimeLoadOpLowering>(llvmConverter, ctx); 899 900 // Lower async coroutine operations to LLVM coroutine intrinsics. 901 patterns.insert<CoroIdOpConversion, CoroBeginOpConversion, 902 CoroFreeOpConversion, CoroEndOpConversion, 903 CoroSaveOpConversion, CoroSuspendOpConversion>(converter, 904 ctx); 905 906 ConversionTarget target(*ctx); 907 target.addLegalOp<ConstantOp>(); 908 target.addLegalDialect<LLVM::LLVMDialect>(); 909 910 // All operations from Async dialect must be lowered to the runtime API and 911 // LLVM intrinsics calls. 912 target.addIllegalDialect<AsyncDialect>(); 913 914 // Add dynamic legality constraints to apply conversions defined above. 915 target.addDynamicallyLegalOp<FuncOp>( 916 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 917 target.addDynamicallyLegalOp<ReturnOp>( 918 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 919 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 920 return converter.isSignatureLegal(op.getCalleeType()); 921 }); 922 923 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 924 signalPassFailure(); 925 } 926 927 //===----------------------------------------------------------------------===// 928 // Patterns for structural type conversions for the Async dialect operations. 929 //===----------------------------------------------------------------------===// 930 931 namespace { 932 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 933 public: 934 using OpConversionPattern::OpConversionPattern; 935 LogicalResult 936 matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands, 937 ConversionPatternRewriter &rewriter) const override { 938 ExecuteOp newOp = 939 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 940 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 941 newOp.getRegion().end()); 942 943 // Set operands and update block argument and result types. 944 newOp->setOperands(operands); 945 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 946 return failure(); 947 for (auto result : newOp.getResults()) 948 result.setType(typeConverter->convertType(result.getType())); 949 950 rewriter.replaceOp(op, newOp.getResults()); 951 return success(); 952 } 953 }; 954 955 // Dummy pattern to trigger the appropriate type conversion / materialization. 956 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 957 public: 958 using OpConversionPattern::OpConversionPattern; 959 LogicalResult 960 matchAndRewrite(AwaitOp op, ArrayRef<Value> operands, 961 ConversionPatternRewriter &rewriter) const override { 962 rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front()); 963 return success(); 964 } 965 }; 966 967 // Dummy pattern to trigger the appropriate type conversion / materialization. 968 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 969 public: 970 using OpConversionPattern::OpConversionPattern; 971 LogicalResult 972 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 973 ConversionPatternRewriter &rewriter) const override { 974 rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands); 975 return success(); 976 } 977 }; 978 } // namespace 979 980 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 981 return std::make_unique<ConvertAsyncToLLVMPass>(); 982 } 983 984 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 985 MLIRContext *context, TypeConverter &typeConverter, 986 OwningRewritePatternList &patterns, ConversionTarget &target) { 987 typeConverter.addConversion([&](TokenType type) { return type; }); 988 typeConverter.addConversion([&](ValueType type) { 989 return ValueType::get(typeConverter.convertType(type.getValueType())); 990 }); 991 992 patterns 993 .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 994 typeConverter, context); 995 996 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 997 [&](Operation *op) { return typeConverter.isLegal(op); }); 998 } 999