1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Async/IR/Async.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 #define DEBUG_TYPE "convert-async-to-llvm" 23 24 using namespace mlir; 25 using namespace mlir::async; 26 27 //===----------------------------------------------------------------------===// 28 // Async Runtime C API declaration. 29 //===----------------------------------------------------------------------===// 30 31 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 32 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 33 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 34 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 35 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 36 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 37 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 39 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 42 static constexpr const char *kGetValueStorage = 43 "mlirAsyncRuntimeGetValueStorage"; 44 static constexpr const char *kAddTokenToGroup = 45 "mlirAsyncRuntimeAddTokenToGroup"; 46 static constexpr const char *kAwaitTokenAndExecute = 47 "mlirAsyncRuntimeAwaitTokenAndExecute"; 48 static constexpr const char *kAwaitValueAndExecute = 49 "mlirAsyncRuntimeAwaitValueAndExecute"; 50 static constexpr const char *kAwaitAllAndExecute = 51 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 52 53 namespace { 54 /// Async Runtime API function types. 55 /// 56 /// Because we can't create API function signature for type parametrized 57 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 58 /// lowering all async data types become opaque pointers at runtime. 59 struct AsyncAPI { 60 // All async types are lowered to opaque i8* LLVM pointers at runtime. 61 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 62 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 63 } 64 65 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 66 return LLVM::LLVMTokenType::get(ctx); 67 } 68 69 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 70 auto ref = opaquePointerType(ctx); 71 auto count = IntegerType::get(ctx, 32); 72 return FunctionType::get(ctx, {ref, count}, {}); 73 } 74 75 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 76 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 77 } 78 79 static FunctionType createValueFunctionType(MLIRContext *ctx) { 80 auto i32 = IntegerType::get(ctx, 32); 81 auto value = opaquePointerType(ctx); 82 return FunctionType::get(ctx, {i32}, {value}); 83 } 84 85 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 86 return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); 87 } 88 89 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 90 auto value = opaquePointerType(ctx); 91 auto storage = opaquePointerType(ctx); 92 return FunctionType::get(ctx, {value}, {storage}); 93 } 94 95 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 96 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 97 } 98 99 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 100 auto value = opaquePointerType(ctx); 101 return FunctionType::get(ctx, {value}, {}); 102 } 103 104 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 105 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 106 } 107 108 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 109 auto value = opaquePointerType(ctx); 110 return FunctionType::get(ctx, {value}, {}); 111 } 112 113 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 114 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 115 } 116 117 static FunctionType executeFunctionType(MLIRContext *ctx) { 118 auto hdl = opaquePointerType(ctx); 119 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 120 return FunctionType::get(ctx, {hdl, resume}, {}); 121 } 122 123 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 124 auto i64 = IntegerType::get(ctx, 64); 125 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 126 {i64}); 127 } 128 129 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 130 auto hdl = opaquePointerType(ctx); 131 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 132 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 133 } 134 135 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 136 auto value = opaquePointerType(ctx); 137 auto hdl = opaquePointerType(ctx); 138 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 139 return FunctionType::get(ctx, {value, hdl, resume}, {}); 140 } 141 142 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 143 auto hdl = opaquePointerType(ctx); 144 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 145 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 146 } 147 148 // Auxiliary coroutine resume intrinsic wrapper. 149 static Type resumeFunctionType(MLIRContext *ctx) { 150 auto voidTy = LLVM::LLVMVoidType::get(ctx); 151 auto i8Ptr = opaquePointerType(ctx); 152 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 153 } 154 }; 155 } // namespace 156 157 /// Adds Async Runtime C API declarations to the module. 158 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 159 auto builder = 160 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 161 162 auto addFuncDecl = [&](StringRef name, FunctionType type) { 163 if (module.lookupSymbol(name)) 164 return; 165 builder.create<FuncOp>(name, type).setPrivate(); 166 }; 167 168 MLIRContext *ctx = module.getContext(); 169 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 170 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 171 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 172 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 173 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 174 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 175 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 176 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 177 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 178 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 179 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 180 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 181 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 182 addFuncDecl(kAwaitTokenAndExecute, 183 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 184 addFuncDecl(kAwaitValueAndExecute, 185 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 186 addFuncDecl(kAwaitAllAndExecute, 187 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 188 } 189 190 //===----------------------------------------------------------------------===// 191 // Add malloc/free declarations to the module. 192 //===----------------------------------------------------------------------===// 193 194 static constexpr const char *kMalloc = "malloc"; 195 static constexpr const char *kFree = "free"; 196 197 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 198 StringRef name, Type ret, ArrayRef<Type> params) { 199 if (module.lookupSymbol(name)) 200 return; 201 Type type = LLVM::LLVMFunctionType::get(ret, params); 202 builder.create<LLVM::LLVMFuncOp>(name, type); 203 } 204 205 /// Adds malloc/free declarations to the module. 206 static void addCRuntimeDeclarations(ModuleOp module) { 207 using namespace mlir::LLVM; 208 209 MLIRContext *ctx = module.getContext(); 210 auto builder = 211 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 212 213 auto voidTy = LLVMVoidType::get(ctx); 214 auto i64 = IntegerType::get(ctx, 64); 215 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 216 217 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 218 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // Coroutine resume function wrapper. 223 //===----------------------------------------------------------------------===// 224 225 static constexpr const char *kResume = "__resume"; 226 227 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 228 /// intrinsics. We need this function to be able to pass it to the async 229 /// runtime execute API. 230 static void addResumeFunction(ModuleOp module) { 231 if (module.lookupSymbol(kResume)) 232 return; 233 234 MLIRContext *ctx = module.getContext(); 235 auto loc = module.getLoc(); 236 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 237 238 auto voidTy = LLVM::LLVMVoidType::get(ctx); 239 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 240 241 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 242 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 243 resumeOp.setPrivate(); 244 245 auto *block = resumeOp.addEntryBlock(); 246 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 247 248 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 249 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 250 } 251 252 //===----------------------------------------------------------------------===// 253 // Convert Async dialect types to LLVM types. 254 //===----------------------------------------------------------------------===// 255 256 namespace { 257 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 258 /// their runtime type (opaque pointers) and does not convert any other types. 259 class AsyncRuntimeTypeConverter : public TypeConverter { 260 public: 261 AsyncRuntimeTypeConverter() { 262 addConversion([](Type type) { return type; }); 263 addConversion(convertAsyncTypes); 264 } 265 266 static Optional<Type> convertAsyncTypes(Type type) { 267 if (type.isa<TokenType, GroupType, ValueType>()) 268 return AsyncAPI::opaquePointerType(type.getContext()); 269 270 if (type.isa<CoroIdType, CoroStateType>()) 271 return AsyncAPI::tokenType(type.getContext()); 272 if (type.isa<CoroHandleType>()) 273 return AsyncAPI::opaquePointerType(type.getContext()); 274 275 return llvm::None; 276 } 277 }; 278 } // namespace 279 280 //===----------------------------------------------------------------------===// 281 // Convert async.coro.id to @llvm.coro.id intrinsic. 282 //===----------------------------------------------------------------------===// 283 284 namespace { 285 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 286 public: 287 using OpConversionPattern::OpConversionPattern; 288 289 LogicalResult 290 matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands, 291 ConversionPatternRewriter &rewriter) const override { 292 auto token = AsyncAPI::tokenType(op->getContext()); 293 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 294 auto loc = op->getLoc(); 295 296 // Constants for initializing coroutine frame. 297 auto constZero = rewriter.create<LLVM::ConstantOp>( 298 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 299 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 300 301 // Get coroutine id: @llvm.coro.id. 302 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 303 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 304 305 return success(); 306 } 307 }; 308 } // namespace 309 310 //===----------------------------------------------------------------------===// 311 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 312 //===----------------------------------------------------------------------===// 313 314 namespace { 315 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 316 public: 317 using OpConversionPattern::OpConversionPattern; 318 319 LogicalResult 320 matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands, 321 ConversionPatternRewriter &rewriter) const override { 322 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 323 auto loc = op->getLoc(); 324 325 // Get coroutine frame size: @llvm.coro.size.i64. 326 auto coroSize = 327 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 328 329 // Allocate memory for the coroutine frame. 330 auto coroAlloc = rewriter.create<LLVM::CallOp>( 331 loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc), 332 ValueRange(coroSize.getResult())); 333 334 // Begin a coroutine: @llvm.coro.begin. 335 auto coroId = CoroBeginOpAdaptor(operands).id(); 336 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 337 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 338 339 return success(); 340 } 341 }; 342 } // namespace 343 344 //===----------------------------------------------------------------------===// 345 // Convert async.coro.free to @llvm.coro.free intrinsic. 346 //===----------------------------------------------------------------------===// 347 348 namespace { 349 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 350 public: 351 using OpConversionPattern::OpConversionPattern; 352 353 LogicalResult 354 matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands, 355 ConversionPatternRewriter &rewriter) const override { 356 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 357 auto loc = op->getLoc(); 358 359 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 360 auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands); 361 362 // Free the memory. 363 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 364 rewriter.getSymbolRefAttr(kFree), 365 ValueRange(coroMem.getResult())); 366 367 return success(); 368 } 369 }; 370 } // namespace 371 372 //===----------------------------------------------------------------------===// 373 // Convert async.coro.end to @llvm.coro.end intrinsic. 374 //===----------------------------------------------------------------------===// 375 376 namespace { 377 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 378 public: 379 using OpConversionPattern::OpConversionPattern; 380 381 LogicalResult 382 matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands, 383 ConversionPatternRewriter &rewriter) const override { 384 // We are not in the block that is part of the unwind sequence. 385 auto constFalse = rewriter.create<LLVM::ConstantOp>( 386 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 387 388 // Mark the end of a coroutine: @llvm.coro.end. 389 auto coroHdl = CoroEndOpAdaptor(operands).handle(); 390 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 391 ValueRange({coroHdl, constFalse})); 392 rewriter.eraseOp(op); 393 394 return success(); 395 } 396 }; 397 } // namespace 398 399 //===----------------------------------------------------------------------===// 400 // Convert async.coro.save to @llvm.coro.save intrinsic. 401 //===----------------------------------------------------------------------===// 402 403 namespace { 404 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 405 public: 406 using OpConversionPattern::OpConversionPattern; 407 408 LogicalResult 409 matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands, 410 ConversionPatternRewriter &rewriter) const override { 411 // Save the coroutine state: @llvm.coro.save 412 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 413 op, AsyncAPI::tokenType(op->getContext()), operands); 414 415 return success(); 416 } 417 }; 418 } // namespace 419 420 //===----------------------------------------------------------------------===// 421 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 422 //===----------------------------------------------------------------------===// 423 424 namespace { 425 426 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 427 /// branch to the appropriate block based on the return code. 428 /// 429 /// Before: 430 /// 431 /// ^suspended: 432 /// "opBefore"(...) 433 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 434 /// ^resume: 435 /// "op"(...) 436 /// ^cleanup: ... 437 /// ^suspend: ... 438 /// 439 /// After: 440 /// 441 /// ^suspended: 442 /// "opBefore"(...) 443 /// %suspend = llmv.intr.coro.suspend ... 444 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 445 /// ^resume: 446 /// "op"(...) 447 /// ^cleanup: ... 448 /// ^suspend: ... 449 /// 450 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 451 public: 452 using OpConversionPattern::OpConversionPattern; 453 454 LogicalResult 455 matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands, 456 ConversionPatternRewriter &rewriter) const override { 457 auto i8 = rewriter.getIntegerType(8); 458 auto i32 = rewriter.getI32Type(); 459 auto loc = op->getLoc(); 460 461 // This is not a final suspension point. 462 auto constFalse = rewriter.create<LLVM::ConstantOp>( 463 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 464 465 // Suspend a coroutine: @llvm.coro.suspend 466 auto coroState = CoroSuspendOpAdaptor(operands).state(); 467 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 468 loc, i8, ValueRange({coroState, constFalse})); 469 470 // Cast return code to i32. 471 472 // After a suspension point decide if we should branch into resume, cleanup 473 // or suspend block of the coroutine (see @llvm.coro.suspend return code 474 // documentation). 475 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 476 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 477 op.cleanupDest()}; 478 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 479 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 480 /*defaultDestination=*/op.suspendDest(), 481 /*defaultOperands=*/ValueRange(), 482 /*caseValues=*/caseValues, 483 /*caseDestinations=*/caseDest, 484 /*caseOperands=*/ArrayRef<ValueRange>(), 485 /*branchWeights=*/ArrayRef<int32_t>()); 486 487 return success(); 488 } 489 }; 490 } // namespace 491 492 //===----------------------------------------------------------------------===// 493 // Convert async.runtime.create to the corresponding runtime API call. 494 // 495 // To allocate storage for the async values we use getelementptr trick: 496 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 497 //===----------------------------------------------------------------------===// 498 499 namespace { 500 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 501 public: 502 using OpConversionPattern::OpConversionPattern; 503 504 LogicalResult 505 matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands, 506 ConversionPatternRewriter &rewriter) const override { 507 TypeConverter *converter = getTypeConverter(); 508 Type resultType = op->getResultTypes()[0]; 509 510 // Tokens and Groups lowered to function calls without arguments. 511 if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) { 512 rewriter.replaceOpWithNewOp<CallOp>( 513 op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup, 514 converter->convertType(resultType)); 515 return success(); 516 } 517 518 // To create a value we need to compute the storage requirement. 519 if (auto value = resultType.dyn_cast<ValueType>()) { 520 // Returns the size requirements for the async value storage. 521 auto sizeOf = [&](ValueType valueType) -> Value { 522 auto loc = op->getLoc(); 523 auto i32 = rewriter.getI32Type(); 524 525 auto storedType = converter->convertType(valueType.getValueType()); 526 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 527 528 // %Size = getelementptr %T* null, int 1 529 // %SizeI = ptrtoint %T* %Size to i32 530 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 531 auto one = rewriter.create<LLVM::ConstantOp>( 532 loc, i32, rewriter.getI32IntegerAttr(1)); 533 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 534 one.getResult()); 535 return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep); 536 }; 537 538 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 539 sizeOf(value)); 540 541 return success(); 542 } 543 544 return rewriter.notifyMatchFailure(op, "unsupported async type"); 545 } 546 }; 547 } // namespace 548 549 //===----------------------------------------------------------------------===// 550 // Convert async.runtime.set_available to the corresponding runtime API call. 551 //===----------------------------------------------------------------------===// 552 553 namespace { 554 class RuntimeSetAvailableOpLowering 555 : public OpConversionPattern<RuntimeSetAvailableOp> { 556 public: 557 using OpConversionPattern::OpConversionPattern; 558 559 LogicalResult 560 matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands, 561 ConversionPatternRewriter &rewriter) const override { 562 Type operandType = op.operand().getType(); 563 564 if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) { 565 rewriter.create<CallOp>(op->getLoc(), 566 operandType.isa<TokenType>() ? kEmplaceToken 567 : kEmplaceValue, 568 TypeRange(), operands); 569 rewriter.eraseOp(op); 570 return success(); 571 } 572 573 return rewriter.notifyMatchFailure(op, "unsupported async type"); 574 } 575 }; 576 } // namespace 577 578 //===----------------------------------------------------------------------===// 579 // Convert async.runtime.await to the corresponding runtime API call. 580 //===----------------------------------------------------------------------===// 581 582 namespace { 583 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 584 public: 585 using OpConversionPattern::OpConversionPattern; 586 587 LogicalResult 588 matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands, 589 ConversionPatternRewriter &rewriter) const override { 590 Type operandType = op.operand().getType(); 591 592 StringRef apiFuncName; 593 if (operandType.isa<TokenType>()) 594 apiFuncName = kAwaitToken; 595 else if (operandType.isa<ValueType>()) 596 apiFuncName = kAwaitValue; 597 else if (operandType.isa<GroupType>()) 598 apiFuncName = kAwaitGroup; 599 else 600 return rewriter.notifyMatchFailure(op, "unsupported async type"); 601 602 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands); 603 rewriter.eraseOp(op); 604 605 return success(); 606 } 607 }; 608 } // namespace 609 610 //===----------------------------------------------------------------------===// 611 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 612 //===----------------------------------------------------------------------===// 613 614 namespace { 615 class RuntimeAwaitAndResumeOpLowering 616 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 617 public: 618 using OpConversionPattern::OpConversionPattern; 619 620 LogicalResult 621 matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands, 622 ConversionPatternRewriter &rewriter) const override { 623 Type operandType = op.operand().getType(); 624 625 StringRef apiFuncName; 626 if (operandType.isa<TokenType>()) 627 apiFuncName = kAwaitTokenAndExecute; 628 else if (operandType.isa<ValueType>()) 629 apiFuncName = kAwaitValueAndExecute; 630 else if (operandType.isa<GroupType>()) 631 apiFuncName = kAwaitAllAndExecute; 632 else 633 return rewriter.notifyMatchFailure(op, "unsupported async type"); 634 635 Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); 636 Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); 637 638 // A pointer to coroutine resume intrinsic wrapper. 639 addResumeFunction(op->getParentOfType<ModuleOp>()); 640 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 641 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 642 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 643 644 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 645 ValueRange({operand, handle, resumePtr.res()})); 646 rewriter.eraseOp(op); 647 648 return success(); 649 } 650 }; 651 } // namespace 652 653 //===----------------------------------------------------------------------===// 654 // Convert async.runtime.resume to the corresponding runtime API call. 655 //===----------------------------------------------------------------------===// 656 657 namespace { 658 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 659 public: 660 using OpConversionPattern::OpConversionPattern; 661 662 LogicalResult 663 matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands, 664 ConversionPatternRewriter &rewriter) const override { 665 // A pointer to coroutine resume intrinsic wrapper. 666 addResumeFunction(op->getParentOfType<ModuleOp>()); 667 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 668 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 669 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 670 671 // Call async runtime API to execute a coroutine in the managed thread. 672 auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); 673 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute, 674 ValueRange({coroHdl, resumePtr.res()})); 675 676 return success(); 677 } 678 }; 679 } // namespace 680 681 //===----------------------------------------------------------------------===// 682 // Convert async.runtime.store to the corresponding runtime API call. 683 //===----------------------------------------------------------------------===// 684 685 namespace { 686 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 687 public: 688 using OpConversionPattern::OpConversionPattern; 689 690 LogicalResult 691 matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands, 692 ConversionPatternRewriter &rewriter) const override { 693 Location loc = op->getLoc(); 694 695 // Get a pointer to the async value storage from the runtime. 696 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 697 auto storage = RuntimeStoreOpAdaptor(operands).storage(); 698 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 699 TypeRange(i8Ptr), storage); 700 701 // Cast from i8* to the LLVM pointer type. 702 auto valueType = op.value().getType(); 703 auto llvmValueType = getTypeConverter()->convertType(valueType); 704 if (!llvmValueType) 705 return rewriter.notifyMatchFailure( 706 op, "failed to convert stored value type to LLVM type"); 707 708 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 709 loc, LLVM::LLVMPointerType::get(llvmValueType), 710 storagePtr.getResult(0)); 711 712 // Store the yielded value into the async value storage. 713 auto value = RuntimeStoreOpAdaptor(operands).value(); 714 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 715 716 // Erase the original runtime store operation. 717 rewriter.eraseOp(op); 718 719 return success(); 720 } 721 }; 722 } // namespace 723 724 //===----------------------------------------------------------------------===// 725 // Convert async.runtime.load to the corresponding runtime API call. 726 //===----------------------------------------------------------------------===// 727 728 namespace { 729 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 730 public: 731 using OpConversionPattern::OpConversionPattern; 732 733 LogicalResult 734 matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands, 735 ConversionPatternRewriter &rewriter) const override { 736 Location loc = op->getLoc(); 737 738 // Get a pointer to the async value storage from the runtime. 739 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 740 auto storage = RuntimeLoadOpAdaptor(operands).storage(); 741 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 742 TypeRange(i8Ptr), storage); 743 744 // Cast from i8* to the LLVM pointer type. 745 auto valueType = op.result().getType(); 746 auto llvmValueType = getTypeConverter()->convertType(valueType); 747 if (!llvmValueType) 748 return rewriter.notifyMatchFailure( 749 op, "failed to convert loaded value type to LLVM type"); 750 751 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 752 loc, LLVM::LLVMPointerType::get(llvmValueType), 753 storagePtr.getResult(0)); 754 755 // Load from the casted pointer. 756 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 757 758 return success(); 759 } 760 }; 761 } // namespace 762 763 //===----------------------------------------------------------------------===// 764 // Convert async.runtime.add_to_group to the corresponding runtime API call. 765 //===----------------------------------------------------------------------===// 766 767 namespace { 768 class RuntimeAddToGroupOpLowering 769 : public OpConversionPattern<RuntimeAddToGroupOp> { 770 public: 771 using OpConversionPattern::OpConversionPattern; 772 773 LogicalResult 774 matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands, 775 ConversionPatternRewriter &rewriter) const override { 776 // Currently we can only add tokens to the group. 777 if (!op.operand().getType().isa<TokenType>()) 778 return rewriter.notifyMatchFailure(op, "only token type is supported"); 779 780 // Replace with a runtime API function call. 781 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, 782 rewriter.getI64Type(), operands); 783 784 return success(); 785 } 786 }; 787 } // namespace 788 789 //===----------------------------------------------------------------------===// 790 // Async reference counting ops lowering (`async.runtime.add_ref` and 791 // `async.runtime.drop_ref` to the corresponding API calls). 792 //===----------------------------------------------------------------------===// 793 794 namespace { 795 template <typename RefCountingOp> 796 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 797 public: 798 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 799 StringRef apiFunctionName) 800 : OpConversionPattern<RefCountingOp>(converter, ctx), 801 apiFunctionName(apiFunctionName) {} 802 803 LogicalResult 804 matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands, 805 ConversionPatternRewriter &rewriter) const override { 806 auto count = 807 rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(), 808 rewriter.getI32IntegerAttr(op.count())); 809 810 auto operand = typename RefCountingOp::Adaptor(operands).operand(); 811 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 812 ValueRange({operand, count})); 813 814 return success(); 815 } 816 817 private: 818 StringRef apiFunctionName; 819 }; 820 821 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 822 public: 823 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 824 : RefCountingOpLowering(converter, ctx, kAddRef) {} 825 }; 826 827 class RuntimeDropRefOpLowering 828 : public RefCountingOpLowering<RuntimeDropRefOp> { 829 public: 830 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 831 : RefCountingOpLowering(converter, ctx, kDropRef) {} 832 }; 833 } // namespace 834 835 //===----------------------------------------------------------------------===// 836 // Convert return operations that return async values from async regions. 837 //===----------------------------------------------------------------------===// 838 839 namespace { 840 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 841 public: 842 using OpConversionPattern::OpConversionPattern; 843 844 LogicalResult 845 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 846 ConversionPatternRewriter &rewriter) const override { 847 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 848 return success(); 849 } 850 }; 851 } // namespace 852 853 //===----------------------------------------------------------------------===// 854 855 namespace { 856 struct ConvertAsyncToLLVMPass 857 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 858 void runOnOperation() override; 859 }; 860 } // namespace 861 862 void ConvertAsyncToLLVMPass::runOnOperation() { 863 ModuleOp module = getOperation(); 864 MLIRContext *ctx = module->getContext(); 865 866 // Add declarations for most functions required by the coroutines lowering. 867 // We delay adding the resume function until it's needed because it currently 868 // fails to compile unless '-O0' is specified. 869 addAsyncRuntimeApiDeclarations(module); 870 addCRuntimeDeclarations(module); 871 872 // Lower async.runtime and async.coro operations to Async Runtime API and 873 // LLVM coroutine intrinsics. 874 875 // Convert async dialect types and operations to LLVM dialect. 876 AsyncRuntimeTypeConverter converter; 877 RewritePatternSet patterns(ctx); 878 879 // We use conversion to LLVM type to lower async.runtime load and store 880 // operations. 881 LLVMTypeConverter llvmConverter(ctx); 882 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 883 884 // Convert async types in function signatures and function calls. 885 populateFuncOpTypeConversionPattern(patterns, converter); 886 populateCallOpTypeConversionPattern(patterns, converter); 887 888 // Convert return operations inside async.execute regions. 889 patterns.add<ReturnOpOpConversion>(converter, ctx); 890 891 // Lower async.runtime operations to the async runtime API calls. 892 patterns.add<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering, 893 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 894 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 895 RuntimeDropRefOpLowering>(converter, ctx); 896 897 // Lower async.runtime operations that rely on LLVM type converter to convert 898 // from async value payload type to the LLVM type. 899 patterns.add<RuntimeCreateOpLowering, RuntimeStoreOpLowering, 900 RuntimeLoadOpLowering>(llvmConverter, ctx); 901 902 // Lower async coroutine operations to LLVM coroutine intrinsics. 903 patterns 904 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 905 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 906 converter, ctx); 907 908 ConversionTarget target(*ctx); 909 target.addLegalOp<ConstantOp>(); 910 target.addLegalDialect<LLVM::LLVMDialect>(); 911 912 // All operations from Async dialect must be lowered to the runtime API and 913 // LLVM intrinsics calls. 914 target.addIllegalDialect<AsyncDialect>(); 915 916 // Add dynamic legality constraints to apply conversions defined above. 917 target.addDynamicallyLegalOp<FuncOp>( 918 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 919 target.addDynamicallyLegalOp<ReturnOp>( 920 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 921 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 922 return converter.isSignatureLegal(op.getCalleeType()); 923 }); 924 925 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 926 signalPassFailure(); 927 } 928 929 //===----------------------------------------------------------------------===// 930 // Patterns for structural type conversions for the Async dialect operations. 931 //===----------------------------------------------------------------------===// 932 933 namespace { 934 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 935 public: 936 using OpConversionPattern::OpConversionPattern; 937 LogicalResult 938 matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands, 939 ConversionPatternRewriter &rewriter) const override { 940 ExecuteOp newOp = 941 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 942 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 943 newOp.getRegion().end()); 944 945 // Set operands and update block argument and result types. 946 newOp->setOperands(operands); 947 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 948 return failure(); 949 for (auto result : newOp.getResults()) 950 result.setType(typeConverter->convertType(result.getType())); 951 952 rewriter.replaceOp(op, newOp.getResults()); 953 return success(); 954 } 955 }; 956 957 // Dummy pattern to trigger the appropriate type conversion / materialization. 958 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 959 public: 960 using OpConversionPattern::OpConversionPattern; 961 LogicalResult 962 matchAndRewrite(AwaitOp op, ArrayRef<Value> operands, 963 ConversionPatternRewriter &rewriter) const override { 964 rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front()); 965 return success(); 966 } 967 }; 968 969 // Dummy pattern to trigger the appropriate type conversion / materialization. 970 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 971 public: 972 using OpConversionPattern::OpConversionPattern; 973 LogicalResult 974 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 975 ConversionPatternRewriter &rewriter) const override { 976 rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands); 977 return success(); 978 } 979 }; 980 } // namespace 981 982 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 983 return std::make_unique<ConvertAsyncToLLVMPass>(); 984 } 985 986 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 987 TypeConverter &typeConverter, RewritePatternSet &patterns, 988 ConversionTarget &target) { 989 typeConverter.addConversion([&](TokenType type) { return type; }); 990 typeConverter.addConversion([&](ValueType type) { 991 return ValueType::get(typeConverter.convertType(type.getValueType())); 992 }); 993 994 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 995 typeConverter, patterns.getContext()); 996 997 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 998 [&](Operation *op) { return typeConverter.isLegal(op); }); 999 } 1000