1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 19 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 #define DEBUG_TYPE "convert-async-to-llvm" 28 29 using namespace mlir; 30 using namespace mlir::async; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 52 static constexpr const char *kGetValueStorage = 53 "mlirAsyncRuntimeGetValueStorage"; 54 static constexpr const char *kAddTokenToGroup = 55 "mlirAsyncRuntimeAddTokenToGroup"; 56 static constexpr const char *kAwaitTokenAndExecute = 57 "mlirAsyncRuntimeAwaitTokenAndExecute"; 58 static constexpr const char *kAwaitValueAndExecute = 59 "mlirAsyncRuntimeAwaitValueAndExecute"; 60 static constexpr const char *kAwaitAllAndExecute = 61 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 62 static constexpr const char *kGetNumWorkerThreads = 63 "mlirAsyncRuntimGetNumWorkerThreads"; 64 65 namespace { 66 /// Async Runtime API function types. 67 /// 68 /// Because we can't create API function signature for type parametrized 69 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 70 /// lowering all async data types become opaque pointers at runtime. 71 struct AsyncAPI { 72 // All async types are lowered to opaque i8* LLVM pointers at runtime. 73 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 74 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 75 } 76 77 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 78 return LLVM::LLVMTokenType::get(ctx); 79 } 80 81 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 82 auto ref = opaquePointerType(ctx); 83 auto count = IntegerType::get(ctx, 64); 84 return FunctionType::get(ctx, {ref, count}, {}); 85 } 86 87 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 88 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 89 } 90 91 static FunctionType createValueFunctionType(MLIRContext *ctx) { 92 auto i64 = IntegerType::get(ctx, 64); 93 auto value = opaquePointerType(ctx); 94 return FunctionType::get(ctx, {i64}, {value}); 95 } 96 97 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 98 auto i64 = IntegerType::get(ctx, 64); 99 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 100 } 101 102 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 103 auto value = opaquePointerType(ctx); 104 auto storage = opaquePointerType(ctx); 105 return FunctionType::get(ctx, {value}, {storage}); 106 } 107 108 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 109 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 110 } 111 112 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 113 auto value = opaquePointerType(ctx); 114 return FunctionType::get(ctx, {value}, {}); 115 } 116 117 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 118 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 119 } 120 121 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 122 auto value = opaquePointerType(ctx); 123 return FunctionType::get(ctx, {value}, {}); 124 } 125 126 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 127 auto i1 = IntegerType::get(ctx, 1); 128 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 129 } 130 131 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 132 auto value = opaquePointerType(ctx); 133 auto i1 = IntegerType::get(ctx, 1); 134 return FunctionType::get(ctx, {value}, {i1}); 135 } 136 137 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 138 auto i1 = IntegerType::get(ctx, 1); 139 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 140 } 141 142 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 143 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 144 } 145 146 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 147 auto value = opaquePointerType(ctx); 148 return FunctionType::get(ctx, {value}, {}); 149 } 150 151 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 152 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 153 } 154 155 static FunctionType executeFunctionType(MLIRContext *ctx) { 156 auto hdl = opaquePointerType(ctx); 157 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 158 return FunctionType::get(ctx, {hdl, resume}, {}); 159 } 160 161 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 162 auto i64 = IntegerType::get(ctx, 64); 163 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 164 {i64}); 165 } 166 167 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 168 auto hdl = opaquePointerType(ctx); 169 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 170 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 171 } 172 173 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 174 auto value = opaquePointerType(ctx); 175 auto hdl = opaquePointerType(ctx); 176 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 177 return FunctionType::get(ctx, {value, hdl, resume}, {}); 178 } 179 180 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 181 auto hdl = opaquePointerType(ctx); 182 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 183 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 184 } 185 186 static FunctionType getNumWorkerThreads(MLIRContext *ctx) { 187 return FunctionType::get(ctx, {}, {IndexType::get(ctx)}); 188 } 189 190 // Auxiliary coroutine resume intrinsic wrapper. 191 static Type resumeFunctionType(MLIRContext *ctx) { 192 auto voidTy = LLVM::LLVMVoidType::get(ctx); 193 auto i8Ptr = opaquePointerType(ctx); 194 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 195 } 196 }; 197 } // namespace 198 199 /// Adds Async Runtime C API declarations to the module. 200 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 201 auto builder = 202 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 203 204 auto addFuncDecl = [&](StringRef name, FunctionType type) { 205 if (module.lookupSymbol(name)) 206 return; 207 builder.create<func::FuncOp>(name, type).setPrivate(); 208 }; 209 210 MLIRContext *ctx = module.getContext(); 211 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 212 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 213 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 214 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 215 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 216 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 217 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 218 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 219 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 220 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 221 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 222 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 223 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 224 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 225 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 226 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 227 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 228 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 229 addFuncDecl(kAwaitTokenAndExecute, 230 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 231 addFuncDecl(kAwaitValueAndExecute, 232 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 233 addFuncDecl(kAwaitAllAndExecute, 234 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 235 addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx)); 236 } 237 238 //===----------------------------------------------------------------------===// 239 // Coroutine resume function wrapper. 240 //===----------------------------------------------------------------------===// 241 242 static constexpr const char *kResume = "__resume"; 243 244 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 245 /// intrinsics. We need this function to be able to pass it to the async 246 /// runtime execute API. 247 static void addResumeFunction(ModuleOp module) { 248 if (module.lookupSymbol(kResume)) 249 return; 250 251 MLIRContext *ctx = module.getContext(); 252 auto loc = module.getLoc(); 253 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 254 255 auto voidTy = LLVM::LLVMVoidType::get(ctx); 256 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 257 258 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 259 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 260 resumeOp.setPrivate(); 261 262 auto *block = resumeOp.addEntryBlock(); 263 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 264 265 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 266 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 267 } 268 269 //===----------------------------------------------------------------------===// 270 // Convert Async dialect types to LLVM types. 271 //===----------------------------------------------------------------------===// 272 273 namespace { 274 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 275 /// their runtime type (opaque pointers) and does not convert any other types. 276 class AsyncRuntimeTypeConverter : public TypeConverter { 277 public: 278 AsyncRuntimeTypeConverter() { 279 addConversion([](Type type) { return type; }); 280 addConversion(convertAsyncTypes); 281 } 282 283 static Optional<Type> convertAsyncTypes(Type type) { 284 if (type.isa<TokenType, GroupType, ValueType>()) 285 return AsyncAPI::opaquePointerType(type.getContext()); 286 287 if (type.isa<CoroIdType, CoroStateType>()) 288 return AsyncAPI::tokenType(type.getContext()); 289 if (type.isa<CoroHandleType>()) 290 return AsyncAPI::opaquePointerType(type.getContext()); 291 292 return llvm::None; 293 } 294 }; 295 } // namespace 296 297 //===----------------------------------------------------------------------===// 298 // Convert async.coro.id to @llvm.coro.id intrinsic. 299 //===----------------------------------------------------------------------===// 300 301 namespace { 302 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 303 public: 304 using OpConversionPattern::OpConversionPattern; 305 306 LogicalResult 307 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, 308 ConversionPatternRewriter &rewriter) const override { 309 auto token = AsyncAPI::tokenType(op->getContext()); 310 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 311 auto loc = op->getLoc(); 312 313 // Constants for initializing coroutine frame. 314 auto constZero = rewriter.create<LLVM::ConstantOp>( 315 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 316 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 317 318 // Get coroutine id: @llvm.coro.id. 319 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 320 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 321 322 return success(); 323 } 324 }; 325 } // namespace 326 327 //===----------------------------------------------------------------------===// 328 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 329 //===----------------------------------------------------------------------===// 330 331 namespace { 332 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 333 public: 334 using OpConversionPattern::OpConversionPattern; 335 336 LogicalResult 337 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, 338 ConversionPatternRewriter &rewriter) const override { 339 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 340 auto loc = op->getLoc(); 341 342 // Get coroutine frame size: @llvm.coro.size.i64. 343 Value coroSize = 344 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 345 // Get coroutine frame alignment: @llvm.coro.align.i64. 346 Value coroAlign = 347 rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type()); 348 349 // Round up the size to be multiple of the alignment. Since aligned_alloc 350 // requires the size parameter be an integral multiple of the alignment 351 // parameter. 352 auto makeConstant = [&](uint64_t c) { 353 return rewriter.create<LLVM::ConstantOp>( 354 op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(c)); 355 }; 356 coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign); 357 coroSize = 358 rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1)); 359 Value negCoroAlign = 360 rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign); 361 coroSize = 362 rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign); 363 364 // Allocate memory for the coroutine frame. 365 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 366 op->getParentOfType<ModuleOp>(), rewriter.getI64Type()); 367 auto coroAlloc = rewriter.create<LLVM::CallOp>( 368 loc, i8Ptr, SymbolRefAttr::get(allocFuncOp), 369 ValueRange{coroAlign, coroSize}); 370 371 // Begin a coroutine: @llvm.coro.begin. 372 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); 373 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 374 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 375 376 return success(); 377 } 378 }; 379 } // namespace 380 381 //===----------------------------------------------------------------------===// 382 // Convert async.coro.free to @llvm.coro.free intrinsic. 383 //===----------------------------------------------------------------------===// 384 385 namespace { 386 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 387 public: 388 using OpConversionPattern::OpConversionPattern; 389 390 LogicalResult 391 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const override { 393 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 394 auto loc = op->getLoc(); 395 396 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 397 auto coroMem = 398 rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands()); 399 400 // Free the memory. 401 auto freeFuncOp = 402 LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 403 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 404 SymbolRefAttr::get(freeFuncOp), 405 ValueRange(coroMem.getResult())); 406 407 return success(); 408 } 409 }; 410 } // namespace 411 412 //===----------------------------------------------------------------------===// 413 // Convert async.coro.end to @llvm.coro.end intrinsic. 414 //===----------------------------------------------------------------------===// 415 416 namespace { 417 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 418 public: 419 using OpConversionPattern::OpConversionPattern; 420 421 LogicalResult 422 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, 423 ConversionPatternRewriter &rewriter) const override { 424 // We are not in the block that is part of the unwind sequence. 425 auto constFalse = rewriter.create<LLVM::ConstantOp>( 426 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 427 428 // Mark the end of a coroutine: @llvm.coro.end. 429 auto coroHdl = adaptor.handle(); 430 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 431 ValueRange({coroHdl, constFalse})); 432 rewriter.eraseOp(op); 433 434 return success(); 435 } 436 }; 437 } // namespace 438 439 //===----------------------------------------------------------------------===// 440 // Convert async.coro.save to @llvm.coro.save intrinsic. 441 //===----------------------------------------------------------------------===// 442 443 namespace { 444 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 445 public: 446 using OpConversionPattern::OpConversionPattern; 447 448 LogicalResult 449 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, 450 ConversionPatternRewriter &rewriter) const override { 451 // Save the coroutine state: @llvm.coro.save 452 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 453 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); 454 455 return success(); 456 } 457 }; 458 } // namespace 459 460 //===----------------------------------------------------------------------===// 461 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 462 //===----------------------------------------------------------------------===// 463 464 namespace { 465 466 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 467 /// branch to the appropriate block based on the return code. 468 /// 469 /// Before: 470 /// 471 /// ^suspended: 472 /// "opBefore"(...) 473 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 474 /// ^resume: 475 /// "op"(...) 476 /// ^cleanup: ... 477 /// ^suspend: ... 478 /// 479 /// After: 480 /// 481 /// ^suspended: 482 /// "opBefore"(...) 483 /// %suspend = llmv.intr.coro.suspend ... 484 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 485 /// ^resume: 486 /// "op"(...) 487 /// ^cleanup: ... 488 /// ^suspend: ... 489 /// 490 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 491 public: 492 using OpConversionPattern::OpConversionPattern; 493 494 LogicalResult 495 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, 496 ConversionPatternRewriter &rewriter) const override { 497 auto i8 = rewriter.getIntegerType(8); 498 auto i32 = rewriter.getI32Type(); 499 auto loc = op->getLoc(); 500 501 // This is not a final suspension point. 502 auto constFalse = rewriter.create<LLVM::ConstantOp>( 503 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 504 505 // Suspend a coroutine: @llvm.coro.suspend 506 auto coroState = adaptor.state(); 507 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 508 loc, i8, ValueRange({coroState, constFalse})); 509 510 // Cast return code to i32. 511 512 // After a suspension point decide if we should branch into resume, cleanup 513 // or suspend block of the coroutine (see @llvm.coro.suspend return code 514 // documentation). 515 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 516 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 517 op.cleanupDest()}; 518 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 519 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 520 /*defaultDestination=*/op.suspendDest(), 521 /*defaultOperands=*/ValueRange(), 522 /*caseValues=*/caseValues, 523 /*caseDestinations=*/caseDest, 524 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), 525 /*branchWeights=*/ArrayRef<int32_t>()); 526 527 return success(); 528 } 529 }; 530 } // namespace 531 532 //===----------------------------------------------------------------------===// 533 // Convert async.runtime.create to the corresponding runtime API call. 534 // 535 // To allocate storage for the async values we use getelementptr trick: 536 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 537 //===----------------------------------------------------------------------===// 538 539 namespace { 540 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 541 public: 542 using OpConversionPattern::OpConversionPattern; 543 544 LogicalResult 545 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, 546 ConversionPatternRewriter &rewriter) const override { 547 TypeConverter *converter = getTypeConverter(); 548 Type resultType = op->getResultTypes()[0]; 549 550 // Tokens creation maps to a simple function call. 551 if (resultType.isa<TokenType>()) { 552 rewriter.replaceOpWithNewOp<func::CallOp>( 553 op, kCreateToken, converter->convertType(resultType)); 554 return success(); 555 } 556 557 // To create a value we need to compute the storage requirement. 558 if (auto value = resultType.dyn_cast<ValueType>()) { 559 // Returns the size requirements for the async value storage. 560 auto sizeOf = [&](ValueType valueType) -> Value { 561 auto loc = op->getLoc(); 562 auto i64 = rewriter.getI64Type(); 563 564 auto storedType = converter->convertType(valueType.getValueType()); 565 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 566 567 // %Size = getelementptr %T* null, int 1 568 // %SizeI = ptrtoint %T* %Size to i64 569 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 570 auto one = rewriter.create<LLVM::ConstantOp>( 571 loc, i64, rewriter.getI64IntegerAttr(1)); 572 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 573 one.getResult()); 574 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); 575 }; 576 577 rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType, 578 sizeOf(value)); 579 580 return success(); 581 } 582 583 return rewriter.notifyMatchFailure(op, "unsupported async type"); 584 } 585 }; 586 } // namespace 587 588 //===----------------------------------------------------------------------===// 589 // Convert async.runtime.create_group to the corresponding runtime API call. 590 //===----------------------------------------------------------------------===// 591 592 namespace { 593 class RuntimeCreateGroupOpLowering 594 : public OpConversionPattern<RuntimeCreateGroupOp> { 595 public: 596 using OpConversionPattern::OpConversionPattern; 597 598 LogicalResult 599 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, 600 ConversionPatternRewriter &rewriter) const override { 601 TypeConverter *converter = getTypeConverter(); 602 Type resultType = op.getResult().getType(); 603 604 rewriter.replaceOpWithNewOp<func::CallOp>( 605 op, kCreateGroup, converter->convertType(resultType), 606 adaptor.getOperands()); 607 return success(); 608 } 609 }; 610 } // namespace 611 612 //===----------------------------------------------------------------------===// 613 // Convert async.runtime.set_available to the corresponding runtime API call. 614 //===----------------------------------------------------------------------===// 615 616 namespace { 617 class RuntimeSetAvailableOpLowering 618 : public OpConversionPattern<RuntimeSetAvailableOp> { 619 public: 620 using OpConversionPattern::OpConversionPattern; 621 622 LogicalResult 623 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, 624 ConversionPatternRewriter &rewriter) const override { 625 StringRef apiFuncName = 626 TypeSwitch<Type, StringRef>(op.operand().getType()) 627 .Case<TokenType>([](Type) { return kEmplaceToken; }) 628 .Case<ValueType>([](Type) { return kEmplaceValue; }); 629 630 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(), 631 adaptor.getOperands()); 632 633 return success(); 634 } 635 }; 636 } // namespace 637 638 //===----------------------------------------------------------------------===// 639 // Convert async.runtime.set_error to the corresponding runtime API call. 640 //===----------------------------------------------------------------------===// 641 642 namespace { 643 class RuntimeSetErrorOpLowering 644 : public OpConversionPattern<RuntimeSetErrorOp> { 645 public: 646 using OpConversionPattern::OpConversionPattern; 647 648 LogicalResult 649 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, 650 ConversionPatternRewriter &rewriter) const override { 651 StringRef apiFuncName = 652 TypeSwitch<Type, StringRef>(op.operand().getType()) 653 .Case<TokenType>([](Type) { return kSetTokenError; }) 654 .Case<ValueType>([](Type) { return kSetValueError; }); 655 656 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(), 657 adaptor.getOperands()); 658 659 return success(); 660 } 661 }; 662 } // namespace 663 664 //===----------------------------------------------------------------------===// 665 // Convert async.runtime.is_error to the corresponding runtime API call. 666 //===----------------------------------------------------------------------===// 667 668 namespace { 669 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 670 public: 671 using OpConversionPattern::OpConversionPattern; 672 673 LogicalResult 674 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, 675 ConversionPatternRewriter &rewriter) const override { 676 StringRef apiFuncName = 677 TypeSwitch<Type, StringRef>(op.operand().getType()) 678 .Case<TokenType>([](Type) { return kIsTokenError; }) 679 .Case<GroupType>([](Type) { return kIsGroupError; }) 680 .Case<ValueType>([](Type) { return kIsValueError; }); 681 682 rewriter.replaceOpWithNewOp<func::CallOp>( 683 op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands()); 684 return success(); 685 } 686 }; 687 } // namespace 688 689 //===----------------------------------------------------------------------===// 690 // Convert async.runtime.await to the corresponding runtime API call. 691 //===----------------------------------------------------------------------===// 692 693 namespace { 694 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 695 public: 696 using OpConversionPattern::OpConversionPattern; 697 698 LogicalResult 699 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, 700 ConversionPatternRewriter &rewriter) const override { 701 StringRef apiFuncName = 702 TypeSwitch<Type, StringRef>(op.operand().getType()) 703 .Case<TokenType>([](Type) { return kAwaitToken; }) 704 .Case<ValueType>([](Type) { return kAwaitValue; }) 705 .Case<GroupType>([](Type) { return kAwaitGroup; }); 706 707 rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(), 708 adaptor.getOperands()); 709 rewriter.eraseOp(op); 710 711 return success(); 712 } 713 }; 714 } // namespace 715 716 //===----------------------------------------------------------------------===// 717 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 718 //===----------------------------------------------------------------------===// 719 720 namespace { 721 class RuntimeAwaitAndResumeOpLowering 722 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 723 public: 724 using OpConversionPattern::OpConversionPattern; 725 726 LogicalResult 727 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, 728 ConversionPatternRewriter &rewriter) const override { 729 StringRef apiFuncName = 730 TypeSwitch<Type, StringRef>(op.operand().getType()) 731 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 732 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 733 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 734 735 Value operand = adaptor.operand(); 736 Value handle = adaptor.handle(); 737 738 // A pointer to coroutine resume intrinsic wrapper. 739 addResumeFunction(op->getParentOfType<ModuleOp>()); 740 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 741 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 742 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 743 744 rewriter.create<func::CallOp>( 745 op->getLoc(), apiFuncName, TypeRange(), 746 ValueRange({operand, handle, resumePtr.getRes()})); 747 rewriter.eraseOp(op); 748 749 return success(); 750 } 751 }; 752 } // namespace 753 754 //===----------------------------------------------------------------------===// 755 // Convert async.runtime.resume to the corresponding runtime API call. 756 //===----------------------------------------------------------------------===// 757 758 namespace { 759 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 760 public: 761 using OpConversionPattern::OpConversionPattern; 762 763 LogicalResult 764 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, 765 ConversionPatternRewriter &rewriter) const override { 766 // A pointer to coroutine resume intrinsic wrapper. 767 addResumeFunction(op->getParentOfType<ModuleOp>()); 768 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 769 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 770 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 771 772 // Call async runtime API to execute a coroutine in the managed thread. 773 auto coroHdl = adaptor.handle(); 774 rewriter.replaceOpWithNewOp<func::CallOp>( 775 op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); 776 777 return success(); 778 } 779 }; 780 } // namespace 781 782 //===----------------------------------------------------------------------===// 783 // Convert async.runtime.store to the corresponding runtime API call. 784 //===----------------------------------------------------------------------===// 785 786 namespace { 787 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 788 public: 789 using OpConversionPattern::OpConversionPattern; 790 791 LogicalResult 792 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, 793 ConversionPatternRewriter &rewriter) const override { 794 Location loc = op->getLoc(); 795 796 // Get a pointer to the async value storage from the runtime. 797 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 798 auto storage = adaptor.storage(); 799 auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage, 800 TypeRange(i8Ptr), storage); 801 802 // Cast from i8* to the LLVM pointer type. 803 auto valueType = op.value().getType(); 804 auto llvmValueType = getTypeConverter()->convertType(valueType); 805 if (!llvmValueType) 806 return rewriter.notifyMatchFailure( 807 op, "failed to convert stored value type to LLVM type"); 808 809 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 810 loc, LLVM::LLVMPointerType::get(llvmValueType), 811 storagePtr.getResult(0)); 812 813 // Store the yielded value into the async value storage. 814 auto value = adaptor.value(); 815 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 816 817 // Erase the original runtime store operation. 818 rewriter.eraseOp(op); 819 820 return success(); 821 } 822 }; 823 } // namespace 824 825 //===----------------------------------------------------------------------===// 826 // Convert async.runtime.load to the corresponding runtime API call. 827 //===----------------------------------------------------------------------===// 828 829 namespace { 830 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 831 public: 832 using OpConversionPattern::OpConversionPattern; 833 834 LogicalResult 835 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, 836 ConversionPatternRewriter &rewriter) const override { 837 Location loc = op->getLoc(); 838 839 // Get a pointer to the async value storage from the runtime. 840 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 841 auto storage = adaptor.storage(); 842 auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage, 843 TypeRange(i8Ptr), storage); 844 845 // Cast from i8* to the LLVM pointer type. 846 auto valueType = op.result().getType(); 847 auto llvmValueType = getTypeConverter()->convertType(valueType); 848 if (!llvmValueType) 849 return rewriter.notifyMatchFailure( 850 op, "failed to convert loaded value type to LLVM type"); 851 852 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 853 loc, LLVM::LLVMPointerType::get(llvmValueType), 854 storagePtr.getResult(0)); 855 856 // Load from the casted pointer. 857 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 858 859 return success(); 860 } 861 }; 862 } // namespace 863 864 //===----------------------------------------------------------------------===// 865 // Convert async.runtime.add_to_group to the corresponding runtime API call. 866 //===----------------------------------------------------------------------===// 867 868 namespace { 869 class RuntimeAddToGroupOpLowering 870 : public OpConversionPattern<RuntimeAddToGroupOp> { 871 public: 872 using OpConversionPattern::OpConversionPattern; 873 874 LogicalResult 875 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, 876 ConversionPatternRewriter &rewriter) const override { 877 // Currently we can only add tokens to the group. 878 if (!op.operand().getType().isa<TokenType>()) 879 return rewriter.notifyMatchFailure(op, "only token type is supported"); 880 881 // Replace with a runtime API function call. 882 rewriter.replaceOpWithNewOp<func::CallOp>( 883 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); 884 885 return success(); 886 } 887 }; 888 } // namespace 889 890 //===----------------------------------------------------------------------===// 891 // Convert async.runtime.num_worker_threads to the corresponding runtime API 892 // call. 893 //===----------------------------------------------------------------------===// 894 895 namespace { 896 class RuntimeNumWorkerThreadsOpLowering 897 : public OpConversionPattern<RuntimeNumWorkerThreadsOp> { 898 public: 899 using OpConversionPattern::OpConversionPattern; 900 901 LogicalResult 902 matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor, 903 ConversionPatternRewriter &rewriter) const override { 904 905 // Replace with a runtime API function call. 906 rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads, 907 rewriter.getIndexType()); 908 909 return success(); 910 } 911 }; 912 } // namespace 913 914 //===----------------------------------------------------------------------===// 915 // Async reference counting ops lowering (`async.runtime.add_ref` and 916 // `async.runtime.drop_ref` to the corresponding API calls). 917 //===----------------------------------------------------------------------===// 918 919 namespace { 920 template <typename RefCountingOp> 921 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 922 public: 923 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 924 StringRef apiFunctionName) 925 : OpConversionPattern<RefCountingOp>(converter, ctx), 926 apiFunctionName(apiFunctionName) {} 927 928 LogicalResult 929 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, 930 ConversionPatternRewriter &rewriter) const override { 931 auto count = rewriter.create<arith::ConstantOp>( 932 op->getLoc(), rewriter.getI64Type(), 933 rewriter.getI64IntegerAttr(op.count())); 934 935 auto operand = adaptor.operand(); 936 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName, 937 ValueRange({operand, count})); 938 939 return success(); 940 } 941 942 private: 943 StringRef apiFunctionName; 944 }; 945 946 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 947 public: 948 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 949 : RefCountingOpLowering(converter, ctx, kAddRef) {} 950 }; 951 952 class RuntimeDropRefOpLowering 953 : public RefCountingOpLowering<RuntimeDropRefOp> { 954 public: 955 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 956 : RefCountingOpLowering(converter, ctx, kDropRef) {} 957 }; 958 } // namespace 959 960 //===----------------------------------------------------------------------===// 961 // Convert return operations that return async values from async regions. 962 //===----------------------------------------------------------------------===// 963 964 namespace { 965 class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> { 966 public: 967 using OpConversionPattern::OpConversionPattern; 968 969 LogicalResult 970 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 971 ConversionPatternRewriter &rewriter) const override { 972 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 973 return success(); 974 } 975 }; 976 } // namespace 977 978 //===----------------------------------------------------------------------===// 979 980 namespace { 981 struct ConvertAsyncToLLVMPass 982 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 983 void runOnOperation() override; 984 }; 985 } // namespace 986 987 void ConvertAsyncToLLVMPass::runOnOperation() { 988 ModuleOp module = getOperation(); 989 MLIRContext *ctx = module->getContext(); 990 991 // Add declarations for most functions required by the coroutines lowering. 992 // We delay adding the resume function until it's needed because it currently 993 // fails to compile unless '-O0' is specified. 994 addAsyncRuntimeApiDeclarations(module); 995 996 // Lower async.runtime and async.coro operations to Async Runtime API and 997 // LLVM coroutine intrinsics. 998 999 // Convert async dialect types and operations to LLVM dialect. 1000 AsyncRuntimeTypeConverter converter; 1001 RewritePatternSet patterns(ctx); 1002 1003 // We use conversion to LLVM type to lower async.runtime load and store 1004 // operations. 1005 LLVMTypeConverter llvmConverter(ctx); 1006 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 1007 1008 // Convert async types in function signatures and function calls. 1009 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 1010 converter); 1011 populateCallOpTypeConversionPattern(patterns, converter); 1012 1013 // Convert return operations inside async.execute regions. 1014 patterns.add<ReturnOpOpConversion>(converter, ctx); 1015 1016 // Lower async.runtime operations to the async runtime API calls. 1017 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 1018 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 1019 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 1020 RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering, 1021 RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter, 1022 ctx); 1023 1024 // Lower async.runtime operations that rely on LLVM type converter to convert 1025 // from async value payload type to the LLVM type. 1026 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 1027 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter, 1028 ctx); 1029 1030 // Lower async coroutine operations to LLVM coroutine intrinsics. 1031 patterns 1032 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 1033 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 1034 converter, ctx); 1035 1036 ConversionTarget target(*ctx); 1037 target.addLegalOp<arith::ConstantOp, func::ConstantOp, 1038 UnrealizedConversionCastOp>(); 1039 target.addLegalDialect<LLVM::LLVMDialect>(); 1040 1041 // All operations from Async dialect must be lowered to the runtime API and 1042 // LLVM intrinsics calls. 1043 target.addIllegalDialect<AsyncDialect>(); 1044 1045 // Add dynamic legality constraints to apply conversions defined above. 1046 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 1047 return converter.isSignatureLegal(op.getFunctionType()); 1048 }); 1049 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 1050 return converter.isLegal(op.getOperandTypes()); 1051 }); 1052 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 1053 return converter.isSignatureLegal(op.getCalleeType()); 1054 }); 1055 1056 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1057 signalPassFailure(); 1058 } 1059 1060 //===----------------------------------------------------------------------===// 1061 // Patterns for structural type conversions for the Async dialect operations. 1062 //===----------------------------------------------------------------------===// 1063 1064 namespace { 1065 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1066 public: 1067 using OpConversionPattern::OpConversionPattern; 1068 LogicalResult 1069 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, 1070 ConversionPatternRewriter &rewriter) const override { 1071 ExecuteOp newOp = 1072 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1073 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1074 newOp.getRegion().end()); 1075 1076 // Set operands and update block argument and result types. 1077 newOp->setOperands(adaptor.getOperands()); 1078 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1079 return failure(); 1080 for (auto result : newOp.getResults()) 1081 result.setType(typeConverter->convertType(result.getType())); 1082 1083 rewriter.replaceOp(op, newOp.getResults()); 1084 return success(); 1085 } 1086 }; 1087 1088 // Dummy pattern to trigger the appropriate type conversion / materialization. 1089 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1090 public: 1091 using OpConversionPattern::OpConversionPattern; 1092 LogicalResult 1093 matchAndRewrite(AwaitOp op, OpAdaptor adaptor, 1094 ConversionPatternRewriter &rewriter) const override { 1095 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); 1096 return success(); 1097 } 1098 }; 1099 1100 // Dummy pattern to trigger the appropriate type conversion / materialization. 1101 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1102 public: 1103 using OpConversionPattern::OpConversionPattern; 1104 LogicalResult 1105 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 1106 ConversionPatternRewriter &rewriter) const override { 1107 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); 1108 return success(); 1109 } 1110 }; 1111 } // namespace 1112 1113 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1114 return std::make_unique<ConvertAsyncToLLVMPass>(); 1115 } 1116 1117 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1118 TypeConverter &typeConverter, RewritePatternSet &patterns, 1119 ConversionTarget &target) { 1120 typeConverter.addConversion([&](TokenType type) { return type; }); 1121 typeConverter.addConversion([&](ValueType type) { 1122 Type converted = typeConverter.convertType(type.getValueType()); 1123 return converted ? ValueType::get(converted) : converted; 1124 }); 1125 1126 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1127 typeConverter, patterns.getContext()); 1128 1129 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1130 [&](Operation *op) { return typeConverter.isLegal(op); }); 1131 } 1132