1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Async/IR/Async.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 #define DEBUG_TYPE "convert-async-to-llvm" 24 25 using namespace mlir; 26 using namespace mlir::async; 27 28 //===----------------------------------------------------------------------===// 29 // Async Runtime C API declaration. 30 //===----------------------------------------------------------------------===// 31 32 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 33 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 34 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 35 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 36 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 37 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 38 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 39 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 40 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 41 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 42 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 43 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 44 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 45 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 46 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 47 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 48 static constexpr const char *kGetValueStorage = 49 "mlirAsyncRuntimeGetValueStorage"; 50 static constexpr const char *kAddTokenToGroup = 51 "mlirAsyncRuntimeAddTokenToGroup"; 52 static constexpr const char *kAwaitTokenAndExecute = 53 "mlirAsyncRuntimeAwaitTokenAndExecute"; 54 static constexpr const char *kAwaitValueAndExecute = 55 "mlirAsyncRuntimeAwaitValueAndExecute"; 56 static constexpr const char *kAwaitAllAndExecute = 57 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 58 59 namespace { 60 /// Async Runtime API function types. 61 /// 62 /// Because we can't create API function signature for type parametrized 63 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 64 /// lowering all async data types become opaque pointers at runtime. 65 struct AsyncAPI { 66 // All async types are lowered to opaque i8* LLVM pointers at runtime. 67 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 68 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 69 } 70 71 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 72 return LLVM::LLVMTokenType::get(ctx); 73 } 74 75 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 76 auto ref = opaquePointerType(ctx); 77 auto count = IntegerType::get(ctx, 32); 78 return FunctionType::get(ctx, {ref, count}, {}); 79 } 80 81 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 82 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 83 } 84 85 static FunctionType createValueFunctionType(MLIRContext *ctx) { 86 auto i32 = IntegerType::get(ctx, 32); 87 auto value = opaquePointerType(ctx); 88 return FunctionType::get(ctx, {i32}, {value}); 89 } 90 91 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 92 return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); 93 } 94 95 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 96 auto value = opaquePointerType(ctx); 97 auto storage = opaquePointerType(ctx); 98 return FunctionType::get(ctx, {value}, {storage}); 99 } 100 101 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 102 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 103 } 104 105 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 106 auto value = opaquePointerType(ctx); 107 return FunctionType::get(ctx, {value}, {}); 108 } 109 110 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 111 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 112 } 113 114 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 115 auto value = opaquePointerType(ctx); 116 return FunctionType::get(ctx, {value}, {}); 117 } 118 119 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 120 auto i1 = IntegerType::get(ctx, 1); 121 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 122 } 123 124 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 125 auto value = opaquePointerType(ctx); 126 auto i1 = IntegerType::get(ctx, 1); 127 return FunctionType::get(ctx, {value}, {i1}); 128 } 129 130 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 131 auto i1 = IntegerType::get(ctx, 1); 132 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 133 } 134 135 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 136 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 137 } 138 139 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 140 auto value = opaquePointerType(ctx); 141 return FunctionType::get(ctx, {value}, {}); 142 } 143 144 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 145 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 146 } 147 148 static FunctionType executeFunctionType(MLIRContext *ctx) { 149 auto hdl = opaquePointerType(ctx); 150 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 151 return FunctionType::get(ctx, {hdl, resume}, {}); 152 } 153 154 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 155 auto i64 = IntegerType::get(ctx, 64); 156 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 157 {i64}); 158 } 159 160 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 161 auto hdl = opaquePointerType(ctx); 162 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 163 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 164 } 165 166 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 167 auto value = opaquePointerType(ctx); 168 auto hdl = opaquePointerType(ctx); 169 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 170 return FunctionType::get(ctx, {value, hdl, resume}, {}); 171 } 172 173 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 174 auto hdl = opaquePointerType(ctx); 175 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 176 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 177 } 178 179 // Auxiliary coroutine resume intrinsic wrapper. 180 static Type resumeFunctionType(MLIRContext *ctx) { 181 auto voidTy = LLVM::LLVMVoidType::get(ctx); 182 auto i8Ptr = opaquePointerType(ctx); 183 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 184 } 185 }; 186 } // namespace 187 188 /// Adds Async Runtime C API declarations to the module. 189 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 190 auto builder = 191 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 192 193 auto addFuncDecl = [&](StringRef name, FunctionType type) { 194 if (module.lookupSymbol(name)) 195 return; 196 builder.create<FuncOp>(name, type).setPrivate(); 197 }; 198 199 MLIRContext *ctx = module.getContext(); 200 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 201 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 202 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 203 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 204 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 205 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 206 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 207 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 208 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 209 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 210 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 211 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 212 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 213 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 214 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 215 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 216 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 217 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 218 addFuncDecl(kAwaitTokenAndExecute, 219 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 220 addFuncDecl(kAwaitValueAndExecute, 221 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 222 addFuncDecl(kAwaitAllAndExecute, 223 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 224 } 225 226 //===----------------------------------------------------------------------===// 227 // Add malloc/free declarations to the module. 228 //===----------------------------------------------------------------------===// 229 230 static constexpr const char *kMalloc = "malloc"; 231 static constexpr const char *kFree = "free"; 232 233 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 234 StringRef name, Type ret, ArrayRef<Type> params) { 235 if (module.lookupSymbol(name)) 236 return; 237 Type type = LLVM::LLVMFunctionType::get(ret, params); 238 builder.create<LLVM::LLVMFuncOp>(name, type); 239 } 240 241 /// Adds malloc/free declarations to the module. 242 static void addCRuntimeDeclarations(ModuleOp module) { 243 using namespace mlir::LLVM; 244 245 MLIRContext *ctx = module.getContext(); 246 auto builder = 247 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 248 249 auto voidTy = LLVMVoidType::get(ctx); 250 auto i64 = IntegerType::get(ctx, 64); 251 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 252 253 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 254 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 255 } 256 257 //===----------------------------------------------------------------------===// 258 // Coroutine resume function wrapper. 259 //===----------------------------------------------------------------------===// 260 261 static constexpr const char *kResume = "__resume"; 262 263 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 264 /// intrinsics. We need this function to be able to pass it to the async 265 /// runtime execute API. 266 static void addResumeFunction(ModuleOp module) { 267 if (module.lookupSymbol(kResume)) 268 return; 269 270 MLIRContext *ctx = module.getContext(); 271 auto loc = module.getLoc(); 272 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 273 274 auto voidTy = LLVM::LLVMVoidType::get(ctx); 275 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 276 277 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 278 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 279 resumeOp.setPrivate(); 280 281 auto *block = resumeOp.addEntryBlock(); 282 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 283 284 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 285 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 286 } 287 288 //===----------------------------------------------------------------------===// 289 // Convert Async dialect types to LLVM types. 290 //===----------------------------------------------------------------------===// 291 292 namespace { 293 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 294 /// their runtime type (opaque pointers) and does not convert any other types. 295 class AsyncRuntimeTypeConverter : public TypeConverter { 296 public: 297 AsyncRuntimeTypeConverter() { 298 addConversion([](Type type) { return type; }); 299 addConversion(convertAsyncTypes); 300 } 301 302 static Optional<Type> convertAsyncTypes(Type type) { 303 if (type.isa<TokenType, GroupType, ValueType>()) 304 return AsyncAPI::opaquePointerType(type.getContext()); 305 306 if (type.isa<CoroIdType, CoroStateType>()) 307 return AsyncAPI::tokenType(type.getContext()); 308 if (type.isa<CoroHandleType>()) 309 return AsyncAPI::opaquePointerType(type.getContext()); 310 311 return llvm::None; 312 } 313 }; 314 } // namespace 315 316 //===----------------------------------------------------------------------===// 317 // Convert async.coro.id to @llvm.coro.id intrinsic. 318 //===----------------------------------------------------------------------===// 319 320 namespace { 321 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 322 public: 323 using OpConversionPattern::OpConversionPattern; 324 325 LogicalResult 326 matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands, 327 ConversionPatternRewriter &rewriter) const override { 328 auto token = AsyncAPI::tokenType(op->getContext()); 329 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 330 auto loc = op->getLoc(); 331 332 // Constants for initializing coroutine frame. 333 auto constZero = rewriter.create<LLVM::ConstantOp>( 334 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 335 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 336 337 // Get coroutine id: @llvm.coro.id. 338 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 339 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 340 341 return success(); 342 } 343 }; 344 } // namespace 345 346 //===----------------------------------------------------------------------===// 347 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 348 //===----------------------------------------------------------------------===// 349 350 namespace { 351 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 352 public: 353 using OpConversionPattern::OpConversionPattern; 354 355 LogicalResult 356 matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands, 357 ConversionPatternRewriter &rewriter) const override { 358 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 359 auto loc = op->getLoc(); 360 361 // Get coroutine frame size: @llvm.coro.size.i64. 362 auto coroSize = 363 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 364 365 // Allocate memory for the coroutine frame. 366 auto coroAlloc = rewriter.create<LLVM::CallOp>( 367 loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc), 368 ValueRange(coroSize.getResult())); 369 370 // Begin a coroutine: @llvm.coro.begin. 371 auto coroId = CoroBeginOpAdaptor(operands).id(); 372 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 373 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 374 375 return success(); 376 } 377 }; 378 } // namespace 379 380 //===----------------------------------------------------------------------===// 381 // Convert async.coro.free to @llvm.coro.free intrinsic. 382 //===----------------------------------------------------------------------===// 383 384 namespace { 385 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 386 public: 387 using OpConversionPattern::OpConversionPattern; 388 389 LogicalResult 390 matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands, 391 ConversionPatternRewriter &rewriter) const override { 392 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 393 auto loc = op->getLoc(); 394 395 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 396 auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands); 397 398 // Free the memory. 399 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 400 rewriter.getSymbolRefAttr(kFree), 401 ValueRange(coroMem.getResult())); 402 403 return success(); 404 } 405 }; 406 } // namespace 407 408 //===----------------------------------------------------------------------===// 409 // Convert async.coro.end to @llvm.coro.end intrinsic. 410 //===----------------------------------------------------------------------===// 411 412 namespace { 413 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 414 public: 415 using OpConversionPattern::OpConversionPattern; 416 417 LogicalResult 418 matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands, 419 ConversionPatternRewriter &rewriter) const override { 420 // We are not in the block that is part of the unwind sequence. 421 auto constFalse = rewriter.create<LLVM::ConstantOp>( 422 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 423 424 // Mark the end of a coroutine: @llvm.coro.end. 425 auto coroHdl = CoroEndOpAdaptor(operands).handle(); 426 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 427 ValueRange({coroHdl, constFalse})); 428 rewriter.eraseOp(op); 429 430 return success(); 431 } 432 }; 433 } // namespace 434 435 //===----------------------------------------------------------------------===// 436 // Convert async.coro.save to @llvm.coro.save intrinsic. 437 //===----------------------------------------------------------------------===// 438 439 namespace { 440 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 441 public: 442 using OpConversionPattern::OpConversionPattern; 443 444 LogicalResult 445 matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands, 446 ConversionPatternRewriter &rewriter) const override { 447 // Save the coroutine state: @llvm.coro.save 448 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 449 op, AsyncAPI::tokenType(op->getContext()), operands); 450 451 return success(); 452 } 453 }; 454 } // namespace 455 456 //===----------------------------------------------------------------------===// 457 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 458 //===----------------------------------------------------------------------===// 459 460 namespace { 461 462 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 463 /// branch to the appropriate block based on the return code. 464 /// 465 /// Before: 466 /// 467 /// ^suspended: 468 /// "opBefore"(...) 469 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 470 /// ^resume: 471 /// "op"(...) 472 /// ^cleanup: ... 473 /// ^suspend: ... 474 /// 475 /// After: 476 /// 477 /// ^suspended: 478 /// "opBefore"(...) 479 /// %suspend = llmv.intr.coro.suspend ... 480 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 481 /// ^resume: 482 /// "op"(...) 483 /// ^cleanup: ... 484 /// ^suspend: ... 485 /// 486 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 487 public: 488 using OpConversionPattern::OpConversionPattern; 489 490 LogicalResult 491 matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands, 492 ConversionPatternRewriter &rewriter) const override { 493 auto i8 = rewriter.getIntegerType(8); 494 auto i32 = rewriter.getI32Type(); 495 auto loc = op->getLoc(); 496 497 // This is not a final suspension point. 498 auto constFalse = rewriter.create<LLVM::ConstantOp>( 499 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 500 501 // Suspend a coroutine: @llvm.coro.suspend 502 auto coroState = CoroSuspendOpAdaptor(operands).state(); 503 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 504 loc, i8, ValueRange({coroState, constFalse})); 505 506 // Cast return code to i32. 507 508 // After a suspension point decide if we should branch into resume, cleanup 509 // or suspend block of the coroutine (see @llvm.coro.suspend return code 510 // documentation). 511 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 512 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 513 op.cleanupDest()}; 514 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 515 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 516 /*defaultDestination=*/op.suspendDest(), 517 /*defaultOperands=*/ValueRange(), 518 /*caseValues=*/caseValues, 519 /*caseDestinations=*/caseDest, 520 /*caseOperands=*/ArrayRef<ValueRange>(), 521 /*branchWeights=*/ArrayRef<int32_t>()); 522 523 return success(); 524 } 525 }; 526 } // namespace 527 528 //===----------------------------------------------------------------------===// 529 // Convert async.runtime.create to the corresponding runtime API call. 530 // 531 // To allocate storage for the async values we use getelementptr trick: 532 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 533 //===----------------------------------------------------------------------===// 534 535 namespace { 536 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 537 public: 538 using OpConversionPattern::OpConversionPattern; 539 540 LogicalResult 541 matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands, 542 ConversionPatternRewriter &rewriter) const override { 543 TypeConverter *converter = getTypeConverter(); 544 Type resultType = op->getResultTypes()[0]; 545 546 // Tokens and Groups lowered to function calls without arguments. 547 if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) { 548 rewriter.replaceOpWithNewOp<CallOp>( 549 op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup, 550 converter->convertType(resultType)); 551 return success(); 552 } 553 554 // To create a value we need to compute the storage requirement. 555 if (auto value = resultType.dyn_cast<ValueType>()) { 556 // Returns the size requirements for the async value storage. 557 auto sizeOf = [&](ValueType valueType) -> Value { 558 auto loc = op->getLoc(); 559 auto i32 = rewriter.getI32Type(); 560 561 auto storedType = converter->convertType(valueType.getValueType()); 562 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 563 564 // %Size = getelementptr %T* null, int 1 565 // %SizeI = ptrtoint %T* %Size to i32 566 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 567 auto one = rewriter.create<LLVM::ConstantOp>( 568 loc, i32, rewriter.getI32IntegerAttr(1)); 569 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 570 one.getResult()); 571 return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep); 572 }; 573 574 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 575 sizeOf(value)); 576 577 return success(); 578 } 579 580 return rewriter.notifyMatchFailure(op, "unsupported async type"); 581 } 582 }; 583 } // namespace 584 585 //===----------------------------------------------------------------------===// 586 // Convert async.runtime.set_available to the corresponding runtime API call. 587 //===----------------------------------------------------------------------===// 588 589 namespace { 590 class RuntimeSetAvailableOpLowering 591 : public OpConversionPattern<RuntimeSetAvailableOp> { 592 public: 593 using OpConversionPattern::OpConversionPattern; 594 595 LogicalResult 596 matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands, 597 ConversionPatternRewriter &rewriter) const override { 598 StringRef apiFuncName = 599 TypeSwitch<Type, StringRef>(op.operand().getType()) 600 .Case<TokenType>([](Type) { return kEmplaceToken; }) 601 .Case<ValueType>([](Type) { return kEmplaceValue; }); 602 603 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands); 604 605 return success(); 606 } 607 }; 608 } // namespace 609 610 //===----------------------------------------------------------------------===// 611 // Convert async.runtime.set_error to the corresponding runtime API call. 612 //===----------------------------------------------------------------------===// 613 614 namespace { 615 class RuntimeSetErrorOpLowering 616 : public OpConversionPattern<RuntimeSetErrorOp> { 617 public: 618 using OpConversionPattern::OpConversionPattern; 619 620 LogicalResult 621 matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands, 622 ConversionPatternRewriter &rewriter) const override { 623 StringRef apiFuncName = 624 TypeSwitch<Type, StringRef>(op.operand().getType()) 625 .Case<TokenType>([](Type) { return kSetTokenError; }) 626 .Case<ValueType>([](Type) { return kSetValueError; }); 627 628 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands); 629 630 return success(); 631 } 632 }; 633 } // namespace 634 635 //===----------------------------------------------------------------------===// 636 // Convert async.runtime.is_error to the corresponding runtime API call. 637 //===----------------------------------------------------------------------===// 638 639 namespace { 640 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 641 public: 642 using OpConversionPattern::OpConversionPattern; 643 644 LogicalResult 645 matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands, 646 ConversionPatternRewriter &rewriter) const override { 647 StringRef apiFuncName = 648 TypeSwitch<Type, StringRef>(op.operand().getType()) 649 .Case<TokenType>([](Type) { return kIsTokenError; }) 650 .Case<GroupType>([](Type) { return kIsGroupError; }) 651 .Case<ValueType>([](Type) { return kIsValueError; }); 652 653 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(), 654 operands); 655 return success(); 656 } 657 }; 658 } // namespace 659 660 //===----------------------------------------------------------------------===// 661 // Convert async.runtime.await to the corresponding runtime API call. 662 //===----------------------------------------------------------------------===// 663 664 namespace { 665 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 666 public: 667 using OpConversionPattern::OpConversionPattern; 668 669 LogicalResult 670 matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands, 671 ConversionPatternRewriter &rewriter) const override { 672 StringRef apiFuncName = 673 TypeSwitch<Type, StringRef>(op.operand().getType()) 674 .Case<TokenType>([](Type) { return kAwaitToken; }) 675 .Case<ValueType>([](Type) { return kAwaitValue; }) 676 .Case<GroupType>([](Type) { return kAwaitGroup; }); 677 678 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands); 679 rewriter.eraseOp(op); 680 681 return success(); 682 } 683 }; 684 } // namespace 685 686 //===----------------------------------------------------------------------===// 687 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 688 //===----------------------------------------------------------------------===// 689 690 namespace { 691 class RuntimeAwaitAndResumeOpLowering 692 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 693 public: 694 using OpConversionPattern::OpConversionPattern; 695 696 LogicalResult 697 matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands, 698 ConversionPatternRewriter &rewriter) const override { 699 StringRef apiFuncName = 700 TypeSwitch<Type, StringRef>(op.operand().getType()) 701 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 702 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 703 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 704 705 Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); 706 Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); 707 708 // A pointer to coroutine resume intrinsic wrapper. 709 addResumeFunction(op->getParentOfType<ModuleOp>()); 710 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 711 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 712 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 713 714 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 715 ValueRange({operand, handle, resumePtr.res()})); 716 rewriter.eraseOp(op); 717 718 return success(); 719 } 720 }; 721 } // namespace 722 723 //===----------------------------------------------------------------------===// 724 // Convert async.runtime.resume to the corresponding runtime API call. 725 //===----------------------------------------------------------------------===// 726 727 namespace { 728 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 729 public: 730 using OpConversionPattern::OpConversionPattern; 731 732 LogicalResult 733 matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands, 734 ConversionPatternRewriter &rewriter) const override { 735 // A pointer to coroutine resume intrinsic wrapper. 736 addResumeFunction(op->getParentOfType<ModuleOp>()); 737 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 738 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 739 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 740 741 // Call async runtime API to execute a coroutine in the managed thread. 742 auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); 743 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute, 744 ValueRange({coroHdl, resumePtr.res()})); 745 746 return success(); 747 } 748 }; 749 } // namespace 750 751 //===----------------------------------------------------------------------===// 752 // Convert async.runtime.store to the corresponding runtime API call. 753 //===----------------------------------------------------------------------===// 754 755 namespace { 756 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 757 public: 758 using OpConversionPattern::OpConversionPattern; 759 760 LogicalResult 761 matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands, 762 ConversionPatternRewriter &rewriter) const override { 763 Location loc = op->getLoc(); 764 765 // Get a pointer to the async value storage from the runtime. 766 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 767 auto storage = RuntimeStoreOpAdaptor(operands).storage(); 768 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 769 TypeRange(i8Ptr), storage); 770 771 // Cast from i8* to the LLVM pointer type. 772 auto valueType = op.value().getType(); 773 auto llvmValueType = getTypeConverter()->convertType(valueType); 774 if (!llvmValueType) 775 return rewriter.notifyMatchFailure( 776 op, "failed to convert stored value type to LLVM type"); 777 778 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 779 loc, LLVM::LLVMPointerType::get(llvmValueType), 780 storagePtr.getResult(0)); 781 782 // Store the yielded value into the async value storage. 783 auto value = RuntimeStoreOpAdaptor(operands).value(); 784 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 785 786 // Erase the original runtime store operation. 787 rewriter.eraseOp(op); 788 789 return success(); 790 } 791 }; 792 } // namespace 793 794 //===----------------------------------------------------------------------===// 795 // Convert async.runtime.load to the corresponding runtime API call. 796 //===----------------------------------------------------------------------===// 797 798 namespace { 799 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 800 public: 801 using OpConversionPattern::OpConversionPattern; 802 803 LogicalResult 804 matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands, 805 ConversionPatternRewriter &rewriter) const override { 806 Location loc = op->getLoc(); 807 808 // Get a pointer to the async value storage from the runtime. 809 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 810 auto storage = RuntimeLoadOpAdaptor(operands).storage(); 811 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 812 TypeRange(i8Ptr), storage); 813 814 // Cast from i8* to the LLVM pointer type. 815 auto valueType = op.result().getType(); 816 auto llvmValueType = getTypeConverter()->convertType(valueType); 817 if (!llvmValueType) 818 return rewriter.notifyMatchFailure( 819 op, "failed to convert loaded value type to LLVM type"); 820 821 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 822 loc, LLVM::LLVMPointerType::get(llvmValueType), 823 storagePtr.getResult(0)); 824 825 // Load from the casted pointer. 826 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 827 828 return success(); 829 } 830 }; 831 } // namespace 832 833 //===----------------------------------------------------------------------===// 834 // Convert async.runtime.add_to_group to the corresponding runtime API call. 835 //===----------------------------------------------------------------------===// 836 837 namespace { 838 class RuntimeAddToGroupOpLowering 839 : public OpConversionPattern<RuntimeAddToGroupOp> { 840 public: 841 using OpConversionPattern::OpConversionPattern; 842 843 LogicalResult 844 matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands, 845 ConversionPatternRewriter &rewriter) const override { 846 // Currently we can only add tokens to the group. 847 if (!op.operand().getType().isa<TokenType>()) 848 return rewriter.notifyMatchFailure(op, "only token type is supported"); 849 850 // Replace with a runtime API function call. 851 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, 852 rewriter.getI64Type(), operands); 853 854 return success(); 855 } 856 }; 857 } // namespace 858 859 //===----------------------------------------------------------------------===// 860 // Async reference counting ops lowering (`async.runtime.add_ref` and 861 // `async.runtime.drop_ref` to the corresponding API calls). 862 //===----------------------------------------------------------------------===// 863 864 namespace { 865 template <typename RefCountingOp> 866 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 867 public: 868 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 869 StringRef apiFunctionName) 870 : OpConversionPattern<RefCountingOp>(converter, ctx), 871 apiFunctionName(apiFunctionName) {} 872 873 LogicalResult 874 matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands, 875 ConversionPatternRewriter &rewriter) const override { 876 auto count = 877 rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(), 878 rewriter.getI32IntegerAttr(op.count())); 879 880 auto operand = typename RefCountingOp::Adaptor(operands).operand(); 881 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 882 ValueRange({operand, count})); 883 884 return success(); 885 } 886 887 private: 888 StringRef apiFunctionName; 889 }; 890 891 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 892 public: 893 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 894 : RefCountingOpLowering(converter, ctx, kAddRef) {} 895 }; 896 897 class RuntimeDropRefOpLowering 898 : public RefCountingOpLowering<RuntimeDropRefOp> { 899 public: 900 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 901 : RefCountingOpLowering(converter, ctx, kDropRef) {} 902 }; 903 } // namespace 904 905 //===----------------------------------------------------------------------===// 906 // Convert return operations that return async values from async regions. 907 //===----------------------------------------------------------------------===// 908 909 namespace { 910 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 911 public: 912 using OpConversionPattern::OpConversionPattern; 913 914 LogicalResult 915 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 916 ConversionPatternRewriter &rewriter) const override { 917 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 918 return success(); 919 } 920 }; 921 } // namespace 922 923 //===----------------------------------------------------------------------===// 924 925 namespace { 926 struct ConvertAsyncToLLVMPass 927 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 928 void runOnOperation() override; 929 }; 930 } // namespace 931 932 void ConvertAsyncToLLVMPass::runOnOperation() { 933 ModuleOp module = getOperation(); 934 MLIRContext *ctx = module->getContext(); 935 936 // Add declarations for most functions required by the coroutines lowering. 937 // We delay adding the resume function until it's needed because it currently 938 // fails to compile unless '-O0' is specified. 939 addAsyncRuntimeApiDeclarations(module); 940 addCRuntimeDeclarations(module); 941 942 // Lower async.runtime and async.coro operations to Async Runtime API and 943 // LLVM coroutine intrinsics. 944 945 // Convert async dialect types and operations to LLVM dialect. 946 AsyncRuntimeTypeConverter converter; 947 RewritePatternSet patterns(ctx); 948 949 // We use conversion to LLVM type to lower async.runtime load and store 950 // operations. 951 LLVMTypeConverter llvmConverter(ctx); 952 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 953 954 // Convert async types in function signatures and function calls. 955 populateFuncOpTypeConversionPattern(patterns, converter); 956 populateCallOpTypeConversionPattern(patterns, converter); 957 958 // Convert return operations inside async.execute regions. 959 patterns.add<ReturnOpOpConversion>(converter, ctx); 960 961 // Lower async.runtime operations to the async runtime API calls. 962 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 963 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 964 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 965 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 966 RuntimeDropRefOpLowering>(converter, ctx); 967 968 // Lower async.runtime operations that rely on LLVM type converter to convert 969 // from async value payload type to the LLVM type. 970 patterns.add<RuntimeCreateOpLowering, RuntimeStoreOpLowering, 971 RuntimeLoadOpLowering>(llvmConverter, ctx); 972 973 // Lower async coroutine operations to LLVM coroutine intrinsics. 974 patterns 975 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 976 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 977 converter, ctx); 978 979 ConversionTarget target(*ctx); 980 target.addLegalOp<ConstantOp>(); 981 target.addLegalDialect<LLVM::LLVMDialect>(); 982 983 // All operations from Async dialect must be lowered to the runtime API and 984 // LLVM intrinsics calls. 985 target.addIllegalDialect<AsyncDialect>(); 986 987 // Add dynamic legality constraints to apply conversions defined above. 988 target.addDynamicallyLegalOp<FuncOp>( 989 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 990 target.addDynamicallyLegalOp<ReturnOp>( 991 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 992 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 993 return converter.isSignatureLegal(op.getCalleeType()); 994 }); 995 996 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 997 signalPassFailure(); 998 } 999 1000 //===----------------------------------------------------------------------===// 1001 // Patterns for structural type conversions for the Async dialect operations. 1002 //===----------------------------------------------------------------------===// 1003 1004 namespace { 1005 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1006 public: 1007 using OpConversionPattern::OpConversionPattern; 1008 LogicalResult 1009 matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands, 1010 ConversionPatternRewriter &rewriter) const override { 1011 ExecuteOp newOp = 1012 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1013 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1014 newOp.getRegion().end()); 1015 1016 // Set operands and update block argument and result types. 1017 newOp->setOperands(operands); 1018 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1019 return failure(); 1020 for (auto result : newOp.getResults()) 1021 result.setType(typeConverter->convertType(result.getType())); 1022 1023 rewriter.replaceOp(op, newOp.getResults()); 1024 return success(); 1025 } 1026 }; 1027 1028 // Dummy pattern to trigger the appropriate type conversion / materialization. 1029 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1030 public: 1031 using OpConversionPattern::OpConversionPattern; 1032 LogicalResult 1033 matchAndRewrite(AwaitOp op, ArrayRef<Value> operands, 1034 ConversionPatternRewriter &rewriter) const override { 1035 rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front()); 1036 return success(); 1037 } 1038 }; 1039 1040 // Dummy pattern to trigger the appropriate type conversion / materialization. 1041 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1042 public: 1043 using OpConversionPattern::OpConversionPattern; 1044 LogicalResult 1045 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 1046 ConversionPatternRewriter &rewriter) const override { 1047 rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands); 1048 return success(); 1049 } 1050 }; 1051 } // namespace 1052 1053 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1054 return std::make_unique<ConvertAsyncToLLVMPass>(); 1055 } 1056 1057 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1058 TypeConverter &typeConverter, RewritePatternSet &patterns, 1059 ConversionTarget &target) { 1060 typeConverter.addConversion([&](TokenType type) { return type; }); 1061 typeConverter.addConversion([&](ValueType type) { 1062 Type converted = typeConverter.convertType(type.getValueType()); 1063 return converted ? ValueType::get(converted) : converted; 1064 }); 1065 1066 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1067 typeConverter, patterns.getContext()); 1068 1069 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1070 [&](Operation *op) { return typeConverter.isLegal(op); }); 1071 } 1072