1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 #define DEBUG_TYPE "convert-async-to-llvm" 28 29 using namespace mlir; 30 using namespace mlir::async; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 52 static constexpr const char *kGetValueStorage = 53 "mlirAsyncRuntimeGetValueStorage"; 54 static constexpr const char *kAddTokenToGroup = 55 "mlirAsyncRuntimeAddTokenToGroup"; 56 static constexpr const char *kAwaitTokenAndExecute = 57 "mlirAsyncRuntimeAwaitTokenAndExecute"; 58 static constexpr const char *kAwaitValueAndExecute = 59 "mlirAsyncRuntimeAwaitValueAndExecute"; 60 static constexpr const char *kAwaitAllAndExecute = 61 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 62 63 namespace { 64 /// Async Runtime API function types. 65 /// 66 /// Because we can't create API function signature for type parametrized 67 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 68 /// lowering all async data types become opaque pointers at runtime. 69 struct AsyncAPI { 70 // All async types are lowered to opaque i8* LLVM pointers at runtime. 71 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 72 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 73 } 74 75 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 76 return LLVM::LLVMTokenType::get(ctx); 77 } 78 79 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 80 auto ref = opaquePointerType(ctx); 81 auto count = IntegerType::get(ctx, 64); 82 return FunctionType::get(ctx, {ref, count}, {}); 83 } 84 85 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 86 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 87 } 88 89 static FunctionType createValueFunctionType(MLIRContext *ctx) { 90 auto i64 = IntegerType::get(ctx, 64); 91 auto value = opaquePointerType(ctx); 92 return FunctionType::get(ctx, {i64}, {value}); 93 } 94 95 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 96 auto i64 = IntegerType::get(ctx, 64); 97 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 98 } 99 100 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 101 auto value = opaquePointerType(ctx); 102 auto storage = opaquePointerType(ctx); 103 return FunctionType::get(ctx, {value}, {storage}); 104 } 105 106 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 107 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 108 } 109 110 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 111 auto value = opaquePointerType(ctx); 112 return FunctionType::get(ctx, {value}, {}); 113 } 114 115 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 116 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 117 } 118 119 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 120 auto value = opaquePointerType(ctx); 121 return FunctionType::get(ctx, {value}, {}); 122 } 123 124 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 125 auto i1 = IntegerType::get(ctx, 1); 126 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 127 } 128 129 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 130 auto value = opaquePointerType(ctx); 131 auto i1 = IntegerType::get(ctx, 1); 132 return FunctionType::get(ctx, {value}, {i1}); 133 } 134 135 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 136 auto i1 = IntegerType::get(ctx, 1); 137 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 138 } 139 140 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 141 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 142 } 143 144 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 145 auto value = opaquePointerType(ctx); 146 return FunctionType::get(ctx, {value}, {}); 147 } 148 149 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 150 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 151 } 152 153 static FunctionType executeFunctionType(MLIRContext *ctx) { 154 auto hdl = opaquePointerType(ctx); 155 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 156 return FunctionType::get(ctx, {hdl, resume}, {}); 157 } 158 159 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 160 auto i64 = IntegerType::get(ctx, 64); 161 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 162 {i64}); 163 } 164 165 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 166 auto hdl = opaquePointerType(ctx); 167 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 168 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 169 } 170 171 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 172 auto value = opaquePointerType(ctx); 173 auto hdl = opaquePointerType(ctx); 174 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 175 return FunctionType::get(ctx, {value, hdl, resume}, {}); 176 } 177 178 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 179 auto hdl = opaquePointerType(ctx); 180 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 181 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 182 } 183 184 // Auxiliary coroutine resume intrinsic wrapper. 185 static Type resumeFunctionType(MLIRContext *ctx) { 186 auto voidTy = LLVM::LLVMVoidType::get(ctx); 187 auto i8Ptr = opaquePointerType(ctx); 188 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 189 } 190 }; 191 } // namespace 192 193 /// Adds Async Runtime C API declarations to the module. 194 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 195 auto builder = 196 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 197 198 auto addFuncDecl = [&](StringRef name, FunctionType type) { 199 if (module.lookupSymbol(name)) 200 return; 201 builder.create<FuncOp>(name, type).setPrivate(); 202 }; 203 204 MLIRContext *ctx = module.getContext(); 205 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 206 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 207 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 208 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 209 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 210 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 211 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 212 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 213 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 214 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 215 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 216 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 217 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 218 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 219 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 220 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 221 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 222 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 223 addFuncDecl(kAwaitTokenAndExecute, 224 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 225 addFuncDecl(kAwaitValueAndExecute, 226 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 227 addFuncDecl(kAwaitAllAndExecute, 228 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 229 } 230 231 //===----------------------------------------------------------------------===// 232 // Coroutine resume function wrapper. 233 //===----------------------------------------------------------------------===// 234 235 static constexpr const char *kResume = "__resume"; 236 237 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 238 /// intrinsics. We need this function to be able to pass it to the async 239 /// runtime execute API. 240 static void addResumeFunction(ModuleOp module) { 241 if (module.lookupSymbol(kResume)) 242 return; 243 244 MLIRContext *ctx = module.getContext(); 245 auto loc = module.getLoc(); 246 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 247 248 auto voidTy = LLVM::LLVMVoidType::get(ctx); 249 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 250 251 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 252 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 253 resumeOp.setPrivate(); 254 255 auto *block = resumeOp.addEntryBlock(); 256 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 257 258 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 259 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // Convert Async dialect types to LLVM types. 264 //===----------------------------------------------------------------------===// 265 266 namespace { 267 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 268 /// their runtime type (opaque pointers) and does not convert any other types. 269 class AsyncRuntimeTypeConverter : public TypeConverter { 270 public: 271 AsyncRuntimeTypeConverter() { 272 addConversion([](Type type) { return type; }); 273 addConversion(convertAsyncTypes); 274 } 275 276 static Optional<Type> convertAsyncTypes(Type type) { 277 if (type.isa<TokenType, GroupType, ValueType>()) 278 return AsyncAPI::opaquePointerType(type.getContext()); 279 280 if (type.isa<CoroIdType, CoroStateType>()) 281 return AsyncAPI::tokenType(type.getContext()); 282 if (type.isa<CoroHandleType>()) 283 return AsyncAPI::opaquePointerType(type.getContext()); 284 285 return llvm::None; 286 } 287 }; 288 } // namespace 289 290 //===----------------------------------------------------------------------===// 291 // Convert async.coro.id to @llvm.coro.id intrinsic. 292 //===----------------------------------------------------------------------===// 293 294 namespace { 295 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 296 public: 297 using OpConversionPattern::OpConversionPattern; 298 299 LogicalResult 300 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, 301 ConversionPatternRewriter &rewriter) const override { 302 auto token = AsyncAPI::tokenType(op->getContext()); 303 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 304 auto loc = op->getLoc(); 305 306 // Constants for initializing coroutine frame. 307 auto constZero = rewriter.create<LLVM::ConstantOp>( 308 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 309 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 310 311 // Get coroutine id: @llvm.coro.id. 312 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 313 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 314 315 return success(); 316 } 317 }; 318 } // namespace 319 320 //===----------------------------------------------------------------------===// 321 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 322 //===----------------------------------------------------------------------===// 323 324 namespace { 325 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 326 public: 327 using OpConversionPattern::OpConversionPattern; 328 329 LogicalResult 330 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, 331 ConversionPatternRewriter &rewriter) const override { 332 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 333 auto loc = op->getLoc(); 334 335 // Get coroutine frame size: @llvm.coro.size.i64. 336 Value coroSize = 337 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 338 // The coroutine lowering doesn't properly account for alignment of the 339 // frame, so align everything to 64 bytes which ought to be enough for 340 // everyone. https://llvm.org/PR53148 341 constexpr int64_t coroAlign = 64; 342 auto makeConstant = [&](uint64_t c) { 343 return rewriter.create<LLVM::ConstantOp>( 344 op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(c)); 345 }; 346 // Round up the size to the alignment. This is a requirement of 347 // aligned_alloc. 348 coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, 349 makeConstant(coroAlign - 1)); 350 coroSize = rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, 351 makeConstant(-coroAlign)); 352 353 // Allocate memory for the coroutine frame. 354 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 355 op->getParentOfType<ModuleOp>(), rewriter.getI64Type()); 356 auto coroAlloc = rewriter.create<LLVM::CallOp>( 357 loc, i8Ptr, SymbolRefAttr::get(allocFuncOp), 358 ValueRange{makeConstant(coroAlign), coroSize}); 359 360 // Begin a coroutine: @llvm.coro.begin. 361 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); 362 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 363 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 364 365 return success(); 366 } 367 }; 368 } // namespace 369 370 //===----------------------------------------------------------------------===// 371 // Convert async.coro.free to @llvm.coro.free intrinsic. 372 //===----------------------------------------------------------------------===// 373 374 namespace { 375 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 376 public: 377 using OpConversionPattern::OpConversionPattern; 378 379 LogicalResult 380 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, 381 ConversionPatternRewriter &rewriter) const override { 382 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 383 auto loc = op->getLoc(); 384 385 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 386 auto coroMem = 387 rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands()); 388 389 // Free the memory. 390 auto freeFuncOp = 391 LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 392 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 393 SymbolRefAttr::get(freeFuncOp), 394 ValueRange(coroMem.getResult())); 395 396 return success(); 397 } 398 }; 399 } // namespace 400 401 //===----------------------------------------------------------------------===// 402 // Convert async.coro.end to @llvm.coro.end intrinsic. 403 //===----------------------------------------------------------------------===// 404 405 namespace { 406 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 407 public: 408 using OpConversionPattern::OpConversionPattern; 409 410 LogicalResult 411 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, 412 ConversionPatternRewriter &rewriter) const override { 413 // We are not in the block that is part of the unwind sequence. 414 auto constFalse = rewriter.create<LLVM::ConstantOp>( 415 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 416 417 // Mark the end of a coroutine: @llvm.coro.end. 418 auto coroHdl = adaptor.handle(); 419 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 420 ValueRange({coroHdl, constFalse})); 421 rewriter.eraseOp(op); 422 423 return success(); 424 } 425 }; 426 } // namespace 427 428 //===----------------------------------------------------------------------===// 429 // Convert async.coro.save to @llvm.coro.save intrinsic. 430 //===----------------------------------------------------------------------===// 431 432 namespace { 433 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 434 public: 435 using OpConversionPattern::OpConversionPattern; 436 437 LogicalResult 438 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, 439 ConversionPatternRewriter &rewriter) const override { 440 // Save the coroutine state: @llvm.coro.save 441 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 442 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); 443 444 return success(); 445 } 446 }; 447 } // namespace 448 449 //===----------------------------------------------------------------------===// 450 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 451 //===----------------------------------------------------------------------===// 452 453 namespace { 454 455 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 456 /// branch to the appropriate block based on the return code. 457 /// 458 /// Before: 459 /// 460 /// ^suspended: 461 /// "opBefore"(...) 462 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 463 /// ^resume: 464 /// "op"(...) 465 /// ^cleanup: ... 466 /// ^suspend: ... 467 /// 468 /// After: 469 /// 470 /// ^suspended: 471 /// "opBefore"(...) 472 /// %suspend = llmv.intr.coro.suspend ... 473 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 474 /// ^resume: 475 /// "op"(...) 476 /// ^cleanup: ... 477 /// ^suspend: ... 478 /// 479 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 480 public: 481 using OpConversionPattern::OpConversionPattern; 482 483 LogicalResult 484 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, 485 ConversionPatternRewriter &rewriter) const override { 486 auto i8 = rewriter.getIntegerType(8); 487 auto i32 = rewriter.getI32Type(); 488 auto loc = op->getLoc(); 489 490 // This is not a final suspension point. 491 auto constFalse = rewriter.create<LLVM::ConstantOp>( 492 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 493 494 // Suspend a coroutine: @llvm.coro.suspend 495 auto coroState = adaptor.state(); 496 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 497 loc, i8, ValueRange({coroState, constFalse})); 498 499 // Cast return code to i32. 500 501 // After a suspension point decide if we should branch into resume, cleanup 502 // or suspend block of the coroutine (see @llvm.coro.suspend return code 503 // documentation). 504 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 505 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 506 op.cleanupDest()}; 507 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 508 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 509 /*defaultDestination=*/op.suspendDest(), 510 /*defaultOperands=*/ValueRange(), 511 /*caseValues=*/caseValues, 512 /*caseDestinations=*/caseDest, 513 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), 514 /*branchWeights=*/ArrayRef<int32_t>()); 515 516 return success(); 517 } 518 }; 519 } // namespace 520 521 //===----------------------------------------------------------------------===// 522 // Convert async.runtime.create to the corresponding runtime API call. 523 // 524 // To allocate storage for the async values we use getelementptr trick: 525 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 526 //===----------------------------------------------------------------------===// 527 528 namespace { 529 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 530 public: 531 using OpConversionPattern::OpConversionPattern; 532 533 LogicalResult 534 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, 535 ConversionPatternRewriter &rewriter) const override { 536 TypeConverter *converter = getTypeConverter(); 537 Type resultType = op->getResultTypes()[0]; 538 539 // Tokens creation maps to a simple function call. 540 if (resultType.isa<TokenType>()) { 541 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken, 542 converter->convertType(resultType)); 543 return success(); 544 } 545 546 // To create a value we need to compute the storage requirement. 547 if (auto value = resultType.dyn_cast<ValueType>()) { 548 // Returns the size requirements for the async value storage. 549 auto sizeOf = [&](ValueType valueType) -> Value { 550 auto loc = op->getLoc(); 551 auto i64 = rewriter.getI64Type(); 552 553 auto storedType = converter->convertType(valueType.getValueType()); 554 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 555 556 // %Size = getelementptr %T* null, int 1 557 // %SizeI = ptrtoint %T* %Size to i64 558 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 559 auto one = rewriter.create<LLVM::ConstantOp>( 560 loc, i64, rewriter.getI64IntegerAttr(1)); 561 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 562 one.getResult()); 563 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); 564 }; 565 566 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 567 sizeOf(value)); 568 569 return success(); 570 } 571 572 return rewriter.notifyMatchFailure(op, "unsupported async type"); 573 } 574 }; 575 } // namespace 576 577 //===----------------------------------------------------------------------===// 578 // Convert async.runtime.create_group to the corresponding runtime API call. 579 //===----------------------------------------------------------------------===// 580 581 namespace { 582 class RuntimeCreateGroupOpLowering 583 : public OpConversionPattern<RuntimeCreateGroupOp> { 584 public: 585 using OpConversionPattern::OpConversionPattern; 586 587 LogicalResult 588 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, 589 ConversionPatternRewriter &rewriter) const override { 590 TypeConverter *converter = getTypeConverter(); 591 Type resultType = op.getResult().getType(); 592 593 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, 594 converter->convertType(resultType), 595 adaptor.getOperands()); 596 return success(); 597 } 598 }; 599 } // namespace 600 601 //===----------------------------------------------------------------------===// 602 // Convert async.runtime.set_available to the corresponding runtime API call. 603 //===----------------------------------------------------------------------===// 604 605 namespace { 606 class RuntimeSetAvailableOpLowering 607 : public OpConversionPattern<RuntimeSetAvailableOp> { 608 public: 609 using OpConversionPattern::OpConversionPattern; 610 611 LogicalResult 612 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, 613 ConversionPatternRewriter &rewriter) const override { 614 StringRef apiFuncName = 615 TypeSwitch<Type, StringRef>(op.operand().getType()) 616 .Case<TokenType>([](Type) { return kEmplaceToken; }) 617 .Case<ValueType>([](Type) { return kEmplaceValue; }); 618 619 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 620 adaptor.getOperands()); 621 622 return success(); 623 } 624 }; 625 } // namespace 626 627 //===----------------------------------------------------------------------===// 628 // Convert async.runtime.set_error to the corresponding runtime API call. 629 //===----------------------------------------------------------------------===// 630 631 namespace { 632 class RuntimeSetErrorOpLowering 633 : public OpConversionPattern<RuntimeSetErrorOp> { 634 public: 635 using OpConversionPattern::OpConversionPattern; 636 637 LogicalResult 638 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, 639 ConversionPatternRewriter &rewriter) const override { 640 StringRef apiFuncName = 641 TypeSwitch<Type, StringRef>(op.operand().getType()) 642 .Case<TokenType>([](Type) { return kSetTokenError; }) 643 .Case<ValueType>([](Type) { return kSetValueError; }); 644 645 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 646 adaptor.getOperands()); 647 648 return success(); 649 } 650 }; 651 } // namespace 652 653 //===----------------------------------------------------------------------===// 654 // Convert async.runtime.is_error to the corresponding runtime API call. 655 //===----------------------------------------------------------------------===// 656 657 namespace { 658 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 659 public: 660 using OpConversionPattern::OpConversionPattern; 661 662 LogicalResult 663 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, 664 ConversionPatternRewriter &rewriter) const override { 665 StringRef apiFuncName = 666 TypeSwitch<Type, StringRef>(op.operand().getType()) 667 .Case<TokenType>([](Type) { return kIsTokenError; }) 668 .Case<GroupType>([](Type) { return kIsGroupError; }) 669 .Case<ValueType>([](Type) { return kIsValueError; }); 670 671 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(), 672 adaptor.getOperands()); 673 return success(); 674 } 675 }; 676 } // namespace 677 678 //===----------------------------------------------------------------------===// 679 // Convert async.runtime.await to the corresponding runtime API call. 680 //===----------------------------------------------------------------------===// 681 682 namespace { 683 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 684 public: 685 using OpConversionPattern::OpConversionPattern; 686 687 LogicalResult 688 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, 689 ConversionPatternRewriter &rewriter) const override { 690 StringRef apiFuncName = 691 TypeSwitch<Type, StringRef>(op.operand().getType()) 692 .Case<TokenType>([](Type) { return kAwaitToken; }) 693 .Case<ValueType>([](Type) { return kAwaitValue; }) 694 .Case<GroupType>([](Type) { return kAwaitGroup; }); 695 696 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 697 adaptor.getOperands()); 698 rewriter.eraseOp(op); 699 700 return success(); 701 } 702 }; 703 } // namespace 704 705 //===----------------------------------------------------------------------===// 706 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 707 //===----------------------------------------------------------------------===// 708 709 namespace { 710 class RuntimeAwaitAndResumeOpLowering 711 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 712 public: 713 using OpConversionPattern::OpConversionPattern; 714 715 LogicalResult 716 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, 717 ConversionPatternRewriter &rewriter) const override { 718 StringRef apiFuncName = 719 TypeSwitch<Type, StringRef>(op.operand().getType()) 720 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 721 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 722 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 723 724 Value operand = adaptor.operand(); 725 Value handle = adaptor.handle(); 726 727 // A pointer to coroutine resume intrinsic wrapper. 728 addResumeFunction(op->getParentOfType<ModuleOp>()); 729 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 730 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 731 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 732 733 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 734 ValueRange({operand, handle, resumePtr.getRes()})); 735 rewriter.eraseOp(op); 736 737 return success(); 738 } 739 }; 740 } // namespace 741 742 //===----------------------------------------------------------------------===// 743 // Convert async.runtime.resume to the corresponding runtime API call. 744 //===----------------------------------------------------------------------===// 745 746 namespace { 747 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 748 public: 749 using OpConversionPattern::OpConversionPattern; 750 751 LogicalResult 752 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, 753 ConversionPatternRewriter &rewriter) const override { 754 // A pointer to coroutine resume intrinsic wrapper. 755 addResumeFunction(op->getParentOfType<ModuleOp>()); 756 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 757 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 758 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 759 760 // Call async runtime API to execute a coroutine in the managed thread. 761 auto coroHdl = adaptor.handle(); 762 rewriter.replaceOpWithNewOp<CallOp>( 763 op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); 764 765 return success(); 766 } 767 }; 768 } // namespace 769 770 //===----------------------------------------------------------------------===// 771 // Convert async.runtime.store to the corresponding runtime API call. 772 //===----------------------------------------------------------------------===// 773 774 namespace { 775 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 776 public: 777 using OpConversionPattern::OpConversionPattern; 778 779 LogicalResult 780 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, 781 ConversionPatternRewriter &rewriter) const override { 782 Location loc = op->getLoc(); 783 784 // Get a pointer to the async value storage from the runtime. 785 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 786 auto storage = adaptor.storage(); 787 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 788 TypeRange(i8Ptr), storage); 789 790 // Cast from i8* to the LLVM pointer type. 791 auto valueType = op.value().getType(); 792 auto llvmValueType = getTypeConverter()->convertType(valueType); 793 if (!llvmValueType) 794 return rewriter.notifyMatchFailure( 795 op, "failed to convert stored value type to LLVM type"); 796 797 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 798 loc, LLVM::LLVMPointerType::get(llvmValueType), 799 storagePtr.getResult(0)); 800 801 // Store the yielded value into the async value storage. 802 auto value = adaptor.value(); 803 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 804 805 // Erase the original runtime store operation. 806 rewriter.eraseOp(op); 807 808 return success(); 809 } 810 }; 811 } // namespace 812 813 //===----------------------------------------------------------------------===// 814 // Convert async.runtime.load to the corresponding runtime API call. 815 //===----------------------------------------------------------------------===// 816 817 namespace { 818 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 819 public: 820 using OpConversionPattern::OpConversionPattern; 821 822 LogicalResult 823 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, 824 ConversionPatternRewriter &rewriter) const override { 825 Location loc = op->getLoc(); 826 827 // Get a pointer to the async value storage from the runtime. 828 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 829 auto storage = adaptor.storage(); 830 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 831 TypeRange(i8Ptr), storage); 832 833 // Cast from i8* to the LLVM pointer type. 834 auto valueType = op.result().getType(); 835 auto llvmValueType = getTypeConverter()->convertType(valueType); 836 if (!llvmValueType) 837 return rewriter.notifyMatchFailure( 838 op, "failed to convert loaded value type to LLVM type"); 839 840 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 841 loc, LLVM::LLVMPointerType::get(llvmValueType), 842 storagePtr.getResult(0)); 843 844 // Load from the casted pointer. 845 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 846 847 return success(); 848 } 849 }; 850 } // namespace 851 852 //===----------------------------------------------------------------------===// 853 // Convert async.runtime.add_to_group to the corresponding runtime API call. 854 //===----------------------------------------------------------------------===// 855 856 namespace { 857 class RuntimeAddToGroupOpLowering 858 : public OpConversionPattern<RuntimeAddToGroupOp> { 859 public: 860 using OpConversionPattern::OpConversionPattern; 861 862 LogicalResult 863 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, 864 ConversionPatternRewriter &rewriter) const override { 865 // Currently we can only add tokens to the group. 866 if (!op.operand().getType().isa<TokenType>()) 867 return rewriter.notifyMatchFailure(op, "only token type is supported"); 868 869 // Replace with a runtime API function call. 870 rewriter.replaceOpWithNewOp<CallOp>( 871 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); 872 873 return success(); 874 } 875 }; 876 } // namespace 877 878 //===----------------------------------------------------------------------===// 879 // Async reference counting ops lowering (`async.runtime.add_ref` and 880 // `async.runtime.drop_ref` to the corresponding API calls). 881 //===----------------------------------------------------------------------===// 882 883 namespace { 884 template <typename RefCountingOp> 885 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 886 public: 887 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 888 StringRef apiFunctionName) 889 : OpConversionPattern<RefCountingOp>(converter, ctx), 890 apiFunctionName(apiFunctionName) {} 891 892 LogicalResult 893 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, 894 ConversionPatternRewriter &rewriter) const override { 895 auto count = rewriter.create<arith::ConstantOp>( 896 op->getLoc(), rewriter.getI64Type(), 897 rewriter.getI64IntegerAttr(op.count())); 898 899 auto operand = adaptor.operand(); 900 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 901 ValueRange({operand, count})); 902 903 return success(); 904 } 905 906 private: 907 StringRef apiFunctionName; 908 }; 909 910 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 911 public: 912 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 913 : RefCountingOpLowering(converter, ctx, kAddRef) {} 914 }; 915 916 class RuntimeDropRefOpLowering 917 : public RefCountingOpLowering<RuntimeDropRefOp> { 918 public: 919 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 920 : RefCountingOpLowering(converter, ctx, kDropRef) {} 921 }; 922 } // namespace 923 924 //===----------------------------------------------------------------------===// 925 // Convert return operations that return async values from async regions. 926 //===----------------------------------------------------------------------===// 927 928 namespace { 929 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 930 public: 931 using OpConversionPattern::OpConversionPattern; 932 933 LogicalResult 934 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 935 ConversionPatternRewriter &rewriter) const override { 936 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 937 return success(); 938 } 939 }; 940 } // namespace 941 942 //===----------------------------------------------------------------------===// 943 944 namespace { 945 struct ConvertAsyncToLLVMPass 946 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 947 void runOnOperation() override; 948 }; 949 } // namespace 950 951 void ConvertAsyncToLLVMPass::runOnOperation() { 952 ModuleOp module = getOperation(); 953 MLIRContext *ctx = module->getContext(); 954 955 // Add declarations for most functions required by the coroutines lowering. 956 // We delay adding the resume function until it's needed because it currently 957 // fails to compile unless '-O0' is specified. 958 addAsyncRuntimeApiDeclarations(module); 959 960 // Lower async.runtime and async.coro operations to Async Runtime API and 961 // LLVM coroutine intrinsics. 962 963 // Convert async dialect types and operations to LLVM dialect. 964 AsyncRuntimeTypeConverter converter; 965 RewritePatternSet patterns(ctx); 966 967 // We use conversion to LLVM type to lower async.runtime load and store 968 // operations. 969 LLVMTypeConverter llvmConverter(ctx); 970 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 971 972 // Convert async types in function signatures and function calls. 973 populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, converter); 974 populateCallOpTypeConversionPattern(patterns, converter); 975 976 // Convert return operations inside async.execute regions. 977 patterns.add<ReturnOpOpConversion>(converter, ctx); 978 979 // Lower async.runtime operations to the async runtime API calls. 980 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 981 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 982 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 983 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 984 RuntimeDropRefOpLowering>(converter, ctx); 985 986 // Lower async.runtime operations that rely on LLVM type converter to convert 987 // from async value payload type to the LLVM type. 988 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 989 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter, 990 ctx); 991 992 // Lower async coroutine operations to LLVM coroutine intrinsics. 993 patterns 994 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 995 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 996 converter, ctx); 997 998 ConversionTarget target(*ctx); 999 target 1000 .addLegalOp<arith::ConstantOp, ConstantOp, UnrealizedConversionCastOp>(); 1001 target.addLegalDialect<LLVM::LLVMDialect>(); 1002 1003 // All operations from Async dialect must be lowered to the runtime API and 1004 // LLVM intrinsics calls. 1005 target.addIllegalDialect<AsyncDialect>(); 1006 1007 // Add dynamic legality constraints to apply conversions defined above. 1008 target.addDynamicallyLegalOp<FuncOp>( 1009 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1010 target.addDynamicallyLegalOp<ReturnOp>( 1011 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1012 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1013 return converter.isSignatureLegal(op.getCalleeType()); 1014 }); 1015 1016 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1017 signalPassFailure(); 1018 } 1019 1020 //===----------------------------------------------------------------------===// 1021 // Patterns for structural type conversions for the Async dialect operations. 1022 //===----------------------------------------------------------------------===// 1023 1024 namespace { 1025 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1026 public: 1027 using OpConversionPattern::OpConversionPattern; 1028 LogicalResult 1029 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, 1030 ConversionPatternRewriter &rewriter) const override { 1031 ExecuteOp newOp = 1032 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1033 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1034 newOp.getRegion().end()); 1035 1036 // Set operands and update block argument and result types. 1037 newOp->setOperands(adaptor.getOperands()); 1038 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1039 return failure(); 1040 for (auto result : newOp.getResults()) 1041 result.setType(typeConverter->convertType(result.getType())); 1042 1043 rewriter.replaceOp(op, newOp.getResults()); 1044 return success(); 1045 } 1046 }; 1047 1048 // Dummy pattern to trigger the appropriate type conversion / materialization. 1049 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1050 public: 1051 using OpConversionPattern::OpConversionPattern; 1052 LogicalResult 1053 matchAndRewrite(AwaitOp op, OpAdaptor adaptor, 1054 ConversionPatternRewriter &rewriter) const override { 1055 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); 1056 return success(); 1057 } 1058 }; 1059 1060 // Dummy pattern to trigger the appropriate type conversion / materialization. 1061 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1062 public: 1063 using OpConversionPattern::OpConversionPattern; 1064 LogicalResult 1065 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 1066 ConversionPatternRewriter &rewriter) const override { 1067 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); 1068 return success(); 1069 } 1070 }; 1071 } // namespace 1072 1073 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1074 return std::make_unique<ConvertAsyncToLLVMPass>(); 1075 } 1076 1077 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1078 TypeConverter &typeConverter, RewritePatternSet &patterns, 1079 ConversionTarget &target) { 1080 typeConverter.addConversion([&](TokenType type) { return type; }); 1081 typeConverter.addConversion([&](ValueType type) { 1082 Type converted = typeConverter.convertType(type.getValueType()); 1083 return converted ? ValueType::get(converted) : converted; 1084 }); 1085 1086 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1087 typeConverter, patterns.getContext()); 1088 1089 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1090 [&](Operation *op) { return typeConverter.isLegal(op); }); 1091 } 1092