1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Async/IR/Async.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 #define DEBUG_TYPE "convert-async-to-llvm" 23 24 using namespace mlir; 25 using namespace mlir::async; 26 27 //===----------------------------------------------------------------------===// 28 // Async Runtime C API declaration. 29 //===----------------------------------------------------------------------===// 30 31 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 32 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 33 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 34 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 35 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 36 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 37 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 39 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 42 static constexpr const char *kGetValueStorage = 43 "mlirAsyncRuntimeGetValueStorage"; 44 static constexpr const char *kAddTokenToGroup = 45 "mlirAsyncRuntimeAddTokenToGroup"; 46 static constexpr const char *kAwaitTokenAndExecute = 47 "mlirAsyncRuntimeAwaitTokenAndExecute"; 48 static constexpr const char *kAwaitValueAndExecute = 49 "mlirAsyncRuntimeAwaitValueAndExecute"; 50 static constexpr const char *kAwaitAllAndExecute = 51 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 52 53 namespace { 54 /// Async Runtime API function types. 55 /// 56 /// Because we can't create API function signature for type parametrized 57 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 58 /// lowering all async data types become opaque pointers at runtime. 59 struct AsyncAPI { 60 // All async types are lowered to opaque i8* LLVM pointers at runtime. 61 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 62 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 63 } 64 65 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 66 return LLVM::LLVMTokenType::get(ctx); 67 } 68 69 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 70 auto ref = opaquePointerType(ctx); 71 auto count = IntegerType::get(ctx, 32); 72 return FunctionType::get(ctx, {ref, count}, {}); 73 } 74 75 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 76 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 77 } 78 79 static FunctionType createValueFunctionType(MLIRContext *ctx) { 80 auto i32 = IntegerType::get(ctx, 32); 81 auto value = opaquePointerType(ctx); 82 return FunctionType::get(ctx, {i32}, {value}); 83 } 84 85 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 86 return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); 87 } 88 89 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 90 auto value = opaquePointerType(ctx); 91 auto storage = opaquePointerType(ctx); 92 return FunctionType::get(ctx, {value}, {storage}); 93 } 94 95 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 96 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 97 } 98 99 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 100 auto value = opaquePointerType(ctx); 101 return FunctionType::get(ctx, {value}, {}); 102 } 103 104 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 105 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 106 } 107 108 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 109 auto value = opaquePointerType(ctx); 110 return FunctionType::get(ctx, {value}, {}); 111 } 112 113 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 114 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 115 } 116 117 static FunctionType executeFunctionType(MLIRContext *ctx) { 118 auto hdl = opaquePointerType(ctx); 119 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 120 return FunctionType::get(ctx, {hdl, resume}, {}); 121 } 122 123 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 124 auto i64 = IntegerType::get(ctx, 64); 125 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 126 {i64}); 127 } 128 129 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 130 auto hdl = opaquePointerType(ctx); 131 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 132 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 133 } 134 135 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 136 auto value = opaquePointerType(ctx); 137 auto hdl = opaquePointerType(ctx); 138 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 139 return FunctionType::get(ctx, {value, hdl, resume}, {}); 140 } 141 142 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 143 auto hdl = opaquePointerType(ctx); 144 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 145 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 146 } 147 148 // Auxiliary coroutine resume intrinsic wrapper. 149 static Type resumeFunctionType(MLIRContext *ctx) { 150 auto voidTy = LLVM::LLVMVoidType::get(ctx); 151 auto i8Ptr = opaquePointerType(ctx); 152 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 153 } 154 }; 155 } // namespace 156 157 /// Adds Async Runtime C API declarations to the module. 158 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 159 auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), 160 module.getBody()); 161 162 auto addFuncDecl = [&](StringRef name, FunctionType type) { 163 if (module.lookupSymbol(name)) 164 return; 165 builder.create<FuncOp>(name, type).setPrivate(); 166 }; 167 168 MLIRContext *ctx = module.getContext(); 169 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 170 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 171 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 172 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 173 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 174 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 175 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 176 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 177 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 178 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 179 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 180 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 181 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 182 addFuncDecl(kAwaitTokenAndExecute, 183 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 184 addFuncDecl(kAwaitValueAndExecute, 185 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 186 addFuncDecl(kAwaitAllAndExecute, 187 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 188 } 189 190 //===----------------------------------------------------------------------===// 191 // Add malloc/free declarations to the module. 192 //===----------------------------------------------------------------------===// 193 194 static constexpr const char *kMalloc = "malloc"; 195 static constexpr const char *kFree = "free"; 196 197 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 198 StringRef name, Type ret, ArrayRef<Type> params) { 199 if (module.lookupSymbol(name)) 200 return; 201 Type type = LLVM::LLVMFunctionType::get(ret, params); 202 builder.create<LLVM::LLVMFuncOp>(name, type); 203 } 204 205 /// Adds malloc/free declarations to the module. 206 static void addCRuntimeDeclarations(ModuleOp module) { 207 using namespace mlir::LLVM; 208 209 MLIRContext *ctx = module.getContext(); 210 ImplicitLocOpBuilder builder(module.getLoc(), 211 module.getBody()->getTerminator()); 212 213 auto voidTy = LLVMVoidType::get(ctx); 214 auto i64 = IntegerType::get(ctx, 64); 215 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 216 217 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 218 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // Coroutine resume function wrapper. 223 //===----------------------------------------------------------------------===// 224 225 static constexpr const char *kResume = "__resume"; 226 227 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 228 /// intrinsics. We need this function to be able to pass it to the async 229 /// runtime execute API. 230 static void addResumeFunction(ModuleOp module) { 231 if (module.lookupSymbol(kResume)) 232 return; 233 234 MLIRContext *ctx = module.getContext(); 235 236 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 237 Location loc = module.getLoc(); 238 239 auto voidTy = LLVM::LLVMVoidType::get(ctx); 240 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 241 242 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 243 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 244 resumeOp.setPrivate(); 245 246 auto *block = resumeOp.addEntryBlock(); 247 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 248 249 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 250 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 251 } 252 253 //===----------------------------------------------------------------------===// 254 // Convert Async dialect types to LLVM types. 255 //===----------------------------------------------------------------------===// 256 257 namespace { 258 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 259 /// their runtime type (opaque pointers) and does not convert any other types. 260 class AsyncRuntimeTypeConverter : public TypeConverter { 261 public: 262 AsyncRuntimeTypeConverter() { 263 addConversion([](Type type) { return type; }); 264 addConversion(convertAsyncTypes); 265 } 266 267 static Optional<Type> convertAsyncTypes(Type type) { 268 if (type.isa<TokenType, GroupType, ValueType>()) 269 return AsyncAPI::opaquePointerType(type.getContext()); 270 271 if (type.isa<CoroIdType, CoroStateType>()) 272 return AsyncAPI::tokenType(type.getContext()); 273 if (type.isa<CoroHandleType>()) 274 return AsyncAPI::opaquePointerType(type.getContext()); 275 276 return llvm::None; 277 } 278 }; 279 } // namespace 280 281 //===----------------------------------------------------------------------===// 282 // Convert async.coro.id to @llvm.coro.id intrinsic. 283 //===----------------------------------------------------------------------===// 284 285 namespace { 286 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 287 public: 288 using OpConversionPattern::OpConversionPattern; 289 290 LogicalResult 291 matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands, 292 ConversionPatternRewriter &rewriter) const override { 293 auto token = AsyncAPI::tokenType(op->getContext()); 294 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 295 auto loc = op->getLoc(); 296 297 // Constants for initializing coroutine frame. 298 auto constZero = rewriter.create<LLVM::ConstantOp>( 299 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 300 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 301 302 // Get coroutine id: @llvm.coro.id. 303 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 304 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 305 306 return success(); 307 } 308 }; 309 } // namespace 310 311 //===----------------------------------------------------------------------===// 312 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 313 //===----------------------------------------------------------------------===// 314 315 namespace { 316 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 317 public: 318 using OpConversionPattern::OpConversionPattern; 319 320 LogicalResult 321 matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands, 322 ConversionPatternRewriter &rewriter) const override { 323 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 324 auto loc = op->getLoc(); 325 326 // Get coroutine frame size: @llvm.coro.size.i64. 327 auto coroSize = 328 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 329 330 // Allocate memory for the coroutine frame. 331 auto coroAlloc = rewriter.create<LLVM::CallOp>( 332 loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc), 333 ValueRange(coroSize.getResult())); 334 335 // Begin a coroutine: @llvm.coro.begin. 336 auto coroId = CoroBeginOpAdaptor(operands).id(); 337 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 338 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 339 340 return success(); 341 } 342 }; 343 } // namespace 344 345 //===----------------------------------------------------------------------===// 346 // Convert async.coro.free to @llvm.coro.free intrinsic. 347 //===----------------------------------------------------------------------===// 348 349 namespace { 350 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 351 public: 352 using OpConversionPattern::OpConversionPattern; 353 354 LogicalResult 355 matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands, 356 ConversionPatternRewriter &rewriter) const override { 357 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 358 auto loc = op->getLoc(); 359 360 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 361 auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands); 362 363 // Free the memory. 364 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 365 rewriter.getSymbolRefAttr(kFree), 366 ValueRange(coroMem.getResult())); 367 368 return success(); 369 } 370 }; 371 } // namespace 372 373 //===----------------------------------------------------------------------===// 374 // Convert async.coro.end to @llvm.coro.end intrinsic. 375 //===----------------------------------------------------------------------===// 376 377 namespace { 378 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 379 public: 380 using OpConversionPattern::OpConversionPattern; 381 382 LogicalResult 383 matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands, 384 ConversionPatternRewriter &rewriter) const override { 385 // We are not in the block that is part of the unwind sequence. 386 auto constFalse = rewriter.create<LLVM::ConstantOp>( 387 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 388 389 // Mark the end of a coroutine: @llvm.coro.end. 390 auto coroHdl = CoroEndOpAdaptor(operands).handle(); 391 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 392 ValueRange({coroHdl, constFalse})); 393 rewriter.eraseOp(op); 394 395 return success(); 396 } 397 }; 398 } // namespace 399 400 //===----------------------------------------------------------------------===// 401 // Convert async.coro.save to @llvm.coro.save intrinsic. 402 //===----------------------------------------------------------------------===// 403 404 namespace { 405 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 406 public: 407 using OpConversionPattern::OpConversionPattern; 408 409 LogicalResult 410 matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands, 411 ConversionPatternRewriter &rewriter) const override { 412 // Save the coroutine state: @llvm.coro.save 413 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 414 op, AsyncAPI::tokenType(op->getContext()), operands); 415 416 return success(); 417 } 418 }; 419 } // namespace 420 421 //===----------------------------------------------------------------------===// 422 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 423 //===----------------------------------------------------------------------===// 424 425 namespace { 426 427 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 428 /// branch to the appropriate block based on the return code. 429 /// 430 /// Before: 431 /// 432 /// ^suspended: 433 /// "opBefore"(...) 434 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 435 /// ^resume: 436 /// "op"(...) 437 /// ^cleanup: ... 438 /// ^suspend: ... 439 /// 440 /// After: 441 /// 442 /// ^suspended: 443 /// "opBefore"(...) 444 /// %suspend = llmv.intr.coro.suspend ... 445 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 446 /// ^resume: 447 /// "op"(...) 448 /// ^cleanup: ... 449 /// ^suspend: ... 450 /// 451 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 452 public: 453 using OpConversionPattern::OpConversionPattern; 454 455 LogicalResult 456 matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands, 457 ConversionPatternRewriter &rewriter) const override { 458 auto i8 = rewriter.getIntegerType(8); 459 auto i32 = rewriter.getI32Type(); 460 auto loc = op->getLoc(); 461 462 // This is not a final suspension point. 463 auto constFalse = rewriter.create<LLVM::ConstantOp>( 464 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 465 466 // Suspend a coroutine: @llvm.coro.suspend 467 auto coroState = CoroSuspendOpAdaptor(operands).state(); 468 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 469 loc, i8, ValueRange({coroState, constFalse})); 470 471 // Cast return code to i32. 472 473 // After a suspension point decide if we should branch into resume, cleanup 474 // or suspend block of the coroutine (see @llvm.coro.suspend return code 475 // documentation). 476 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 477 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 478 op.cleanupDest()}; 479 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 480 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 481 /*defaultDestination=*/op.suspendDest(), 482 /*defaultOperands=*/ValueRange(), 483 /*caseValues=*/caseValues, 484 /*caseDestinations=*/caseDest, 485 /*caseOperands=*/ArrayRef<ValueRange>(), 486 /*branchWeights=*/ArrayRef<int32_t>()); 487 488 return success(); 489 } 490 }; 491 } // namespace 492 493 //===----------------------------------------------------------------------===// 494 // Convert async.runtime.create to the corresponding runtime API call. 495 // 496 // To allocate storage for the async values we use getelementptr trick: 497 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 498 //===----------------------------------------------------------------------===// 499 500 namespace { 501 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 502 public: 503 using OpConversionPattern::OpConversionPattern; 504 505 LogicalResult 506 matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands, 507 ConversionPatternRewriter &rewriter) const override { 508 TypeConverter *converter = getTypeConverter(); 509 Type resultType = op->getResultTypes()[0]; 510 511 // Tokens and Groups lowered to function calls without arguments. 512 if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) { 513 rewriter.replaceOpWithNewOp<CallOp>( 514 op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup, 515 converter->convertType(resultType)); 516 return success(); 517 } 518 519 // To create a value we need to compute the storage requirement. 520 if (auto value = resultType.dyn_cast<ValueType>()) { 521 // Returns the size requirements for the async value storage. 522 auto sizeOf = [&](ValueType valueType) -> Value { 523 auto loc = op->getLoc(); 524 auto i32 = rewriter.getI32Type(); 525 526 auto storedType = converter->convertType(valueType.getValueType()); 527 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 528 529 // %Size = getelementptr %T* null, int 1 530 // %SizeI = ptrtoint %T* %Size to i32 531 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 532 auto one = rewriter.create<LLVM::ConstantOp>( 533 loc, i32, rewriter.getI32IntegerAttr(1)); 534 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 535 one.getResult()); 536 return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep); 537 }; 538 539 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 540 sizeOf(value)); 541 542 return success(); 543 } 544 545 return rewriter.notifyMatchFailure(op, "unsupported async type"); 546 } 547 }; 548 } // namespace 549 550 //===----------------------------------------------------------------------===// 551 // Convert async.runtime.set_available to the corresponding runtime API call. 552 //===----------------------------------------------------------------------===// 553 554 namespace { 555 class RuntimeSetAvailableOpLowering 556 : public OpConversionPattern<RuntimeSetAvailableOp> { 557 public: 558 using OpConversionPattern::OpConversionPattern; 559 560 LogicalResult 561 matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands, 562 ConversionPatternRewriter &rewriter) const override { 563 Type operandType = op.operand().getType(); 564 565 if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) { 566 rewriter.create<CallOp>(op->getLoc(), 567 operandType.isa<TokenType>() ? kEmplaceToken 568 : kEmplaceValue, 569 TypeRange(), operands); 570 rewriter.eraseOp(op); 571 return success(); 572 } 573 574 return rewriter.notifyMatchFailure(op, "unsupported async type"); 575 } 576 }; 577 } // namespace 578 579 //===----------------------------------------------------------------------===// 580 // Convert async.runtime.await to the corresponding runtime API call. 581 //===----------------------------------------------------------------------===// 582 583 namespace { 584 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 585 public: 586 using OpConversionPattern::OpConversionPattern; 587 588 LogicalResult 589 matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands, 590 ConversionPatternRewriter &rewriter) const override { 591 Type operandType = op.operand().getType(); 592 593 StringRef apiFuncName; 594 if (operandType.isa<TokenType>()) 595 apiFuncName = kAwaitToken; 596 else if (operandType.isa<ValueType>()) 597 apiFuncName = kAwaitValue; 598 else if (operandType.isa<GroupType>()) 599 apiFuncName = kAwaitGroup; 600 else 601 return rewriter.notifyMatchFailure(op, "unsupported async type"); 602 603 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands); 604 rewriter.eraseOp(op); 605 606 return success(); 607 } 608 }; 609 } // namespace 610 611 //===----------------------------------------------------------------------===// 612 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 613 //===----------------------------------------------------------------------===// 614 615 namespace { 616 class RuntimeAwaitAndResumeOpLowering 617 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 618 public: 619 using OpConversionPattern::OpConversionPattern; 620 621 LogicalResult 622 matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands, 623 ConversionPatternRewriter &rewriter) const override { 624 Type operandType = op.operand().getType(); 625 626 StringRef apiFuncName; 627 if (operandType.isa<TokenType>()) 628 apiFuncName = kAwaitTokenAndExecute; 629 else if (operandType.isa<ValueType>()) 630 apiFuncName = kAwaitValueAndExecute; 631 else if (operandType.isa<GroupType>()) 632 apiFuncName = kAwaitAllAndExecute; 633 else 634 return rewriter.notifyMatchFailure(op, "unsupported async type"); 635 636 Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); 637 Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); 638 639 // A pointer to coroutine resume intrinsic wrapper. 640 addResumeFunction(op->getParentOfType<ModuleOp>()); 641 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 642 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 643 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 644 645 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 646 ValueRange({operand, handle, resumePtr.res()})); 647 rewriter.eraseOp(op); 648 649 return success(); 650 } 651 }; 652 } // namespace 653 654 //===----------------------------------------------------------------------===// 655 // Convert async.runtime.resume to the corresponding runtime API call. 656 //===----------------------------------------------------------------------===// 657 658 namespace { 659 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 660 public: 661 using OpConversionPattern::OpConversionPattern; 662 663 LogicalResult 664 matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands, 665 ConversionPatternRewriter &rewriter) const override { 666 // A pointer to coroutine resume intrinsic wrapper. 667 addResumeFunction(op->getParentOfType<ModuleOp>()); 668 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 669 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 670 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 671 672 // Call async runtime API to execute a coroutine in the managed thread. 673 auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); 674 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute, 675 ValueRange({coroHdl, resumePtr.res()})); 676 677 return success(); 678 } 679 }; 680 } // namespace 681 682 //===----------------------------------------------------------------------===// 683 // Convert async.runtime.store to the corresponding runtime API call. 684 //===----------------------------------------------------------------------===// 685 686 namespace { 687 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 688 public: 689 using OpConversionPattern::OpConversionPattern; 690 691 LogicalResult 692 matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands, 693 ConversionPatternRewriter &rewriter) const override { 694 Location loc = op->getLoc(); 695 696 // Get a pointer to the async value storage from the runtime. 697 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 698 auto storage = RuntimeStoreOpAdaptor(operands).storage(); 699 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 700 TypeRange(i8Ptr), storage); 701 702 // Cast from i8* to the LLVM pointer type. 703 auto valueType = op.value().getType(); 704 auto llvmValueType = getTypeConverter()->convertType(valueType); 705 if (!llvmValueType) 706 return rewriter.notifyMatchFailure( 707 op, "failed to convert stored value type to LLVM type"); 708 709 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 710 loc, LLVM::LLVMPointerType::get(llvmValueType), 711 storagePtr.getResult(0)); 712 713 // Store the yielded value into the async value storage. 714 auto value = RuntimeStoreOpAdaptor(operands).value(); 715 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 716 717 // Erase the original runtime store operation. 718 rewriter.eraseOp(op); 719 720 return success(); 721 } 722 }; 723 } // namespace 724 725 //===----------------------------------------------------------------------===// 726 // Convert async.runtime.load to the corresponding runtime API call. 727 //===----------------------------------------------------------------------===// 728 729 namespace { 730 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 731 public: 732 using OpConversionPattern::OpConversionPattern; 733 734 LogicalResult 735 matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands, 736 ConversionPatternRewriter &rewriter) const override { 737 Location loc = op->getLoc(); 738 739 // Get a pointer to the async value storage from the runtime. 740 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 741 auto storage = RuntimeLoadOpAdaptor(operands).storage(); 742 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 743 TypeRange(i8Ptr), storage); 744 745 // Cast from i8* to the LLVM pointer type. 746 auto valueType = op.result().getType(); 747 auto llvmValueType = getTypeConverter()->convertType(valueType); 748 if (!llvmValueType) 749 return rewriter.notifyMatchFailure( 750 op, "failed to convert loaded value type to LLVM type"); 751 752 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 753 loc, LLVM::LLVMPointerType::get(llvmValueType), 754 storagePtr.getResult(0)); 755 756 // Load from the casted pointer. 757 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 758 759 return success(); 760 } 761 }; 762 } // namespace 763 764 //===----------------------------------------------------------------------===// 765 // Convert async.runtime.add_to_group to the corresponding runtime API call. 766 //===----------------------------------------------------------------------===// 767 768 namespace { 769 class RuntimeAddToGroupOpLowering 770 : public OpConversionPattern<RuntimeAddToGroupOp> { 771 public: 772 using OpConversionPattern::OpConversionPattern; 773 774 LogicalResult 775 matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands, 776 ConversionPatternRewriter &rewriter) const override { 777 // Currently we can only add tokens to the group. 778 if (!op.operand().getType().isa<TokenType>()) 779 return rewriter.notifyMatchFailure(op, "only token type is supported"); 780 781 // Replace with a runtime API function call. 782 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, 783 rewriter.getI64Type(), operands); 784 785 return success(); 786 } 787 }; 788 } // namespace 789 790 //===----------------------------------------------------------------------===// 791 // Async reference counting ops lowering (`async.runtime.add_ref` and 792 // `async.runtime.drop_ref` to the corresponding API calls). 793 //===----------------------------------------------------------------------===// 794 795 namespace { 796 template <typename RefCountingOp> 797 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 798 public: 799 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 800 StringRef apiFunctionName) 801 : OpConversionPattern<RefCountingOp>(converter, ctx), 802 apiFunctionName(apiFunctionName) {} 803 804 LogicalResult 805 matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands, 806 ConversionPatternRewriter &rewriter) const override { 807 auto count = 808 rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(), 809 rewriter.getI32IntegerAttr(op.count())); 810 811 auto operand = typename RefCountingOp::Adaptor(operands).operand(); 812 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 813 ValueRange({operand, count})); 814 815 return success(); 816 } 817 818 private: 819 StringRef apiFunctionName; 820 }; 821 822 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 823 public: 824 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 825 : RefCountingOpLowering(converter, ctx, kAddRef) {} 826 }; 827 828 class RuntimeDropRefOpLowering 829 : public RefCountingOpLowering<RuntimeDropRefOp> { 830 public: 831 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 832 : RefCountingOpLowering(converter, ctx, kDropRef) {} 833 }; 834 } // namespace 835 836 //===----------------------------------------------------------------------===// 837 // Convert return operations that return async values from async regions. 838 //===----------------------------------------------------------------------===// 839 840 namespace { 841 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 842 public: 843 using OpConversionPattern::OpConversionPattern; 844 845 LogicalResult 846 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 847 ConversionPatternRewriter &rewriter) const override { 848 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 849 return success(); 850 } 851 }; 852 } // namespace 853 854 //===----------------------------------------------------------------------===// 855 856 namespace { 857 struct ConvertAsyncToLLVMPass 858 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 859 void runOnOperation() override; 860 }; 861 } // namespace 862 863 void ConvertAsyncToLLVMPass::runOnOperation() { 864 ModuleOp module = getOperation(); 865 MLIRContext *ctx = module->getContext(); 866 867 // Add declarations for most functions required by the coroutines lowering. 868 // We delay adding the resume function until it's needed because it currently 869 // fails to compile unless '-O0' is specified. 870 addAsyncRuntimeApiDeclarations(module); 871 addCRuntimeDeclarations(module); 872 873 // Lower async.runtime and async.coro operations to Async Runtime API and 874 // LLVM coroutine intrinsics. 875 876 // Convert async dialect types and operations to LLVM dialect. 877 AsyncRuntimeTypeConverter converter; 878 OwningRewritePatternList patterns; 879 880 // We use conversion to LLVM type to lower async.runtime load and store 881 // operations. 882 LLVMTypeConverter llvmConverter(ctx); 883 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 884 885 // Convert async types in function signatures and function calls. 886 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 887 populateCallOpTypeConversionPattern(patterns, ctx, converter); 888 889 // Convert return operations inside async.execute regions. 890 patterns.insert<ReturnOpOpConversion>(converter, ctx); 891 892 // Lower async.runtime operations to the async runtime API calls. 893 patterns.insert<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering, 894 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 895 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 896 RuntimeDropRefOpLowering>(converter, ctx); 897 898 // Lower async.runtime operations that rely on LLVM type converter to convert 899 // from async value payload type to the LLVM type. 900 patterns.insert<RuntimeCreateOpLowering, RuntimeStoreOpLowering, 901 RuntimeLoadOpLowering>(llvmConverter, ctx); 902 903 // Lower async coroutine operations to LLVM coroutine intrinsics. 904 patterns.insert<CoroIdOpConversion, CoroBeginOpConversion, 905 CoroFreeOpConversion, CoroEndOpConversion, 906 CoroSaveOpConversion, CoroSuspendOpConversion>(converter, 907 ctx); 908 909 ConversionTarget target(*ctx); 910 target.addLegalOp<ConstantOp>(); 911 target.addLegalDialect<LLVM::LLVMDialect>(); 912 913 // All operations from Async dialect must be lowered to the runtime API and 914 // LLVM intrinsics calls. 915 target.addIllegalDialect<AsyncDialect>(); 916 917 // Add dynamic legality constraints to apply conversions defined above. 918 target.addDynamicallyLegalOp<FuncOp>( 919 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 920 target.addDynamicallyLegalOp<ReturnOp>( 921 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 922 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 923 return converter.isSignatureLegal(op.getCalleeType()); 924 }); 925 926 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 927 signalPassFailure(); 928 } 929 930 //===----------------------------------------------------------------------===// 931 // Patterns for structural type conversions for the Async dialect operations. 932 //===----------------------------------------------------------------------===// 933 934 namespace { 935 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 936 public: 937 using OpConversionPattern::OpConversionPattern; 938 LogicalResult 939 matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands, 940 ConversionPatternRewriter &rewriter) const override { 941 ExecuteOp newOp = 942 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 943 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 944 newOp.getRegion().end()); 945 946 // Set operands and update block argument and result types. 947 newOp->setOperands(operands); 948 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 949 return failure(); 950 for (auto result : newOp.getResults()) 951 result.setType(typeConverter->convertType(result.getType())); 952 953 rewriter.replaceOp(op, newOp.getResults()); 954 return success(); 955 } 956 }; 957 958 // Dummy pattern to trigger the appropriate type conversion / materialization. 959 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 960 public: 961 using OpConversionPattern::OpConversionPattern; 962 LogicalResult 963 matchAndRewrite(AwaitOp op, ArrayRef<Value> operands, 964 ConversionPatternRewriter &rewriter) const override { 965 rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front()); 966 return success(); 967 } 968 }; 969 970 // Dummy pattern to trigger the appropriate type conversion / materialization. 971 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 972 public: 973 using OpConversionPattern::OpConversionPattern; 974 LogicalResult 975 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 976 ConversionPatternRewriter &rewriter) const override { 977 rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands); 978 return success(); 979 } 980 }; 981 } // namespace 982 983 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 984 return std::make_unique<ConvertAsyncToLLVMPass>(); 985 } 986 987 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 988 MLIRContext *context, TypeConverter &typeConverter, 989 OwningRewritePatternList &patterns, ConversionTarget &target) { 990 typeConverter.addConversion([&](TokenType type) { return type; }); 991 typeConverter.addConversion([&](ValueType type) { 992 return ValueType::get(typeConverter.convertType(type.getValueType())); 993 }); 994 995 patterns 996 .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 997 typeConverter, context); 998 999 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1000 [&](Operation *op) { return typeConverter.isLegal(op); }); 1001 } 1002