1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Dialect/Async/IR/Async.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 25 #define DEBUG_TYPE "convert-async-to-llvm" 26 27 using namespace mlir; 28 using namespace mlir::async; 29 30 //===----------------------------------------------------------------------===// 31 // Async Runtime C API declaration. 32 //===----------------------------------------------------------------------===// 33 34 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 35 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 37 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 38 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 39 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 40 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 41 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 42 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 43 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 44 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 45 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 46 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 47 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 48 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 49 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 50 static constexpr const char *kGetValueStorage = 51 "mlirAsyncRuntimeGetValueStorage"; 52 static constexpr const char *kAddTokenToGroup = 53 "mlirAsyncRuntimeAddTokenToGroup"; 54 static constexpr const char *kAwaitTokenAndExecute = 55 "mlirAsyncRuntimeAwaitTokenAndExecute"; 56 static constexpr const char *kAwaitValueAndExecute = 57 "mlirAsyncRuntimeAwaitValueAndExecute"; 58 static constexpr const char *kAwaitAllAndExecute = 59 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 60 61 namespace { 62 /// Async Runtime API function types. 63 /// 64 /// Because we can't create API function signature for type parametrized 65 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 66 /// lowering all async data types become opaque pointers at runtime. 67 struct AsyncAPI { 68 // All async types are lowered to opaque i8* LLVM pointers at runtime. 69 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 70 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 71 } 72 73 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 74 return LLVM::LLVMTokenType::get(ctx); 75 } 76 77 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 78 auto ref = opaquePointerType(ctx); 79 auto count = IntegerType::get(ctx, 32); 80 return FunctionType::get(ctx, {ref, count}, {}); 81 } 82 83 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 84 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 85 } 86 87 static FunctionType createValueFunctionType(MLIRContext *ctx) { 88 auto i32 = IntegerType::get(ctx, 32); 89 auto value = opaquePointerType(ctx); 90 return FunctionType::get(ctx, {i32}, {value}); 91 } 92 93 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 94 auto i64 = IntegerType::get(ctx, 64); 95 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 96 } 97 98 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 99 auto value = opaquePointerType(ctx); 100 auto storage = opaquePointerType(ctx); 101 return FunctionType::get(ctx, {value}, {storage}); 102 } 103 104 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 105 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 106 } 107 108 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 109 auto value = opaquePointerType(ctx); 110 return FunctionType::get(ctx, {value}, {}); 111 } 112 113 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 114 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 115 } 116 117 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 118 auto value = opaquePointerType(ctx); 119 return FunctionType::get(ctx, {value}, {}); 120 } 121 122 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 123 auto i1 = IntegerType::get(ctx, 1); 124 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 125 } 126 127 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 128 auto value = opaquePointerType(ctx); 129 auto i1 = IntegerType::get(ctx, 1); 130 return FunctionType::get(ctx, {value}, {i1}); 131 } 132 133 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 134 auto i1 = IntegerType::get(ctx, 1); 135 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 136 } 137 138 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 139 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 140 } 141 142 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 143 auto value = opaquePointerType(ctx); 144 return FunctionType::get(ctx, {value}, {}); 145 } 146 147 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 148 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 149 } 150 151 static FunctionType executeFunctionType(MLIRContext *ctx) { 152 auto hdl = opaquePointerType(ctx); 153 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 154 return FunctionType::get(ctx, {hdl, resume}, {}); 155 } 156 157 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 158 auto i64 = IntegerType::get(ctx, 64); 159 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 160 {i64}); 161 } 162 163 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 164 auto hdl = opaquePointerType(ctx); 165 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 166 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 167 } 168 169 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 170 auto value = opaquePointerType(ctx); 171 auto hdl = opaquePointerType(ctx); 172 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 173 return FunctionType::get(ctx, {value, hdl, resume}, {}); 174 } 175 176 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 177 auto hdl = opaquePointerType(ctx); 178 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 179 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 180 } 181 182 // Auxiliary coroutine resume intrinsic wrapper. 183 static Type resumeFunctionType(MLIRContext *ctx) { 184 auto voidTy = LLVM::LLVMVoidType::get(ctx); 185 auto i8Ptr = opaquePointerType(ctx); 186 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 187 } 188 }; 189 } // namespace 190 191 /// Adds Async Runtime C API declarations to the module. 192 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 193 auto builder = 194 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 195 196 auto addFuncDecl = [&](StringRef name, FunctionType type) { 197 if (module.lookupSymbol(name)) 198 return; 199 builder.create<FuncOp>(name, type).setPrivate(); 200 }; 201 202 MLIRContext *ctx = module.getContext(); 203 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 204 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 205 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 206 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 207 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 208 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 209 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 210 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 211 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 212 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 213 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 214 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 215 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 216 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 217 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 218 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 219 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 220 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 221 addFuncDecl(kAwaitTokenAndExecute, 222 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 223 addFuncDecl(kAwaitValueAndExecute, 224 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 225 addFuncDecl(kAwaitAllAndExecute, 226 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 227 } 228 229 //===----------------------------------------------------------------------===// 230 // Add malloc/free declarations to the module. 231 //===----------------------------------------------------------------------===// 232 233 static constexpr const char *kMalloc = "malloc"; 234 static constexpr const char *kFree = "free"; 235 236 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 237 StringRef name, Type ret, ArrayRef<Type> params) { 238 if (module.lookupSymbol(name)) 239 return; 240 Type type = LLVM::LLVMFunctionType::get(ret, params); 241 builder.create<LLVM::LLVMFuncOp>(name, type); 242 } 243 244 /// Adds malloc/free declarations to the module. 245 static void addCRuntimeDeclarations(ModuleOp module) { 246 using namespace mlir::LLVM; 247 248 MLIRContext *ctx = module.getContext(); 249 auto builder = 250 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 251 252 auto voidTy = LLVMVoidType::get(ctx); 253 auto i64 = IntegerType::get(ctx, 64); 254 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 255 256 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 257 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // Coroutine resume function wrapper. 262 //===----------------------------------------------------------------------===// 263 264 static constexpr const char *kResume = "__resume"; 265 266 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 267 /// intrinsics. We need this function to be able to pass it to the async 268 /// runtime execute API. 269 static void addResumeFunction(ModuleOp module) { 270 if (module.lookupSymbol(kResume)) 271 return; 272 273 MLIRContext *ctx = module.getContext(); 274 auto loc = module.getLoc(); 275 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 276 277 auto voidTy = LLVM::LLVMVoidType::get(ctx); 278 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 279 280 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 281 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 282 resumeOp.setPrivate(); 283 284 auto *block = resumeOp.addEntryBlock(); 285 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 286 287 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 288 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 289 } 290 291 //===----------------------------------------------------------------------===// 292 // Convert Async dialect types to LLVM types. 293 //===----------------------------------------------------------------------===// 294 295 namespace { 296 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 297 /// their runtime type (opaque pointers) and does not convert any other types. 298 class AsyncRuntimeTypeConverter : public TypeConverter { 299 public: 300 AsyncRuntimeTypeConverter() { 301 addConversion([](Type type) { return type; }); 302 addConversion(convertAsyncTypes); 303 } 304 305 static Optional<Type> convertAsyncTypes(Type type) { 306 if (type.isa<TokenType, GroupType, ValueType>()) 307 return AsyncAPI::opaquePointerType(type.getContext()); 308 309 if (type.isa<CoroIdType, CoroStateType>()) 310 return AsyncAPI::tokenType(type.getContext()); 311 if (type.isa<CoroHandleType>()) 312 return AsyncAPI::opaquePointerType(type.getContext()); 313 314 return llvm::None; 315 } 316 }; 317 } // namespace 318 319 //===----------------------------------------------------------------------===// 320 // Convert async.coro.id to @llvm.coro.id intrinsic. 321 //===----------------------------------------------------------------------===// 322 323 namespace { 324 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 325 public: 326 using OpConversionPattern::OpConversionPattern; 327 328 LogicalResult 329 matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands, 330 ConversionPatternRewriter &rewriter) const override { 331 auto token = AsyncAPI::tokenType(op->getContext()); 332 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 333 auto loc = op->getLoc(); 334 335 // Constants for initializing coroutine frame. 336 auto constZero = rewriter.create<LLVM::ConstantOp>( 337 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 338 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 339 340 // Get coroutine id: @llvm.coro.id. 341 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 342 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 343 344 return success(); 345 } 346 }; 347 } // namespace 348 349 //===----------------------------------------------------------------------===// 350 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 351 //===----------------------------------------------------------------------===// 352 353 namespace { 354 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 355 public: 356 using OpConversionPattern::OpConversionPattern; 357 358 LogicalResult 359 matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands, 360 ConversionPatternRewriter &rewriter) const override { 361 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 362 auto loc = op->getLoc(); 363 364 // Get coroutine frame size: @llvm.coro.size.i64. 365 auto coroSize = 366 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 367 368 // Allocate memory for the coroutine frame. 369 auto coroAlloc = rewriter.create<LLVM::CallOp>( 370 loc, i8Ptr, SymbolRefAttr::get(rewriter.getContext(), kMalloc), 371 ValueRange(coroSize.getResult())); 372 373 // Begin a coroutine: @llvm.coro.begin. 374 auto coroId = CoroBeginOpAdaptor(operands).id(); 375 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 376 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 377 378 return success(); 379 } 380 }; 381 } // namespace 382 383 //===----------------------------------------------------------------------===// 384 // Convert async.coro.free to @llvm.coro.free intrinsic. 385 //===----------------------------------------------------------------------===// 386 387 namespace { 388 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 389 public: 390 using OpConversionPattern::OpConversionPattern; 391 392 LogicalResult 393 matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands, 394 ConversionPatternRewriter &rewriter) const override { 395 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 396 auto loc = op->getLoc(); 397 398 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 399 auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands); 400 401 // Free the memory. 402 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 403 op, TypeRange(), SymbolRefAttr::get(rewriter.getContext(), kFree), 404 ValueRange(coroMem.getResult())); 405 406 return success(); 407 } 408 }; 409 } // namespace 410 411 //===----------------------------------------------------------------------===// 412 // Convert async.coro.end to @llvm.coro.end intrinsic. 413 //===----------------------------------------------------------------------===// 414 415 namespace { 416 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 417 public: 418 using OpConversionPattern::OpConversionPattern; 419 420 LogicalResult 421 matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands, 422 ConversionPatternRewriter &rewriter) const override { 423 // We are not in the block that is part of the unwind sequence. 424 auto constFalse = rewriter.create<LLVM::ConstantOp>( 425 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 426 427 // Mark the end of a coroutine: @llvm.coro.end. 428 auto coroHdl = CoroEndOpAdaptor(operands).handle(); 429 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 430 ValueRange({coroHdl, constFalse})); 431 rewriter.eraseOp(op); 432 433 return success(); 434 } 435 }; 436 } // namespace 437 438 //===----------------------------------------------------------------------===// 439 // Convert async.coro.save to @llvm.coro.save intrinsic. 440 //===----------------------------------------------------------------------===// 441 442 namespace { 443 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 444 public: 445 using OpConversionPattern::OpConversionPattern; 446 447 LogicalResult 448 matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands, 449 ConversionPatternRewriter &rewriter) const override { 450 // Save the coroutine state: @llvm.coro.save 451 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 452 op, AsyncAPI::tokenType(op->getContext()), operands); 453 454 return success(); 455 } 456 }; 457 } // namespace 458 459 //===----------------------------------------------------------------------===// 460 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 461 //===----------------------------------------------------------------------===// 462 463 namespace { 464 465 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 466 /// branch to the appropriate block based on the return code. 467 /// 468 /// Before: 469 /// 470 /// ^suspended: 471 /// "opBefore"(...) 472 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 473 /// ^resume: 474 /// "op"(...) 475 /// ^cleanup: ... 476 /// ^suspend: ... 477 /// 478 /// After: 479 /// 480 /// ^suspended: 481 /// "opBefore"(...) 482 /// %suspend = llmv.intr.coro.suspend ... 483 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 484 /// ^resume: 485 /// "op"(...) 486 /// ^cleanup: ... 487 /// ^suspend: ... 488 /// 489 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 490 public: 491 using OpConversionPattern::OpConversionPattern; 492 493 LogicalResult 494 matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands, 495 ConversionPatternRewriter &rewriter) const override { 496 auto i8 = rewriter.getIntegerType(8); 497 auto i32 = rewriter.getI32Type(); 498 auto loc = op->getLoc(); 499 500 // This is not a final suspension point. 501 auto constFalse = rewriter.create<LLVM::ConstantOp>( 502 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 503 504 // Suspend a coroutine: @llvm.coro.suspend 505 auto coroState = CoroSuspendOpAdaptor(operands).state(); 506 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 507 loc, i8, ValueRange({coroState, constFalse})); 508 509 // Cast return code to i32. 510 511 // After a suspension point decide if we should branch into resume, cleanup 512 // or suspend block of the coroutine (see @llvm.coro.suspend return code 513 // documentation). 514 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 515 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 516 op.cleanupDest()}; 517 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 518 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 519 /*defaultDestination=*/op.suspendDest(), 520 /*defaultOperands=*/ValueRange(), 521 /*caseValues=*/caseValues, 522 /*caseDestinations=*/caseDest, 523 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), 524 /*branchWeights=*/ArrayRef<int32_t>()); 525 526 return success(); 527 } 528 }; 529 } // namespace 530 531 //===----------------------------------------------------------------------===// 532 // Convert async.runtime.create to the corresponding runtime API call. 533 // 534 // To allocate storage for the async values we use getelementptr trick: 535 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 536 //===----------------------------------------------------------------------===// 537 538 namespace { 539 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 540 public: 541 using OpConversionPattern::OpConversionPattern; 542 543 LogicalResult 544 matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands, 545 ConversionPatternRewriter &rewriter) const override { 546 TypeConverter *converter = getTypeConverter(); 547 Type resultType = op->getResultTypes()[0]; 548 549 // Tokens creation maps to a simple function call. 550 if (resultType.isa<TokenType>()) { 551 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken, 552 converter->convertType(resultType)); 553 return success(); 554 } 555 556 // To create a value we need to compute the storage requirement. 557 if (auto value = resultType.dyn_cast<ValueType>()) { 558 // Returns the size requirements for the async value storage. 559 auto sizeOf = [&](ValueType valueType) -> Value { 560 auto loc = op->getLoc(); 561 auto i32 = rewriter.getI32Type(); 562 563 auto storedType = converter->convertType(valueType.getValueType()); 564 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 565 566 // %Size = getelementptr %T* null, int 1 567 // %SizeI = ptrtoint %T* %Size to i32 568 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 569 auto one = rewriter.create<LLVM::ConstantOp>( 570 loc, i32, rewriter.getI32IntegerAttr(1)); 571 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 572 one.getResult()); 573 return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep); 574 }; 575 576 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 577 sizeOf(value)); 578 579 return success(); 580 } 581 582 return rewriter.notifyMatchFailure(op, "unsupported async type"); 583 } 584 }; 585 } // namespace 586 587 //===----------------------------------------------------------------------===// 588 // Convert async.runtime.create_group to the corresponding runtime API call. 589 //===----------------------------------------------------------------------===// 590 591 namespace { 592 class RuntimeCreateGroupOpLowering 593 : public OpConversionPattern<RuntimeCreateGroupOp> { 594 public: 595 using OpConversionPattern::OpConversionPattern; 596 597 LogicalResult 598 matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands, 599 ConversionPatternRewriter &rewriter) const override { 600 TypeConverter *converter = getTypeConverter(); 601 Type resultType = op.getResult().getType(); 602 603 rewriter.replaceOpWithNewOp<CallOp>( 604 op, kCreateGroup, converter->convertType(resultType), operands); 605 return success(); 606 } 607 }; 608 } // namespace 609 610 //===----------------------------------------------------------------------===// 611 // Convert async.runtime.set_available to the corresponding runtime API call. 612 //===----------------------------------------------------------------------===// 613 614 namespace { 615 class RuntimeSetAvailableOpLowering 616 : public OpConversionPattern<RuntimeSetAvailableOp> { 617 public: 618 using OpConversionPattern::OpConversionPattern; 619 620 LogicalResult 621 matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands, 622 ConversionPatternRewriter &rewriter) const override { 623 StringRef apiFuncName = 624 TypeSwitch<Type, StringRef>(op.operand().getType()) 625 .Case<TokenType>([](Type) { return kEmplaceToken; }) 626 .Case<ValueType>([](Type) { return kEmplaceValue; }); 627 628 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands); 629 630 return success(); 631 } 632 }; 633 } // namespace 634 635 //===----------------------------------------------------------------------===// 636 // Convert async.runtime.set_error to the corresponding runtime API call. 637 //===----------------------------------------------------------------------===// 638 639 namespace { 640 class RuntimeSetErrorOpLowering 641 : public OpConversionPattern<RuntimeSetErrorOp> { 642 public: 643 using OpConversionPattern::OpConversionPattern; 644 645 LogicalResult 646 matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands, 647 ConversionPatternRewriter &rewriter) const override { 648 StringRef apiFuncName = 649 TypeSwitch<Type, StringRef>(op.operand().getType()) 650 .Case<TokenType>([](Type) { return kSetTokenError; }) 651 .Case<ValueType>([](Type) { return kSetValueError; }); 652 653 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands); 654 655 return success(); 656 } 657 }; 658 } // namespace 659 660 //===----------------------------------------------------------------------===// 661 // Convert async.runtime.is_error to the corresponding runtime API call. 662 //===----------------------------------------------------------------------===// 663 664 namespace { 665 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 666 public: 667 using OpConversionPattern::OpConversionPattern; 668 669 LogicalResult 670 matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands, 671 ConversionPatternRewriter &rewriter) const override { 672 StringRef apiFuncName = 673 TypeSwitch<Type, StringRef>(op.operand().getType()) 674 .Case<TokenType>([](Type) { return kIsTokenError; }) 675 .Case<GroupType>([](Type) { return kIsGroupError; }) 676 .Case<ValueType>([](Type) { return kIsValueError; }); 677 678 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(), 679 operands); 680 return success(); 681 } 682 }; 683 } // namespace 684 685 //===----------------------------------------------------------------------===// 686 // Convert async.runtime.await to the corresponding runtime API call. 687 //===----------------------------------------------------------------------===// 688 689 namespace { 690 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 691 public: 692 using OpConversionPattern::OpConversionPattern; 693 694 LogicalResult 695 matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands, 696 ConversionPatternRewriter &rewriter) const override { 697 StringRef apiFuncName = 698 TypeSwitch<Type, StringRef>(op.operand().getType()) 699 .Case<TokenType>([](Type) { return kAwaitToken; }) 700 .Case<ValueType>([](Type) { return kAwaitValue; }) 701 .Case<GroupType>([](Type) { return kAwaitGroup; }); 702 703 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands); 704 rewriter.eraseOp(op); 705 706 return success(); 707 } 708 }; 709 } // namespace 710 711 //===----------------------------------------------------------------------===// 712 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 713 //===----------------------------------------------------------------------===// 714 715 namespace { 716 class RuntimeAwaitAndResumeOpLowering 717 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 718 public: 719 using OpConversionPattern::OpConversionPattern; 720 721 LogicalResult 722 matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands, 723 ConversionPatternRewriter &rewriter) const override { 724 StringRef apiFuncName = 725 TypeSwitch<Type, StringRef>(op.operand().getType()) 726 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 727 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 728 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 729 730 Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); 731 Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); 732 733 // A pointer to coroutine resume intrinsic wrapper. 734 addResumeFunction(op->getParentOfType<ModuleOp>()); 735 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 736 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 737 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 738 739 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 740 ValueRange({operand, handle, resumePtr.res()})); 741 rewriter.eraseOp(op); 742 743 return success(); 744 } 745 }; 746 } // namespace 747 748 //===----------------------------------------------------------------------===// 749 // Convert async.runtime.resume to the corresponding runtime API call. 750 //===----------------------------------------------------------------------===// 751 752 namespace { 753 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 754 public: 755 using OpConversionPattern::OpConversionPattern; 756 757 LogicalResult 758 matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands, 759 ConversionPatternRewriter &rewriter) const override { 760 // A pointer to coroutine resume intrinsic wrapper. 761 addResumeFunction(op->getParentOfType<ModuleOp>()); 762 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 763 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 764 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 765 766 // Call async runtime API to execute a coroutine in the managed thread. 767 auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); 768 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute, 769 ValueRange({coroHdl, resumePtr.res()})); 770 771 return success(); 772 } 773 }; 774 } // namespace 775 776 //===----------------------------------------------------------------------===// 777 // Convert async.runtime.store to the corresponding runtime API call. 778 //===----------------------------------------------------------------------===// 779 780 namespace { 781 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 782 public: 783 using OpConversionPattern::OpConversionPattern; 784 785 LogicalResult 786 matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands, 787 ConversionPatternRewriter &rewriter) const override { 788 Location loc = op->getLoc(); 789 790 // Get a pointer to the async value storage from the runtime. 791 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 792 auto storage = RuntimeStoreOpAdaptor(operands).storage(); 793 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 794 TypeRange(i8Ptr), storage); 795 796 // Cast from i8* to the LLVM pointer type. 797 auto valueType = op.value().getType(); 798 auto llvmValueType = getTypeConverter()->convertType(valueType); 799 if (!llvmValueType) 800 return rewriter.notifyMatchFailure( 801 op, "failed to convert stored value type to LLVM type"); 802 803 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 804 loc, LLVM::LLVMPointerType::get(llvmValueType), 805 storagePtr.getResult(0)); 806 807 // Store the yielded value into the async value storage. 808 auto value = RuntimeStoreOpAdaptor(operands).value(); 809 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 810 811 // Erase the original runtime store operation. 812 rewriter.eraseOp(op); 813 814 return success(); 815 } 816 }; 817 } // namespace 818 819 //===----------------------------------------------------------------------===// 820 // Convert async.runtime.load to the corresponding runtime API call. 821 //===----------------------------------------------------------------------===// 822 823 namespace { 824 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 825 public: 826 using OpConversionPattern::OpConversionPattern; 827 828 LogicalResult 829 matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands, 830 ConversionPatternRewriter &rewriter) const override { 831 Location loc = op->getLoc(); 832 833 // Get a pointer to the async value storage from the runtime. 834 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 835 auto storage = RuntimeLoadOpAdaptor(operands).storage(); 836 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 837 TypeRange(i8Ptr), storage); 838 839 // Cast from i8* to the LLVM pointer type. 840 auto valueType = op.result().getType(); 841 auto llvmValueType = getTypeConverter()->convertType(valueType); 842 if (!llvmValueType) 843 return rewriter.notifyMatchFailure( 844 op, "failed to convert loaded value type to LLVM type"); 845 846 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 847 loc, LLVM::LLVMPointerType::get(llvmValueType), 848 storagePtr.getResult(0)); 849 850 // Load from the casted pointer. 851 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 852 853 return success(); 854 } 855 }; 856 } // namespace 857 858 //===----------------------------------------------------------------------===// 859 // Convert async.runtime.add_to_group to the corresponding runtime API call. 860 //===----------------------------------------------------------------------===// 861 862 namespace { 863 class RuntimeAddToGroupOpLowering 864 : public OpConversionPattern<RuntimeAddToGroupOp> { 865 public: 866 using OpConversionPattern::OpConversionPattern; 867 868 LogicalResult 869 matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands, 870 ConversionPatternRewriter &rewriter) const override { 871 // Currently we can only add tokens to the group. 872 if (!op.operand().getType().isa<TokenType>()) 873 return rewriter.notifyMatchFailure(op, "only token type is supported"); 874 875 // Replace with a runtime API function call. 876 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, 877 rewriter.getI64Type(), operands); 878 879 return success(); 880 } 881 }; 882 } // namespace 883 884 //===----------------------------------------------------------------------===// 885 // Async reference counting ops lowering (`async.runtime.add_ref` and 886 // `async.runtime.drop_ref` to the corresponding API calls). 887 //===----------------------------------------------------------------------===// 888 889 namespace { 890 template <typename RefCountingOp> 891 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 892 public: 893 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 894 StringRef apiFunctionName) 895 : OpConversionPattern<RefCountingOp>(converter, ctx), 896 apiFunctionName(apiFunctionName) {} 897 898 LogicalResult 899 matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands, 900 ConversionPatternRewriter &rewriter) const override { 901 auto count = 902 rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(), 903 rewriter.getI32IntegerAttr(op.count())); 904 905 auto operand = typename RefCountingOp::Adaptor(operands).operand(); 906 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 907 ValueRange({operand, count})); 908 909 return success(); 910 } 911 912 private: 913 StringRef apiFunctionName; 914 }; 915 916 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 917 public: 918 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 919 : RefCountingOpLowering(converter, ctx, kAddRef) {} 920 }; 921 922 class RuntimeDropRefOpLowering 923 : public RefCountingOpLowering<RuntimeDropRefOp> { 924 public: 925 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 926 : RefCountingOpLowering(converter, ctx, kDropRef) {} 927 }; 928 } // namespace 929 930 //===----------------------------------------------------------------------===// 931 // Convert return operations that return async values from async regions. 932 //===----------------------------------------------------------------------===// 933 934 namespace { 935 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 936 public: 937 using OpConversionPattern::OpConversionPattern; 938 939 LogicalResult 940 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 941 ConversionPatternRewriter &rewriter) const override { 942 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 943 return success(); 944 } 945 }; 946 } // namespace 947 948 //===----------------------------------------------------------------------===// 949 950 namespace { 951 struct ConvertAsyncToLLVMPass 952 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 953 void runOnOperation() override; 954 }; 955 } // namespace 956 957 void ConvertAsyncToLLVMPass::runOnOperation() { 958 ModuleOp module = getOperation(); 959 MLIRContext *ctx = module->getContext(); 960 961 // Add declarations for most functions required by the coroutines lowering. 962 // We delay adding the resume function until it's needed because it currently 963 // fails to compile unless '-O0' is specified. 964 addAsyncRuntimeApiDeclarations(module); 965 addCRuntimeDeclarations(module); 966 967 // Lower async.runtime and async.coro operations to Async Runtime API and 968 // LLVM coroutine intrinsics. 969 970 // Convert async dialect types and operations to LLVM dialect. 971 AsyncRuntimeTypeConverter converter; 972 RewritePatternSet patterns(ctx); 973 974 // We use conversion to LLVM type to lower async.runtime load and store 975 // operations. 976 LLVMTypeConverter llvmConverter(ctx); 977 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 978 979 // Convert async types in function signatures and function calls. 980 populateFuncOpTypeConversionPattern(patterns, converter); 981 populateCallOpTypeConversionPattern(patterns, converter); 982 983 // Convert return operations inside async.execute regions. 984 patterns.add<ReturnOpOpConversion>(converter, ctx); 985 986 // Lower async.runtime operations to the async runtime API calls. 987 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 988 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 989 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 990 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 991 RuntimeDropRefOpLowering>(converter, ctx); 992 993 // Lower async.runtime operations that rely on LLVM type converter to convert 994 // from async value payload type to the LLVM type. 995 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 996 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter, 997 ctx); 998 999 // Lower async coroutine operations to LLVM coroutine intrinsics. 1000 patterns 1001 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 1002 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 1003 converter, ctx); 1004 1005 ConversionTarget target(*ctx); 1006 target.addLegalOp<ConstantOp, UnrealizedConversionCastOp>(); 1007 target.addLegalDialect<LLVM::LLVMDialect>(); 1008 1009 // All operations from Async dialect must be lowered to the runtime API and 1010 // LLVM intrinsics calls. 1011 target.addIllegalDialect<AsyncDialect>(); 1012 1013 // Add dynamic legality constraints to apply conversions defined above. 1014 target.addDynamicallyLegalOp<FuncOp>( 1015 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1016 target.addDynamicallyLegalOp<ReturnOp>( 1017 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1018 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1019 return converter.isSignatureLegal(op.getCalleeType()); 1020 }); 1021 1022 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1023 signalPassFailure(); 1024 } 1025 1026 //===----------------------------------------------------------------------===// 1027 // Patterns for structural type conversions for the Async dialect operations. 1028 //===----------------------------------------------------------------------===// 1029 1030 namespace { 1031 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1032 public: 1033 using OpConversionPattern::OpConversionPattern; 1034 LogicalResult 1035 matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands, 1036 ConversionPatternRewriter &rewriter) const override { 1037 ExecuteOp newOp = 1038 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1039 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1040 newOp.getRegion().end()); 1041 1042 // Set operands and update block argument and result types. 1043 newOp->setOperands(operands); 1044 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1045 return failure(); 1046 for (auto result : newOp.getResults()) 1047 result.setType(typeConverter->convertType(result.getType())); 1048 1049 rewriter.replaceOp(op, newOp.getResults()); 1050 return success(); 1051 } 1052 }; 1053 1054 // Dummy pattern to trigger the appropriate type conversion / materialization. 1055 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1056 public: 1057 using OpConversionPattern::OpConversionPattern; 1058 LogicalResult 1059 matchAndRewrite(AwaitOp op, ArrayRef<Value> operands, 1060 ConversionPatternRewriter &rewriter) const override { 1061 rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front()); 1062 return success(); 1063 } 1064 }; 1065 1066 // Dummy pattern to trigger the appropriate type conversion / materialization. 1067 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1068 public: 1069 using OpConversionPattern::OpConversionPattern; 1070 LogicalResult 1071 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 1072 ConversionPatternRewriter &rewriter) const override { 1073 rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands); 1074 return success(); 1075 } 1076 }; 1077 } // namespace 1078 1079 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1080 return std::make_unique<ConvertAsyncToLLVMPass>(); 1081 } 1082 1083 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1084 TypeConverter &typeConverter, RewritePatternSet &patterns, 1085 ConversionTarget &target) { 1086 typeConverter.addConversion([&](TokenType type) { return type; }); 1087 typeConverter.addConversion([&](ValueType type) { 1088 Type converted = typeConverter.convertType(type.getValueType()); 1089 return converted ? ValueType::get(converted) : converted; 1090 }); 1091 1092 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1093 typeConverter, patterns.getContext()); 1094 1095 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1096 [&](Operation *op) { return typeConverter.isLegal(op); }); 1097 } 1098