1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 #define DEBUG_TYPE "convert-async-to-llvm" 28 29 using namespace mlir; 30 using namespace mlir::async; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 52 static constexpr const char *kGetValueStorage = 53 "mlirAsyncRuntimeGetValueStorage"; 54 static constexpr const char *kAddTokenToGroup = 55 "mlirAsyncRuntimeAddTokenToGroup"; 56 static constexpr const char *kAwaitTokenAndExecute = 57 "mlirAsyncRuntimeAwaitTokenAndExecute"; 58 static constexpr const char *kAwaitValueAndExecute = 59 "mlirAsyncRuntimeAwaitValueAndExecute"; 60 static constexpr const char *kAwaitAllAndExecute = 61 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 62 63 namespace { 64 /// Async Runtime API function types. 65 /// 66 /// Because we can't create API function signature for type parametrized 67 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 68 /// lowering all async data types become opaque pointers at runtime. 69 struct AsyncAPI { 70 // All async types are lowered to opaque i8* LLVM pointers at runtime. 71 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 72 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 73 } 74 75 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 76 return LLVM::LLVMTokenType::get(ctx); 77 } 78 79 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 80 auto ref = opaquePointerType(ctx); 81 auto count = IntegerType::get(ctx, 64); 82 return FunctionType::get(ctx, {ref, count}, {}); 83 } 84 85 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 86 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 87 } 88 89 static FunctionType createValueFunctionType(MLIRContext *ctx) { 90 auto i64 = IntegerType::get(ctx, 64); 91 auto value = opaquePointerType(ctx); 92 return FunctionType::get(ctx, {i64}, {value}); 93 } 94 95 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 96 auto i64 = IntegerType::get(ctx, 64); 97 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 98 } 99 100 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 101 auto value = opaquePointerType(ctx); 102 auto storage = opaquePointerType(ctx); 103 return FunctionType::get(ctx, {value}, {storage}); 104 } 105 106 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 107 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 108 } 109 110 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 111 auto value = opaquePointerType(ctx); 112 return FunctionType::get(ctx, {value}, {}); 113 } 114 115 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 116 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 117 } 118 119 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 120 auto value = opaquePointerType(ctx); 121 return FunctionType::get(ctx, {value}, {}); 122 } 123 124 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 125 auto i1 = IntegerType::get(ctx, 1); 126 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 127 } 128 129 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 130 auto value = opaquePointerType(ctx); 131 auto i1 = IntegerType::get(ctx, 1); 132 return FunctionType::get(ctx, {value}, {i1}); 133 } 134 135 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 136 auto i1 = IntegerType::get(ctx, 1); 137 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 138 } 139 140 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 141 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 142 } 143 144 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 145 auto value = opaquePointerType(ctx); 146 return FunctionType::get(ctx, {value}, {}); 147 } 148 149 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 150 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 151 } 152 153 static FunctionType executeFunctionType(MLIRContext *ctx) { 154 auto hdl = opaquePointerType(ctx); 155 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 156 return FunctionType::get(ctx, {hdl, resume}, {}); 157 } 158 159 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 160 auto i64 = IntegerType::get(ctx, 64); 161 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 162 {i64}); 163 } 164 165 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 166 auto hdl = opaquePointerType(ctx); 167 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 168 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 169 } 170 171 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 172 auto value = opaquePointerType(ctx); 173 auto hdl = opaquePointerType(ctx); 174 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 175 return FunctionType::get(ctx, {value, hdl, resume}, {}); 176 } 177 178 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 179 auto hdl = opaquePointerType(ctx); 180 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 181 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 182 } 183 184 // Auxiliary coroutine resume intrinsic wrapper. 185 static Type resumeFunctionType(MLIRContext *ctx) { 186 auto voidTy = LLVM::LLVMVoidType::get(ctx); 187 auto i8Ptr = opaquePointerType(ctx); 188 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 189 } 190 }; 191 } // namespace 192 193 /// Adds Async Runtime C API declarations to the module. 194 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 195 auto builder = 196 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 197 198 auto addFuncDecl = [&](StringRef name, FunctionType type) { 199 if (module.lookupSymbol(name)) 200 return; 201 builder.create<FuncOp>(name, type).setPrivate(); 202 }; 203 204 MLIRContext *ctx = module.getContext(); 205 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 206 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 207 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 208 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 209 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 210 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 211 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 212 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 213 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 214 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 215 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 216 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 217 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 218 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 219 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 220 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 221 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 222 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 223 addFuncDecl(kAwaitTokenAndExecute, 224 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 225 addFuncDecl(kAwaitValueAndExecute, 226 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 227 addFuncDecl(kAwaitAllAndExecute, 228 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 229 } 230 231 //===----------------------------------------------------------------------===// 232 // Coroutine resume function wrapper. 233 //===----------------------------------------------------------------------===// 234 235 static constexpr const char *kResume = "__resume"; 236 237 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 238 /// intrinsics. We need this function to be able to pass it to the async 239 /// runtime execute API. 240 static void addResumeFunction(ModuleOp module) { 241 if (module.lookupSymbol(kResume)) 242 return; 243 244 MLIRContext *ctx = module.getContext(); 245 auto loc = module.getLoc(); 246 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 247 248 auto voidTy = LLVM::LLVMVoidType::get(ctx); 249 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 250 251 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 252 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 253 resumeOp.setPrivate(); 254 255 auto *block = resumeOp.addEntryBlock(); 256 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 257 258 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 259 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // Convert Async dialect types to LLVM types. 264 //===----------------------------------------------------------------------===// 265 266 namespace { 267 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 268 /// their runtime type (opaque pointers) and does not convert any other types. 269 class AsyncRuntimeTypeConverter : public TypeConverter { 270 public: 271 AsyncRuntimeTypeConverter() { 272 addConversion([](Type type) { return type; }); 273 addConversion(convertAsyncTypes); 274 } 275 276 static Optional<Type> convertAsyncTypes(Type type) { 277 if (type.isa<TokenType, GroupType, ValueType>()) 278 return AsyncAPI::opaquePointerType(type.getContext()); 279 280 if (type.isa<CoroIdType, CoroStateType>()) 281 return AsyncAPI::tokenType(type.getContext()); 282 if (type.isa<CoroHandleType>()) 283 return AsyncAPI::opaquePointerType(type.getContext()); 284 285 return llvm::None; 286 } 287 }; 288 } // namespace 289 290 //===----------------------------------------------------------------------===// 291 // Convert async.coro.id to @llvm.coro.id intrinsic. 292 //===----------------------------------------------------------------------===// 293 294 namespace { 295 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 296 public: 297 using OpConversionPattern::OpConversionPattern; 298 299 LogicalResult 300 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, 301 ConversionPatternRewriter &rewriter) const override { 302 auto token = AsyncAPI::tokenType(op->getContext()); 303 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 304 auto loc = op->getLoc(); 305 306 // Constants for initializing coroutine frame. 307 auto constZero = rewriter.create<LLVM::ConstantOp>( 308 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 309 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 310 311 // Get coroutine id: @llvm.coro.id. 312 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 313 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 314 315 return success(); 316 } 317 }; 318 } // namespace 319 320 //===----------------------------------------------------------------------===// 321 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 322 //===----------------------------------------------------------------------===// 323 324 namespace { 325 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 326 public: 327 using OpConversionPattern::OpConversionPattern; 328 329 LogicalResult 330 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, 331 ConversionPatternRewriter &rewriter) const override { 332 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 333 auto loc = op->getLoc(); 334 335 // Get coroutine frame size: @llvm.coro.size.i64. 336 Value coroSize = 337 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 338 // Get coroutine frame alignment: @llvm.coro.align.i64. 339 Value coroAlign = 340 rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type()); 341 342 // Round up the size to be multiple of the alignment. Since aligned_alloc 343 // requires the size parameter be an integral multiple of the alignment 344 // parameter. 345 auto makeConstant = [&](uint64_t c) { 346 return rewriter.create<LLVM::ConstantOp>( 347 op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(c)); 348 }; 349 coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign); 350 coroSize = 351 rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1)); 352 Value NegCoroAlign = 353 rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign); 354 coroSize = 355 rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, NegCoroAlign); 356 357 // Allocate memory for the coroutine frame. 358 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 359 op->getParentOfType<ModuleOp>(), rewriter.getI64Type()); 360 auto coroAlloc = rewriter.create<LLVM::CallOp>( 361 loc, i8Ptr, SymbolRefAttr::get(allocFuncOp), 362 ValueRange{coroAlign, coroSize}); 363 364 // Begin a coroutine: @llvm.coro.begin. 365 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); 366 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 367 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 368 369 return success(); 370 } 371 }; 372 } // namespace 373 374 //===----------------------------------------------------------------------===// 375 // Convert async.coro.free to @llvm.coro.free intrinsic. 376 //===----------------------------------------------------------------------===// 377 378 namespace { 379 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 380 public: 381 using OpConversionPattern::OpConversionPattern; 382 383 LogicalResult 384 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, 385 ConversionPatternRewriter &rewriter) const override { 386 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 387 auto loc = op->getLoc(); 388 389 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 390 auto coroMem = 391 rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands()); 392 393 // Free the memory. 394 auto freeFuncOp = 395 LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 396 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 397 SymbolRefAttr::get(freeFuncOp), 398 ValueRange(coroMem.getResult())); 399 400 return success(); 401 } 402 }; 403 } // namespace 404 405 //===----------------------------------------------------------------------===// 406 // Convert async.coro.end to @llvm.coro.end intrinsic. 407 //===----------------------------------------------------------------------===// 408 409 namespace { 410 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 411 public: 412 using OpConversionPattern::OpConversionPattern; 413 414 LogicalResult 415 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, 416 ConversionPatternRewriter &rewriter) const override { 417 // We are not in the block that is part of the unwind sequence. 418 auto constFalse = rewriter.create<LLVM::ConstantOp>( 419 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 420 421 // Mark the end of a coroutine: @llvm.coro.end. 422 auto coroHdl = adaptor.handle(); 423 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 424 ValueRange({coroHdl, constFalse})); 425 rewriter.eraseOp(op); 426 427 return success(); 428 } 429 }; 430 } // namespace 431 432 //===----------------------------------------------------------------------===// 433 // Convert async.coro.save to @llvm.coro.save intrinsic. 434 //===----------------------------------------------------------------------===// 435 436 namespace { 437 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 438 public: 439 using OpConversionPattern::OpConversionPattern; 440 441 LogicalResult 442 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, 443 ConversionPatternRewriter &rewriter) const override { 444 // Save the coroutine state: @llvm.coro.save 445 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 446 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); 447 448 return success(); 449 } 450 }; 451 } // namespace 452 453 //===----------------------------------------------------------------------===// 454 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 455 //===----------------------------------------------------------------------===// 456 457 namespace { 458 459 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 460 /// branch to the appropriate block based on the return code. 461 /// 462 /// Before: 463 /// 464 /// ^suspended: 465 /// "opBefore"(...) 466 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 467 /// ^resume: 468 /// "op"(...) 469 /// ^cleanup: ... 470 /// ^suspend: ... 471 /// 472 /// After: 473 /// 474 /// ^suspended: 475 /// "opBefore"(...) 476 /// %suspend = llmv.intr.coro.suspend ... 477 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 478 /// ^resume: 479 /// "op"(...) 480 /// ^cleanup: ... 481 /// ^suspend: ... 482 /// 483 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 484 public: 485 using OpConversionPattern::OpConversionPattern; 486 487 LogicalResult 488 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, 489 ConversionPatternRewriter &rewriter) const override { 490 auto i8 = rewriter.getIntegerType(8); 491 auto i32 = rewriter.getI32Type(); 492 auto loc = op->getLoc(); 493 494 // This is not a final suspension point. 495 auto constFalse = rewriter.create<LLVM::ConstantOp>( 496 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 497 498 // Suspend a coroutine: @llvm.coro.suspend 499 auto coroState = adaptor.state(); 500 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 501 loc, i8, ValueRange({coroState, constFalse})); 502 503 // Cast return code to i32. 504 505 // After a suspension point decide if we should branch into resume, cleanup 506 // or suspend block of the coroutine (see @llvm.coro.suspend return code 507 // documentation). 508 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 509 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 510 op.cleanupDest()}; 511 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 512 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 513 /*defaultDestination=*/op.suspendDest(), 514 /*defaultOperands=*/ValueRange(), 515 /*caseValues=*/caseValues, 516 /*caseDestinations=*/caseDest, 517 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), 518 /*branchWeights=*/ArrayRef<int32_t>()); 519 520 return success(); 521 } 522 }; 523 } // namespace 524 525 //===----------------------------------------------------------------------===// 526 // Convert async.runtime.create to the corresponding runtime API call. 527 // 528 // To allocate storage for the async values we use getelementptr trick: 529 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 530 //===----------------------------------------------------------------------===// 531 532 namespace { 533 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 534 public: 535 using OpConversionPattern::OpConversionPattern; 536 537 LogicalResult 538 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, 539 ConversionPatternRewriter &rewriter) const override { 540 TypeConverter *converter = getTypeConverter(); 541 Type resultType = op->getResultTypes()[0]; 542 543 // Tokens creation maps to a simple function call. 544 if (resultType.isa<TokenType>()) { 545 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken, 546 converter->convertType(resultType)); 547 return success(); 548 } 549 550 // To create a value we need to compute the storage requirement. 551 if (auto value = resultType.dyn_cast<ValueType>()) { 552 // Returns the size requirements for the async value storage. 553 auto sizeOf = [&](ValueType valueType) -> Value { 554 auto loc = op->getLoc(); 555 auto i64 = rewriter.getI64Type(); 556 557 auto storedType = converter->convertType(valueType.getValueType()); 558 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 559 560 // %Size = getelementptr %T* null, int 1 561 // %SizeI = ptrtoint %T* %Size to i64 562 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 563 auto one = rewriter.create<LLVM::ConstantOp>( 564 loc, i64, rewriter.getI64IntegerAttr(1)); 565 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 566 one.getResult()); 567 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); 568 }; 569 570 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 571 sizeOf(value)); 572 573 return success(); 574 } 575 576 return rewriter.notifyMatchFailure(op, "unsupported async type"); 577 } 578 }; 579 } // namespace 580 581 //===----------------------------------------------------------------------===// 582 // Convert async.runtime.create_group to the corresponding runtime API call. 583 //===----------------------------------------------------------------------===// 584 585 namespace { 586 class RuntimeCreateGroupOpLowering 587 : public OpConversionPattern<RuntimeCreateGroupOp> { 588 public: 589 using OpConversionPattern::OpConversionPattern; 590 591 LogicalResult 592 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, 593 ConversionPatternRewriter &rewriter) const override { 594 TypeConverter *converter = getTypeConverter(); 595 Type resultType = op.getResult().getType(); 596 597 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, 598 converter->convertType(resultType), 599 adaptor.getOperands()); 600 return success(); 601 } 602 }; 603 } // namespace 604 605 //===----------------------------------------------------------------------===// 606 // Convert async.runtime.set_available to the corresponding runtime API call. 607 //===----------------------------------------------------------------------===// 608 609 namespace { 610 class RuntimeSetAvailableOpLowering 611 : public OpConversionPattern<RuntimeSetAvailableOp> { 612 public: 613 using OpConversionPattern::OpConversionPattern; 614 615 LogicalResult 616 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, 617 ConversionPatternRewriter &rewriter) const override { 618 StringRef apiFuncName = 619 TypeSwitch<Type, StringRef>(op.operand().getType()) 620 .Case<TokenType>([](Type) { return kEmplaceToken; }) 621 .Case<ValueType>([](Type) { return kEmplaceValue; }); 622 623 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 624 adaptor.getOperands()); 625 626 return success(); 627 } 628 }; 629 } // namespace 630 631 //===----------------------------------------------------------------------===// 632 // Convert async.runtime.set_error to the corresponding runtime API call. 633 //===----------------------------------------------------------------------===// 634 635 namespace { 636 class RuntimeSetErrorOpLowering 637 : public OpConversionPattern<RuntimeSetErrorOp> { 638 public: 639 using OpConversionPattern::OpConversionPattern; 640 641 LogicalResult 642 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, 643 ConversionPatternRewriter &rewriter) const override { 644 StringRef apiFuncName = 645 TypeSwitch<Type, StringRef>(op.operand().getType()) 646 .Case<TokenType>([](Type) { return kSetTokenError; }) 647 .Case<ValueType>([](Type) { return kSetValueError; }); 648 649 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 650 adaptor.getOperands()); 651 652 return success(); 653 } 654 }; 655 } // namespace 656 657 //===----------------------------------------------------------------------===// 658 // Convert async.runtime.is_error to the corresponding runtime API call. 659 //===----------------------------------------------------------------------===// 660 661 namespace { 662 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 663 public: 664 using OpConversionPattern::OpConversionPattern; 665 666 LogicalResult 667 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, 668 ConversionPatternRewriter &rewriter) const override { 669 StringRef apiFuncName = 670 TypeSwitch<Type, StringRef>(op.operand().getType()) 671 .Case<TokenType>([](Type) { return kIsTokenError; }) 672 .Case<GroupType>([](Type) { return kIsGroupError; }) 673 .Case<ValueType>([](Type) { return kIsValueError; }); 674 675 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(), 676 adaptor.getOperands()); 677 return success(); 678 } 679 }; 680 } // namespace 681 682 //===----------------------------------------------------------------------===// 683 // Convert async.runtime.await to the corresponding runtime API call. 684 //===----------------------------------------------------------------------===// 685 686 namespace { 687 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 688 public: 689 using OpConversionPattern::OpConversionPattern; 690 691 LogicalResult 692 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, 693 ConversionPatternRewriter &rewriter) const override { 694 StringRef apiFuncName = 695 TypeSwitch<Type, StringRef>(op.operand().getType()) 696 .Case<TokenType>([](Type) { return kAwaitToken; }) 697 .Case<ValueType>([](Type) { return kAwaitValue; }) 698 .Case<GroupType>([](Type) { return kAwaitGroup; }); 699 700 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 701 adaptor.getOperands()); 702 rewriter.eraseOp(op); 703 704 return success(); 705 } 706 }; 707 } // namespace 708 709 //===----------------------------------------------------------------------===// 710 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 711 //===----------------------------------------------------------------------===// 712 713 namespace { 714 class RuntimeAwaitAndResumeOpLowering 715 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 716 public: 717 using OpConversionPattern::OpConversionPattern; 718 719 LogicalResult 720 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, 721 ConversionPatternRewriter &rewriter) const override { 722 StringRef apiFuncName = 723 TypeSwitch<Type, StringRef>(op.operand().getType()) 724 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 725 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 726 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 727 728 Value operand = adaptor.operand(); 729 Value handle = adaptor.handle(); 730 731 // A pointer to coroutine resume intrinsic wrapper. 732 addResumeFunction(op->getParentOfType<ModuleOp>()); 733 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 734 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 735 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 736 737 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 738 ValueRange({operand, handle, resumePtr.getRes()})); 739 rewriter.eraseOp(op); 740 741 return success(); 742 } 743 }; 744 } // namespace 745 746 //===----------------------------------------------------------------------===// 747 // Convert async.runtime.resume to the corresponding runtime API call. 748 //===----------------------------------------------------------------------===// 749 750 namespace { 751 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 752 public: 753 using OpConversionPattern::OpConversionPattern; 754 755 LogicalResult 756 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, 757 ConversionPatternRewriter &rewriter) const override { 758 // A pointer to coroutine resume intrinsic wrapper. 759 addResumeFunction(op->getParentOfType<ModuleOp>()); 760 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 761 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 762 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 763 764 // Call async runtime API to execute a coroutine in the managed thread. 765 auto coroHdl = adaptor.handle(); 766 rewriter.replaceOpWithNewOp<CallOp>( 767 op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); 768 769 return success(); 770 } 771 }; 772 } // namespace 773 774 //===----------------------------------------------------------------------===// 775 // Convert async.runtime.store to the corresponding runtime API call. 776 //===----------------------------------------------------------------------===// 777 778 namespace { 779 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 780 public: 781 using OpConversionPattern::OpConversionPattern; 782 783 LogicalResult 784 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, 785 ConversionPatternRewriter &rewriter) const override { 786 Location loc = op->getLoc(); 787 788 // Get a pointer to the async value storage from the runtime. 789 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 790 auto storage = adaptor.storage(); 791 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 792 TypeRange(i8Ptr), storage); 793 794 // Cast from i8* to the LLVM pointer type. 795 auto valueType = op.value().getType(); 796 auto llvmValueType = getTypeConverter()->convertType(valueType); 797 if (!llvmValueType) 798 return rewriter.notifyMatchFailure( 799 op, "failed to convert stored value type to LLVM type"); 800 801 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 802 loc, LLVM::LLVMPointerType::get(llvmValueType), 803 storagePtr.getResult(0)); 804 805 // Store the yielded value into the async value storage. 806 auto value = adaptor.value(); 807 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 808 809 // Erase the original runtime store operation. 810 rewriter.eraseOp(op); 811 812 return success(); 813 } 814 }; 815 } // namespace 816 817 //===----------------------------------------------------------------------===// 818 // Convert async.runtime.load to the corresponding runtime API call. 819 //===----------------------------------------------------------------------===// 820 821 namespace { 822 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 823 public: 824 using OpConversionPattern::OpConversionPattern; 825 826 LogicalResult 827 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, 828 ConversionPatternRewriter &rewriter) const override { 829 Location loc = op->getLoc(); 830 831 // Get a pointer to the async value storage from the runtime. 832 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 833 auto storage = adaptor.storage(); 834 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 835 TypeRange(i8Ptr), storage); 836 837 // Cast from i8* to the LLVM pointer type. 838 auto valueType = op.result().getType(); 839 auto llvmValueType = getTypeConverter()->convertType(valueType); 840 if (!llvmValueType) 841 return rewriter.notifyMatchFailure( 842 op, "failed to convert loaded value type to LLVM type"); 843 844 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 845 loc, LLVM::LLVMPointerType::get(llvmValueType), 846 storagePtr.getResult(0)); 847 848 // Load from the casted pointer. 849 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 850 851 return success(); 852 } 853 }; 854 } // namespace 855 856 //===----------------------------------------------------------------------===// 857 // Convert async.runtime.add_to_group to the corresponding runtime API call. 858 //===----------------------------------------------------------------------===// 859 860 namespace { 861 class RuntimeAddToGroupOpLowering 862 : public OpConversionPattern<RuntimeAddToGroupOp> { 863 public: 864 using OpConversionPattern::OpConversionPattern; 865 866 LogicalResult 867 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, 868 ConversionPatternRewriter &rewriter) const override { 869 // Currently we can only add tokens to the group. 870 if (!op.operand().getType().isa<TokenType>()) 871 return rewriter.notifyMatchFailure(op, "only token type is supported"); 872 873 // Replace with a runtime API function call. 874 rewriter.replaceOpWithNewOp<CallOp>( 875 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); 876 877 return success(); 878 } 879 }; 880 } // namespace 881 882 //===----------------------------------------------------------------------===// 883 // Async reference counting ops lowering (`async.runtime.add_ref` and 884 // `async.runtime.drop_ref` to the corresponding API calls). 885 //===----------------------------------------------------------------------===// 886 887 namespace { 888 template <typename RefCountingOp> 889 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 890 public: 891 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 892 StringRef apiFunctionName) 893 : OpConversionPattern<RefCountingOp>(converter, ctx), 894 apiFunctionName(apiFunctionName) {} 895 896 LogicalResult 897 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, 898 ConversionPatternRewriter &rewriter) const override { 899 auto count = rewriter.create<arith::ConstantOp>( 900 op->getLoc(), rewriter.getI64Type(), 901 rewriter.getI64IntegerAttr(op.count())); 902 903 auto operand = adaptor.operand(); 904 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 905 ValueRange({operand, count})); 906 907 return success(); 908 } 909 910 private: 911 StringRef apiFunctionName; 912 }; 913 914 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 915 public: 916 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 917 : RefCountingOpLowering(converter, ctx, kAddRef) {} 918 }; 919 920 class RuntimeDropRefOpLowering 921 : public RefCountingOpLowering<RuntimeDropRefOp> { 922 public: 923 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 924 : RefCountingOpLowering(converter, ctx, kDropRef) {} 925 }; 926 } // namespace 927 928 //===----------------------------------------------------------------------===// 929 // Convert return operations that return async values from async regions. 930 //===----------------------------------------------------------------------===// 931 932 namespace { 933 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 934 public: 935 using OpConversionPattern::OpConversionPattern; 936 937 LogicalResult 938 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 939 ConversionPatternRewriter &rewriter) const override { 940 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 941 return success(); 942 } 943 }; 944 } // namespace 945 946 //===----------------------------------------------------------------------===// 947 948 namespace { 949 struct ConvertAsyncToLLVMPass 950 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 951 void runOnOperation() override; 952 }; 953 } // namespace 954 955 void ConvertAsyncToLLVMPass::runOnOperation() { 956 ModuleOp module = getOperation(); 957 MLIRContext *ctx = module->getContext(); 958 959 // Add declarations for most functions required by the coroutines lowering. 960 // We delay adding the resume function until it's needed because it currently 961 // fails to compile unless '-O0' is specified. 962 addAsyncRuntimeApiDeclarations(module); 963 964 // Lower async.runtime and async.coro operations to Async Runtime API and 965 // LLVM coroutine intrinsics. 966 967 // Convert async dialect types and operations to LLVM dialect. 968 AsyncRuntimeTypeConverter converter; 969 RewritePatternSet patterns(ctx); 970 971 // We use conversion to LLVM type to lower async.runtime load and store 972 // operations. 973 LLVMTypeConverter llvmConverter(ctx); 974 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 975 976 // Convert async types in function signatures and function calls. 977 populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, converter); 978 populateCallOpTypeConversionPattern(patterns, converter); 979 980 // Convert return operations inside async.execute regions. 981 patterns.add<ReturnOpOpConversion>(converter, ctx); 982 983 // Lower async.runtime operations to the async runtime API calls. 984 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 985 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 986 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 987 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 988 RuntimeDropRefOpLowering>(converter, ctx); 989 990 // Lower async.runtime operations that rely on LLVM type converter to convert 991 // from async value payload type to the LLVM type. 992 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 993 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter, 994 ctx); 995 996 // Lower async coroutine operations to LLVM coroutine intrinsics. 997 patterns 998 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 999 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 1000 converter, ctx); 1001 1002 ConversionTarget target(*ctx); 1003 target 1004 .addLegalOp<arith::ConstantOp, ConstantOp, UnrealizedConversionCastOp>(); 1005 target.addLegalDialect<LLVM::LLVMDialect>(); 1006 1007 // All operations from Async dialect must be lowered to the runtime API and 1008 // LLVM intrinsics calls. 1009 target.addIllegalDialect<AsyncDialect>(); 1010 1011 // Add dynamic legality constraints to apply conversions defined above. 1012 target.addDynamicallyLegalOp<FuncOp>( 1013 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1014 target.addDynamicallyLegalOp<ReturnOp>( 1015 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1016 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1017 return converter.isSignatureLegal(op.getCalleeType()); 1018 }); 1019 1020 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1021 signalPassFailure(); 1022 } 1023 1024 //===----------------------------------------------------------------------===// 1025 // Patterns for structural type conversions for the Async dialect operations. 1026 //===----------------------------------------------------------------------===// 1027 1028 namespace { 1029 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1030 public: 1031 using OpConversionPattern::OpConversionPattern; 1032 LogicalResult 1033 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, 1034 ConversionPatternRewriter &rewriter) const override { 1035 ExecuteOp newOp = 1036 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1037 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1038 newOp.getRegion().end()); 1039 1040 // Set operands and update block argument and result types. 1041 newOp->setOperands(adaptor.getOperands()); 1042 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1043 return failure(); 1044 for (auto result : newOp.getResults()) 1045 result.setType(typeConverter->convertType(result.getType())); 1046 1047 rewriter.replaceOp(op, newOp.getResults()); 1048 return success(); 1049 } 1050 }; 1051 1052 // Dummy pattern to trigger the appropriate type conversion / materialization. 1053 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1054 public: 1055 using OpConversionPattern::OpConversionPattern; 1056 LogicalResult 1057 matchAndRewrite(AwaitOp op, OpAdaptor adaptor, 1058 ConversionPatternRewriter &rewriter) const override { 1059 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); 1060 return success(); 1061 } 1062 }; 1063 1064 // Dummy pattern to trigger the appropriate type conversion / materialization. 1065 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1066 public: 1067 using OpConversionPattern::OpConversionPattern; 1068 LogicalResult 1069 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 1070 ConversionPatternRewriter &rewriter) const override { 1071 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); 1072 return success(); 1073 } 1074 }; 1075 } // namespace 1076 1077 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1078 return std::make_unique<ConvertAsyncToLLVMPass>(); 1079 } 1080 1081 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1082 TypeConverter &typeConverter, RewritePatternSet &patterns, 1083 ConversionTarget &target) { 1084 typeConverter.addConversion([&](TokenType type) { return type; }); 1085 typeConverter.addConversion([&](ValueType type) { 1086 Type converted = typeConverter.convertType(type.getValueType()); 1087 return converted ? ValueType::get(converted) : converted; 1088 }); 1089 1090 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1091 typeConverter, patterns.getContext()); 1092 1093 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1094 [&](Operation *op) { return typeConverter.isLegal(op); }); 1095 } 1096