1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Async/IR/Async.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 #define DEBUG_TYPE "convert-async-to-llvm" 24 25 using namespace mlir; 26 using namespace mlir::async; 27 28 //===----------------------------------------------------------------------===// 29 // Async Runtime C API declaration. 30 //===----------------------------------------------------------------------===// 31 32 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 33 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 34 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 35 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 36 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 37 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 38 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 39 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 40 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 41 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 42 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 43 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 44 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 45 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 46 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 47 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 48 static constexpr const char *kGetValueStorage = 49 "mlirAsyncRuntimeGetValueStorage"; 50 static constexpr const char *kAddTokenToGroup = 51 "mlirAsyncRuntimeAddTokenToGroup"; 52 static constexpr const char *kAwaitTokenAndExecute = 53 "mlirAsyncRuntimeAwaitTokenAndExecute"; 54 static constexpr const char *kAwaitValueAndExecute = 55 "mlirAsyncRuntimeAwaitValueAndExecute"; 56 static constexpr const char *kAwaitAllAndExecute = 57 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 58 59 namespace { 60 /// Async Runtime API function types. 61 /// 62 /// Because we can't create API function signature for type parametrized 63 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 64 /// lowering all async data types become opaque pointers at runtime. 65 struct AsyncAPI { 66 // All async types are lowered to opaque i8* LLVM pointers at runtime. 67 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 68 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 69 } 70 71 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 72 return LLVM::LLVMTokenType::get(ctx); 73 } 74 75 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 76 auto ref = opaquePointerType(ctx); 77 auto count = IntegerType::get(ctx, 32); 78 return FunctionType::get(ctx, {ref, count}, {}); 79 } 80 81 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 82 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 83 } 84 85 static FunctionType createValueFunctionType(MLIRContext *ctx) { 86 auto i32 = IntegerType::get(ctx, 32); 87 auto value = opaquePointerType(ctx); 88 return FunctionType::get(ctx, {i32}, {value}); 89 } 90 91 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 92 auto i64 = IntegerType::get(ctx, 64); 93 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 94 } 95 96 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 97 auto value = opaquePointerType(ctx); 98 auto storage = opaquePointerType(ctx); 99 return FunctionType::get(ctx, {value}, {storage}); 100 } 101 102 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 103 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 104 } 105 106 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 107 auto value = opaquePointerType(ctx); 108 return FunctionType::get(ctx, {value}, {}); 109 } 110 111 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 112 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 113 } 114 115 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 116 auto value = opaquePointerType(ctx); 117 return FunctionType::get(ctx, {value}, {}); 118 } 119 120 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 121 auto i1 = IntegerType::get(ctx, 1); 122 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 123 } 124 125 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 126 auto value = opaquePointerType(ctx); 127 auto i1 = IntegerType::get(ctx, 1); 128 return FunctionType::get(ctx, {value}, {i1}); 129 } 130 131 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 132 auto i1 = IntegerType::get(ctx, 1); 133 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 134 } 135 136 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 137 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 138 } 139 140 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 141 auto value = opaquePointerType(ctx); 142 return FunctionType::get(ctx, {value}, {}); 143 } 144 145 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 146 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 147 } 148 149 static FunctionType executeFunctionType(MLIRContext *ctx) { 150 auto hdl = opaquePointerType(ctx); 151 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 152 return FunctionType::get(ctx, {hdl, resume}, {}); 153 } 154 155 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 156 auto i64 = IntegerType::get(ctx, 64); 157 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 158 {i64}); 159 } 160 161 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 162 auto hdl = opaquePointerType(ctx); 163 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 164 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 165 } 166 167 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 168 auto value = opaquePointerType(ctx); 169 auto hdl = opaquePointerType(ctx); 170 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 171 return FunctionType::get(ctx, {value, hdl, resume}, {}); 172 } 173 174 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 175 auto hdl = opaquePointerType(ctx); 176 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 177 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 178 } 179 180 // Auxiliary coroutine resume intrinsic wrapper. 181 static Type resumeFunctionType(MLIRContext *ctx) { 182 auto voidTy = LLVM::LLVMVoidType::get(ctx); 183 auto i8Ptr = opaquePointerType(ctx); 184 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 185 } 186 }; 187 } // namespace 188 189 /// Adds Async Runtime C API declarations to the module. 190 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 191 auto builder = 192 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 193 194 auto addFuncDecl = [&](StringRef name, FunctionType type) { 195 if (module.lookupSymbol(name)) 196 return; 197 builder.create<FuncOp>(name, type).setPrivate(); 198 }; 199 200 MLIRContext *ctx = module.getContext(); 201 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 202 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 203 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 204 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 205 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 206 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 207 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 208 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 209 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 210 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 211 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 212 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 213 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 214 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 215 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 216 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 217 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 218 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 219 addFuncDecl(kAwaitTokenAndExecute, 220 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 221 addFuncDecl(kAwaitValueAndExecute, 222 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 223 addFuncDecl(kAwaitAllAndExecute, 224 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // Add malloc/free declarations to the module. 229 //===----------------------------------------------------------------------===// 230 231 static constexpr const char *kMalloc = "malloc"; 232 static constexpr const char *kFree = "free"; 233 234 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 235 StringRef name, Type ret, ArrayRef<Type> params) { 236 if (module.lookupSymbol(name)) 237 return; 238 Type type = LLVM::LLVMFunctionType::get(ret, params); 239 builder.create<LLVM::LLVMFuncOp>(name, type); 240 } 241 242 /// Adds malloc/free declarations to the module. 243 static void addCRuntimeDeclarations(ModuleOp module) { 244 using namespace mlir::LLVM; 245 246 MLIRContext *ctx = module.getContext(); 247 auto builder = 248 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 249 250 auto voidTy = LLVMVoidType::get(ctx); 251 auto i64 = IntegerType::get(ctx, 64); 252 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 253 254 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 255 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 256 } 257 258 //===----------------------------------------------------------------------===// 259 // Coroutine resume function wrapper. 260 //===----------------------------------------------------------------------===// 261 262 static constexpr const char *kResume = "__resume"; 263 264 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 265 /// intrinsics. We need this function to be able to pass it to the async 266 /// runtime execute API. 267 static void addResumeFunction(ModuleOp module) { 268 if (module.lookupSymbol(kResume)) 269 return; 270 271 MLIRContext *ctx = module.getContext(); 272 auto loc = module.getLoc(); 273 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 274 275 auto voidTy = LLVM::LLVMVoidType::get(ctx); 276 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 277 278 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 279 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 280 resumeOp.setPrivate(); 281 282 auto *block = resumeOp.addEntryBlock(); 283 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 284 285 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 286 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // Convert Async dialect types to LLVM types. 291 //===----------------------------------------------------------------------===// 292 293 namespace { 294 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 295 /// their runtime type (opaque pointers) and does not convert any other types. 296 class AsyncRuntimeTypeConverter : public TypeConverter { 297 public: 298 AsyncRuntimeTypeConverter() { 299 addConversion([](Type type) { return type; }); 300 addConversion(convertAsyncTypes); 301 } 302 303 static Optional<Type> convertAsyncTypes(Type type) { 304 if (type.isa<TokenType, GroupType, ValueType>()) 305 return AsyncAPI::opaquePointerType(type.getContext()); 306 307 if (type.isa<CoroIdType, CoroStateType>()) 308 return AsyncAPI::tokenType(type.getContext()); 309 if (type.isa<CoroHandleType>()) 310 return AsyncAPI::opaquePointerType(type.getContext()); 311 312 return llvm::None; 313 } 314 }; 315 } // namespace 316 317 //===----------------------------------------------------------------------===// 318 // Convert async.coro.id to @llvm.coro.id intrinsic. 319 //===----------------------------------------------------------------------===// 320 321 namespace { 322 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 323 public: 324 using OpConversionPattern::OpConversionPattern; 325 326 LogicalResult 327 matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands, 328 ConversionPatternRewriter &rewriter) const override { 329 auto token = AsyncAPI::tokenType(op->getContext()); 330 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 331 auto loc = op->getLoc(); 332 333 // Constants for initializing coroutine frame. 334 auto constZero = rewriter.create<LLVM::ConstantOp>( 335 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 336 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 337 338 // Get coroutine id: @llvm.coro.id. 339 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 340 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 341 342 return success(); 343 } 344 }; 345 } // namespace 346 347 //===----------------------------------------------------------------------===// 348 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 349 //===----------------------------------------------------------------------===// 350 351 namespace { 352 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 353 public: 354 using OpConversionPattern::OpConversionPattern; 355 356 LogicalResult 357 matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands, 358 ConversionPatternRewriter &rewriter) const override { 359 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 360 auto loc = op->getLoc(); 361 362 // Get coroutine frame size: @llvm.coro.size.i64. 363 auto coroSize = 364 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 365 366 // Allocate memory for the coroutine frame. 367 auto coroAlloc = rewriter.create<LLVM::CallOp>( 368 loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc), 369 ValueRange(coroSize.getResult())); 370 371 // Begin a coroutine: @llvm.coro.begin. 372 auto coroId = CoroBeginOpAdaptor(operands).id(); 373 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 374 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 375 376 return success(); 377 } 378 }; 379 } // namespace 380 381 //===----------------------------------------------------------------------===// 382 // Convert async.coro.free to @llvm.coro.free intrinsic. 383 //===----------------------------------------------------------------------===// 384 385 namespace { 386 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 387 public: 388 using OpConversionPattern::OpConversionPattern; 389 390 LogicalResult 391 matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands, 392 ConversionPatternRewriter &rewriter) const override { 393 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 394 auto loc = op->getLoc(); 395 396 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 397 auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands); 398 399 // Free the memory. 400 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 401 rewriter.getSymbolRefAttr(kFree), 402 ValueRange(coroMem.getResult())); 403 404 return success(); 405 } 406 }; 407 } // namespace 408 409 //===----------------------------------------------------------------------===// 410 // Convert async.coro.end to @llvm.coro.end intrinsic. 411 //===----------------------------------------------------------------------===// 412 413 namespace { 414 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 415 public: 416 using OpConversionPattern::OpConversionPattern; 417 418 LogicalResult 419 matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands, 420 ConversionPatternRewriter &rewriter) const override { 421 // We are not in the block that is part of the unwind sequence. 422 auto constFalse = rewriter.create<LLVM::ConstantOp>( 423 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 424 425 // Mark the end of a coroutine: @llvm.coro.end. 426 auto coroHdl = CoroEndOpAdaptor(operands).handle(); 427 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 428 ValueRange({coroHdl, constFalse})); 429 rewriter.eraseOp(op); 430 431 return success(); 432 } 433 }; 434 } // namespace 435 436 //===----------------------------------------------------------------------===// 437 // Convert async.coro.save to @llvm.coro.save intrinsic. 438 //===----------------------------------------------------------------------===// 439 440 namespace { 441 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 442 public: 443 using OpConversionPattern::OpConversionPattern; 444 445 LogicalResult 446 matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands, 447 ConversionPatternRewriter &rewriter) const override { 448 // Save the coroutine state: @llvm.coro.save 449 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 450 op, AsyncAPI::tokenType(op->getContext()), operands); 451 452 return success(); 453 } 454 }; 455 } // namespace 456 457 //===----------------------------------------------------------------------===// 458 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 459 //===----------------------------------------------------------------------===// 460 461 namespace { 462 463 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 464 /// branch to the appropriate block based on the return code. 465 /// 466 /// Before: 467 /// 468 /// ^suspended: 469 /// "opBefore"(...) 470 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 471 /// ^resume: 472 /// "op"(...) 473 /// ^cleanup: ... 474 /// ^suspend: ... 475 /// 476 /// After: 477 /// 478 /// ^suspended: 479 /// "opBefore"(...) 480 /// %suspend = llmv.intr.coro.suspend ... 481 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 482 /// ^resume: 483 /// "op"(...) 484 /// ^cleanup: ... 485 /// ^suspend: ... 486 /// 487 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 488 public: 489 using OpConversionPattern::OpConversionPattern; 490 491 LogicalResult 492 matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands, 493 ConversionPatternRewriter &rewriter) const override { 494 auto i8 = rewriter.getIntegerType(8); 495 auto i32 = rewriter.getI32Type(); 496 auto loc = op->getLoc(); 497 498 // This is not a final suspension point. 499 auto constFalse = rewriter.create<LLVM::ConstantOp>( 500 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 501 502 // Suspend a coroutine: @llvm.coro.suspend 503 auto coroState = CoroSuspendOpAdaptor(operands).state(); 504 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 505 loc, i8, ValueRange({coroState, constFalse})); 506 507 // Cast return code to i32. 508 509 // After a suspension point decide if we should branch into resume, cleanup 510 // or suspend block of the coroutine (see @llvm.coro.suspend return code 511 // documentation). 512 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 513 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 514 op.cleanupDest()}; 515 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 516 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 517 /*defaultDestination=*/op.suspendDest(), 518 /*defaultOperands=*/ValueRange(), 519 /*caseValues=*/caseValues, 520 /*caseDestinations=*/caseDest, 521 /*caseOperands=*/ArrayRef<ValueRange>(), 522 /*branchWeights=*/ArrayRef<int32_t>()); 523 524 return success(); 525 } 526 }; 527 } // namespace 528 529 //===----------------------------------------------------------------------===// 530 // Convert async.runtime.create to the corresponding runtime API call. 531 // 532 // To allocate storage for the async values we use getelementptr trick: 533 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 534 //===----------------------------------------------------------------------===// 535 536 namespace { 537 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 538 public: 539 using OpConversionPattern::OpConversionPattern; 540 541 LogicalResult 542 matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands, 543 ConversionPatternRewriter &rewriter) const override { 544 TypeConverter *converter = getTypeConverter(); 545 Type resultType = op->getResultTypes()[0]; 546 547 // Tokens creation maps to a simple function call. 548 if (resultType.isa<TokenType>()) { 549 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken, 550 converter->convertType(resultType)); 551 return success(); 552 } 553 554 // To create a value we need to compute the storage requirement. 555 if (auto value = resultType.dyn_cast<ValueType>()) { 556 // Returns the size requirements for the async value storage. 557 auto sizeOf = [&](ValueType valueType) -> Value { 558 auto loc = op->getLoc(); 559 auto i32 = rewriter.getI32Type(); 560 561 auto storedType = converter->convertType(valueType.getValueType()); 562 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 563 564 // %Size = getelementptr %T* null, int 1 565 // %SizeI = ptrtoint %T* %Size to i32 566 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 567 auto one = rewriter.create<LLVM::ConstantOp>( 568 loc, i32, rewriter.getI32IntegerAttr(1)); 569 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 570 one.getResult()); 571 return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep); 572 }; 573 574 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 575 sizeOf(value)); 576 577 return success(); 578 } 579 580 return rewriter.notifyMatchFailure(op, "unsupported async type"); 581 } 582 }; 583 } // namespace 584 585 //===----------------------------------------------------------------------===// 586 // Convert async.runtime.create_group to the corresponding runtime API call. 587 //===----------------------------------------------------------------------===// 588 589 namespace { 590 class RuntimeCreateGroupOpLowering 591 : public OpConversionPattern<RuntimeCreateGroupOp> { 592 public: 593 using OpConversionPattern::OpConversionPattern; 594 595 LogicalResult 596 matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands, 597 ConversionPatternRewriter &rewriter) const override { 598 TypeConverter *converter = getTypeConverter(); 599 Type resultType = op.getResult().getType(); 600 601 rewriter.replaceOpWithNewOp<CallOp>( 602 op, kCreateGroup, converter->convertType(resultType), operands); 603 return success(); 604 } 605 }; 606 } // namespace 607 608 //===----------------------------------------------------------------------===// 609 // Convert async.runtime.set_available to the corresponding runtime API call. 610 //===----------------------------------------------------------------------===// 611 612 namespace { 613 class RuntimeSetAvailableOpLowering 614 : public OpConversionPattern<RuntimeSetAvailableOp> { 615 public: 616 using OpConversionPattern::OpConversionPattern; 617 618 LogicalResult 619 matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands, 620 ConversionPatternRewriter &rewriter) const override { 621 StringRef apiFuncName = 622 TypeSwitch<Type, StringRef>(op.operand().getType()) 623 .Case<TokenType>([](Type) { return kEmplaceToken; }) 624 .Case<ValueType>([](Type) { return kEmplaceValue; }); 625 626 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands); 627 628 return success(); 629 } 630 }; 631 } // namespace 632 633 //===----------------------------------------------------------------------===// 634 // Convert async.runtime.set_error to the corresponding runtime API call. 635 //===----------------------------------------------------------------------===// 636 637 namespace { 638 class RuntimeSetErrorOpLowering 639 : public OpConversionPattern<RuntimeSetErrorOp> { 640 public: 641 using OpConversionPattern::OpConversionPattern; 642 643 LogicalResult 644 matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands, 645 ConversionPatternRewriter &rewriter) const override { 646 StringRef apiFuncName = 647 TypeSwitch<Type, StringRef>(op.operand().getType()) 648 .Case<TokenType>([](Type) { return kSetTokenError; }) 649 .Case<ValueType>([](Type) { return kSetValueError; }); 650 651 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands); 652 653 return success(); 654 } 655 }; 656 } // namespace 657 658 //===----------------------------------------------------------------------===// 659 // Convert async.runtime.is_error to the corresponding runtime API call. 660 //===----------------------------------------------------------------------===// 661 662 namespace { 663 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 664 public: 665 using OpConversionPattern::OpConversionPattern; 666 667 LogicalResult 668 matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands, 669 ConversionPatternRewriter &rewriter) const override { 670 StringRef apiFuncName = 671 TypeSwitch<Type, StringRef>(op.operand().getType()) 672 .Case<TokenType>([](Type) { return kIsTokenError; }) 673 .Case<GroupType>([](Type) { return kIsGroupError; }) 674 .Case<ValueType>([](Type) { return kIsValueError; }); 675 676 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(), 677 operands); 678 return success(); 679 } 680 }; 681 } // namespace 682 683 //===----------------------------------------------------------------------===// 684 // Convert async.runtime.await to the corresponding runtime API call. 685 //===----------------------------------------------------------------------===// 686 687 namespace { 688 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 689 public: 690 using OpConversionPattern::OpConversionPattern; 691 692 LogicalResult 693 matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands, 694 ConversionPatternRewriter &rewriter) const override { 695 StringRef apiFuncName = 696 TypeSwitch<Type, StringRef>(op.operand().getType()) 697 .Case<TokenType>([](Type) { return kAwaitToken; }) 698 .Case<ValueType>([](Type) { return kAwaitValue; }) 699 .Case<GroupType>([](Type) { return kAwaitGroup; }); 700 701 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands); 702 rewriter.eraseOp(op); 703 704 return success(); 705 } 706 }; 707 } // namespace 708 709 //===----------------------------------------------------------------------===// 710 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 711 //===----------------------------------------------------------------------===// 712 713 namespace { 714 class RuntimeAwaitAndResumeOpLowering 715 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 716 public: 717 using OpConversionPattern::OpConversionPattern; 718 719 LogicalResult 720 matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands, 721 ConversionPatternRewriter &rewriter) const override { 722 StringRef apiFuncName = 723 TypeSwitch<Type, StringRef>(op.operand().getType()) 724 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 725 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 726 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 727 728 Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); 729 Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); 730 731 // A pointer to coroutine resume intrinsic wrapper. 732 addResumeFunction(op->getParentOfType<ModuleOp>()); 733 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 734 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 735 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 736 737 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 738 ValueRange({operand, handle, resumePtr.res()})); 739 rewriter.eraseOp(op); 740 741 return success(); 742 } 743 }; 744 } // namespace 745 746 //===----------------------------------------------------------------------===// 747 // Convert async.runtime.resume to the corresponding runtime API call. 748 //===----------------------------------------------------------------------===// 749 750 namespace { 751 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 752 public: 753 using OpConversionPattern::OpConversionPattern; 754 755 LogicalResult 756 matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands, 757 ConversionPatternRewriter &rewriter) const override { 758 // A pointer to coroutine resume intrinsic wrapper. 759 addResumeFunction(op->getParentOfType<ModuleOp>()); 760 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 761 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 762 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 763 764 // Call async runtime API to execute a coroutine in the managed thread. 765 auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); 766 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute, 767 ValueRange({coroHdl, resumePtr.res()})); 768 769 return success(); 770 } 771 }; 772 } // namespace 773 774 //===----------------------------------------------------------------------===// 775 // Convert async.runtime.store to the corresponding runtime API call. 776 //===----------------------------------------------------------------------===// 777 778 namespace { 779 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 780 public: 781 using OpConversionPattern::OpConversionPattern; 782 783 LogicalResult 784 matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands, 785 ConversionPatternRewriter &rewriter) const override { 786 Location loc = op->getLoc(); 787 788 // Get a pointer to the async value storage from the runtime. 789 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 790 auto storage = RuntimeStoreOpAdaptor(operands).storage(); 791 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 792 TypeRange(i8Ptr), storage); 793 794 // Cast from i8* to the LLVM pointer type. 795 auto valueType = op.value().getType(); 796 auto llvmValueType = getTypeConverter()->convertType(valueType); 797 if (!llvmValueType) 798 return rewriter.notifyMatchFailure( 799 op, "failed to convert stored value type to LLVM type"); 800 801 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 802 loc, LLVM::LLVMPointerType::get(llvmValueType), 803 storagePtr.getResult(0)); 804 805 // Store the yielded value into the async value storage. 806 auto value = RuntimeStoreOpAdaptor(operands).value(); 807 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 808 809 // Erase the original runtime store operation. 810 rewriter.eraseOp(op); 811 812 return success(); 813 } 814 }; 815 } // namespace 816 817 //===----------------------------------------------------------------------===// 818 // Convert async.runtime.load to the corresponding runtime API call. 819 //===----------------------------------------------------------------------===// 820 821 namespace { 822 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 823 public: 824 using OpConversionPattern::OpConversionPattern; 825 826 LogicalResult 827 matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands, 828 ConversionPatternRewriter &rewriter) const override { 829 Location loc = op->getLoc(); 830 831 // Get a pointer to the async value storage from the runtime. 832 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 833 auto storage = RuntimeLoadOpAdaptor(operands).storage(); 834 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 835 TypeRange(i8Ptr), storage); 836 837 // Cast from i8* to the LLVM pointer type. 838 auto valueType = op.result().getType(); 839 auto llvmValueType = getTypeConverter()->convertType(valueType); 840 if (!llvmValueType) 841 return rewriter.notifyMatchFailure( 842 op, "failed to convert loaded value type to LLVM type"); 843 844 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 845 loc, LLVM::LLVMPointerType::get(llvmValueType), 846 storagePtr.getResult(0)); 847 848 // Load from the casted pointer. 849 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 850 851 return success(); 852 } 853 }; 854 } // namespace 855 856 //===----------------------------------------------------------------------===// 857 // Convert async.runtime.add_to_group to the corresponding runtime API call. 858 //===----------------------------------------------------------------------===// 859 860 namespace { 861 class RuntimeAddToGroupOpLowering 862 : public OpConversionPattern<RuntimeAddToGroupOp> { 863 public: 864 using OpConversionPattern::OpConversionPattern; 865 866 LogicalResult 867 matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands, 868 ConversionPatternRewriter &rewriter) const override { 869 // Currently we can only add tokens to the group. 870 if (!op.operand().getType().isa<TokenType>()) 871 return rewriter.notifyMatchFailure(op, "only token type is supported"); 872 873 // Replace with a runtime API function call. 874 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, 875 rewriter.getI64Type(), operands); 876 877 return success(); 878 } 879 }; 880 } // namespace 881 882 //===----------------------------------------------------------------------===// 883 // Async reference counting ops lowering (`async.runtime.add_ref` and 884 // `async.runtime.drop_ref` to the corresponding API calls). 885 //===----------------------------------------------------------------------===// 886 887 namespace { 888 template <typename RefCountingOp> 889 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 890 public: 891 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 892 StringRef apiFunctionName) 893 : OpConversionPattern<RefCountingOp>(converter, ctx), 894 apiFunctionName(apiFunctionName) {} 895 896 LogicalResult 897 matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands, 898 ConversionPatternRewriter &rewriter) const override { 899 auto count = 900 rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(), 901 rewriter.getI32IntegerAttr(op.count())); 902 903 auto operand = typename RefCountingOp::Adaptor(operands).operand(); 904 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 905 ValueRange({operand, count})); 906 907 return success(); 908 } 909 910 private: 911 StringRef apiFunctionName; 912 }; 913 914 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 915 public: 916 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 917 : RefCountingOpLowering(converter, ctx, kAddRef) {} 918 }; 919 920 class RuntimeDropRefOpLowering 921 : public RefCountingOpLowering<RuntimeDropRefOp> { 922 public: 923 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 924 : RefCountingOpLowering(converter, ctx, kDropRef) {} 925 }; 926 } // namespace 927 928 //===----------------------------------------------------------------------===// 929 // Convert return operations that return async values from async regions. 930 //===----------------------------------------------------------------------===// 931 932 namespace { 933 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 934 public: 935 using OpConversionPattern::OpConversionPattern; 936 937 LogicalResult 938 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 939 ConversionPatternRewriter &rewriter) const override { 940 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 941 return success(); 942 } 943 }; 944 } // namespace 945 946 //===----------------------------------------------------------------------===// 947 948 namespace { 949 struct ConvertAsyncToLLVMPass 950 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 951 void runOnOperation() override; 952 }; 953 } // namespace 954 955 void ConvertAsyncToLLVMPass::runOnOperation() { 956 ModuleOp module = getOperation(); 957 MLIRContext *ctx = module->getContext(); 958 959 // Add declarations for most functions required by the coroutines lowering. 960 // We delay adding the resume function until it's needed because it currently 961 // fails to compile unless '-O0' is specified. 962 addAsyncRuntimeApiDeclarations(module); 963 addCRuntimeDeclarations(module); 964 965 // Lower async.runtime and async.coro operations to Async Runtime API and 966 // LLVM coroutine intrinsics. 967 968 // Convert async dialect types and operations to LLVM dialect. 969 AsyncRuntimeTypeConverter converter; 970 RewritePatternSet patterns(ctx); 971 972 // We use conversion to LLVM type to lower async.runtime load and store 973 // operations. 974 LLVMTypeConverter llvmConverter(ctx); 975 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 976 977 // Convert async types in function signatures and function calls. 978 populateFuncOpTypeConversionPattern(patterns, converter); 979 populateCallOpTypeConversionPattern(patterns, converter); 980 981 // Convert return operations inside async.execute regions. 982 patterns.add<ReturnOpOpConversion>(converter, ctx); 983 984 // Lower async.runtime operations to the async runtime API calls. 985 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 986 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 987 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 988 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 989 RuntimeDropRefOpLowering>(converter, ctx); 990 991 // Lower async.runtime operations that rely on LLVM type converter to convert 992 // from async value payload type to the LLVM type. 993 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 994 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter, 995 ctx); 996 997 // Lower async coroutine operations to LLVM coroutine intrinsics. 998 patterns 999 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 1000 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 1001 converter, ctx); 1002 1003 ConversionTarget target(*ctx); 1004 target.addLegalOp<ConstantOp>(); 1005 target.addLegalDialect<LLVM::LLVMDialect>(); 1006 1007 // All operations from Async dialect must be lowered to the runtime API and 1008 // LLVM intrinsics calls. 1009 target.addIllegalDialect<AsyncDialect>(); 1010 1011 // Add dynamic legality constraints to apply conversions defined above. 1012 target.addDynamicallyLegalOp<FuncOp>( 1013 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1014 target.addDynamicallyLegalOp<ReturnOp>( 1015 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1016 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1017 return converter.isSignatureLegal(op.getCalleeType()); 1018 }); 1019 1020 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1021 signalPassFailure(); 1022 } 1023 1024 //===----------------------------------------------------------------------===// 1025 // Patterns for structural type conversions for the Async dialect operations. 1026 //===----------------------------------------------------------------------===// 1027 1028 namespace { 1029 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1030 public: 1031 using OpConversionPattern::OpConversionPattern; 1032 LogicalResult 1033 matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands, 1034 ConversionPatternRewriter &rewriter) const override { 1035 ExecuteOp newOp = 1036 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1037 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1038 newOp.getRegion().end()); 1039 1040 // Set operands and update block argument and result types. 1041 newOp->setOperands(operands); 1042 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1043 return failure(); 1044 for (auto result : newOp.getResults()) 1045 result.setType(typeConverter->convertType(result.getType())); 1046 1047 rewriter.replaceOp(op, newOp.getResults()); 1048 return success(); 1049 } 1050 }; 1051 1052 // Dummy pattern to trigger the appropriate type conversion / materialization. 1053 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1054 public: 1055 using OpConversionPattern::OpConversionPattern; 1056 LogicalResult 1057 matchAndRewrite(AwaitOp op, ArrayRef<Value> operands, 1058 ConversionPatternRewriter &rewriter) const override { 1059 rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front()); 1060 return success(); 1061 } 1062 }; 1063 1064 // Dummy pattern to trigger the appropriate type conversion / materialization. 1065 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1066 public: 1067 using OpConversionPattern::OpConversionPattern; 1068 LogicalResult 1069 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 1070 ConversionPatternRewriter &rewriter) const override { 1071 rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands); 1072 return success(); 1073 } 1074 }; 1075 } // namespace 1076 1077 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1078 return std::make_unique<ConvertAsyncToLLVMPass>(); 1079 } 1080 1081 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1082 TypeConverter &typeConverter, RewritePatternSet &patterns, 1083 ConversionTarget &target) { 1084 typeConverter.addConversion([&](TokenType type) { return type; }); 1085 typeConverter.addConversion([&](ValueType type) { 1086 Type converted = typeConverter.convertType(type.getValueType()); 1087 return converted ? ValueType::get(converted) : converted; 1088 }); 1089 1090 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1091 typeConverter, patterns.getContext()); 1092 1093 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1094 [&](Operation *op) { return typeConverter.isLegal(op); }); 1095 } 1096