1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/TypeUtilities.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 #define DEBUG_TYPE "convert-async-to-llvm" 28 29 using namespace mlir; 30 using namespace mlir::async; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 52 static constexpr const char *kGetValueStorage = 53 "mlirAsyncRuntimeGetValueStorage"; 54 static constexpr const char *kAddTokenToGroup = 55 "mlirAsyncRuntimeAddTokenToGroup"; 56 static constexpr const char *kAwaitTokenAndExecute = 57 "mlirAsyncRuntimeAwaitTokenAndExecute"; 58 static constexpr const char *kAwaitValueAndExecute = 59 "mlirAsyncRuntimeAwaitValueAndExecute"; 60 static constexpr const char *kAwaitAllAndExecute = 61 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 62 static constexpr const char *kGetNumWorkerThreads = 63 "mlirAsyncRuntimGetNumWorkerThreads"; 64 65 namespace { 66 /// Async Runtime API function types. 67 /// 68 /// Because we can't create API function signature for type parametrized 69 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 70 /// lowering all async data types become opaque pointers at runtime. 71 struct AsyncAPI { 72 // All async types are lowered to opaque i8* LLVM pointers at runtime. 73 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 74 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 75 } 76 77 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 78 return LLVM::LLVMTokenType::get(ctx); 79 } 80 81 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 82 auto ref = opaquePointerType(ctx); 83 auto count = IntegerType::get(ctx, 64); 84 return FunctionType::get(ctx, {ref, count}, {}); 85 } 86 87 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 88 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 89 } 90 91 static FunctionType createValueFunctionType(MLIRContext *ctx) { 92 auto i64 = IntegerType::get(ctx, 64); 93 auto value = opaquePointerType(ctx); 94 return FunctionType::get(ctx, {i64}, {value}); 95 } 96 97 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 98 auto i64 = IntegerType::get(ctx, 64); 99 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 100 } 101 102 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 103 auto value = opaquePointerType(ctx); 104 auto storage = opaquePointerType(ctx); 105 return FunctionType::get(ctx, {value}, {storage}); 106 } 107 108 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 109 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 110 } 111 112 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 113 auto value = opaquePointerType(ctx); 114 return FunctionType::get(ctx, {value}, {}); 115 } 116 117 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 118 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 119 } 120 121 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 122 auto value = opaquePointerType(ctx); 123 return FunctionType::get(ctx, {value}, {}); 124 } 125 126 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 127 auto i1 = IntegerType::get(ctx, 1); 128 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 129 } 130 131 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 132 auto value = opaquePointerType(ctx); 133 auto i1 = IntegerType::get(ctx, 1); 134 return FunctionType::get(ctx, {value}, {i1}); 135 } 136 137 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 138 auto i1 = IntegerType::get(ctx, 1); 139 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 140 } 141 142 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 143 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 144 } 145 146 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 147 auto value = opaquePointerType(ctx); 148 return FunctionType::get(ctx, {value}, {}); 149 } 150 151 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 152 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 153 } 154 155 static FunctionType executeFunctionType(MLIRContext *ctx) { 156 auto hdl = opaquePointerType(ctx); 157 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 158 return FunctionType::get(ctx, {hdl, resume}, {}); 159 } 160 161 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 162 auto i64 = IntegerType::get(ctx, 64); 163 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 164 {i64}); 165 } 166 167 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 168 auto hdl = opaquePointerType(ctx); 169 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 170 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 171 } 172 173 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 174 auto value = opaquePointerType(ctx); 175 auto hdl = opaquePointerType(ctx); 176 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 177 return FunctionType::get(ctx, {value, hdl, resume}, {}); 178 } 179 180 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 181 auto hdl = opaquePointerType(ctx); 182 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 183 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 184 } 185 186 static FunctionType getNumWorkerThreads(MLIRContext *ctx) { 187 return FunctionType::get(ctx, {}, {IndexType::get(ctx)}); 188 } 189 190 // Auxiliary coroutine resume intrinsic wrapper. 191 static Type resumeFunctionType(MLIRContext *ctx) { 192 auto voidTy = LLVM::LLVMVoidType::get(ctx); 193 auto i8Ptr = opaquePointerType(ctx); 194 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 195 } 196 }; 197 } // namespace 198 199 /// Adds Async Runtime C API declarations to the module. 200 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 201 auto builder = 202 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 203 204 auto addFuncDecl = [&](StringRef name, FunctionType type) { 205 if (module.lookupSymbol(name)) 206 return; 207 builder.create<FuncOp>(name, type).setPrivate(); 208 }; 209 210 MLIRContext *ctx = module.getContext(); 211 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 212 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 213 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 214 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 215 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 216 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 217 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 218 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 219 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 220 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 221 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 222 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 223 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 224 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 225 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 226 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 227 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 228 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 229 addFuncDecl(kAwaitTokenAndExecute, 230 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 231 addFuncDecl(kAwaitValueAndExecute, 232 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 233 addFuncDecl(kAwaitAllAndExecute, 234 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 235 addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx)); 236 } 237 238 //===----------------------------------------------------------------------===// 239 // Coroutine resume function wrapper. 240 //===----------------------------------------------------------------------===// 241 242 static constexpr const char *kResume = "__resume"; 243 244 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 245 /// intrinsics. We need this function to be able to pass it to the async 246 /// runtime execute API. 247 static void addResumeFunction(ModuleOp module) { 248 if (module.lookupSymbol(kResume)) 249 return; 250 251 MLIRContext *ctx = module.getContext(); 252 auto loc = module.getLoc(); 253 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 254 255 auto voidTy = LLVM::LLVMVoidType::get(ctx); 256 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 257 258 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 259 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 260 resumeOp.setPrivate(); 261 262 auto *block = resumeOp.addEntryBlock(); 263 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 264 265 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 266 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 267 } 268 269 //===----------------------------------------------------------------------===// 270 // Convert Async dialect types to LLVM types. 271 //===----------------------------------------------------------------------===// 272 273 namespace { 274 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 275 /// their runtime type (opaque pointers) and does not convert any other types. 276 class AsyncRuntimeTypeConverter : public TypeConverter { 277 public: 278 AsyncRuntimeTypeConverter() { 279 addConversion([](Type type) { return type; }); 280 addConversion(convertAsyncTypes); 281 } 282 283 static Optional<Type> convertAsyncTypes(Type type) { 284 if (type.isa<TokenType, GroupType, ValueType>()) 285 return AsyncAPI::opaquePointerType(type.getContext()); 286 287 if (type.isa<CoroIdType, CoroStateType>()) 288 return AsyncAPI::tokenType(type.getContext()); 289 if (type.isa<CoroHandleType>()) 290 return AsyncAPI::opaquePointerType(type.getContext()); 291 292 return llvm::None; 293 } 294 }; 295 } // namespace 296 297 //===----------------------------------------------------------------------===// 298 // Convert async.coro.id to @llvm.coro.id intrinsic. 299 //===----------------------------------------------------------------------===// 300 301 namespace { 302 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 303 public: 304 using OpConversionPattern::OpConversionPattern; 305 306 LogicalResult 307 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, 308 ConversionPatternRewriter &rewriter) const override { 309 auto token = AsyncAPI::tokenType(op->getContext()); 310 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 311 auto loc = op->getLoc(); 312 313 // Constants for initializing coroutine frame. 314 auto constZero = rewriter.create<LLVM::ConstantOp>( 315 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 316 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 317 318 // Get coroutine id: @llvm.coro.id. 319 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 320 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 321 322 return success(); 323 } 324 }; 325 } // namespace 326 327 //===----------------------------------------------------------------------===// 328 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 329 //===----------------------------------------------------------------------===// 330 331 namespace { 332 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 333 public: 334 using OpConversionPattern::OpConversionPattern; 335 336 LogicalResult 337 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, 338 ConversionPatternRewriter &rewriter) const override { 339 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 340 auto loc = op->getLoc(); 341 342 // Get coroutine frame size: @llvm.coro.size.i64. 343 Value coroSize = 344 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 345 // Get coroutine frame alignment: @llvm.coro.align.i64. 346 Value coroAlign = 347 rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type()); 348 349 // Round up the size to be multiple of the alignment. Since aligned_alloc 350 // requires the size parameter be an integral multiple of the alignment 351 // parameter. 352 auto makeConstant = [&](uint64_t c) { 353 return rewriter.create<LLVM::ConstantOp>( 354 op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(c)); 355 }; 356 coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign); 357 coroSize = 358 rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1)); 359 Value negCoroAlign = 360 rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign); 361 coroSize = 362 rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign); 363 364 // Allocate memory for the coroutine frame. 365 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 366 op->getParentOfType<ModuleOp>(), rewriter.getI64Type()); 367 auto coroAlloc = rewriter.create<LLVM::CallOp>( 368 loc, i8Ptr, SymbolRefAttr::get(allocFuncOp), 369 ValueRange{coroAlign, coroSize}); 370 371 // Begin a coroutine: @llvm.coro.begin. 372 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); 373 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 374 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); 375 376 return success(); 377 } 378 }; 379 } // namespace 380 381 //===----------------------------------------------------------------------===// 382 // Convert async.coro.free to @llvm.coro.free intrinsic. 383 //===----------------------------------------------------------------------===// 384 385 namespace { 386 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 387 public: 388 using OpConversionPattern::OpConversionPattern; 389 390 LogicalResult 391 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, 392 ConversionPatternRewriter &rewriter) const override { 393 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 394 auto loc = op->getLoc(); 395 396 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 397 auto coroMem = 398 rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands()); 399 400 // Free the memory. 401 auto freeFuncOp = 402 LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 403 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 404 SymbolRefAttr::get(freeFuncOp), 405 ValueRange(coroMem.getResult())); 406 407 return success(); 408 } 409 }; 410 } // namespace 411 412 //===----------------------------------------------------------------------===// 413 // Convert async.coro.end to @llvm.coro.end intrinsic. 414 //===----------------------------------------------------------------------===// 415 416 namespace { 417 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 418 public: 419 using OpConversionPattern::OpConversionPattern; 420 421 LogicalResult 422 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, 423 ConversionPatternRewriter &rewriter) const override { 424 // We are not in the block that is part of the unwind sequence. 425 auto constFalse = rewriter.create<LLVM::ConstantOp>( 426 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 427 428 // Mark the end of a coroutine: @llvm.coro.end. 429 auto coroHdl = adaptor.handle(); 430 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(), 431 ValueRange({coroHdl, constFalse})); 432 rewriter.eraseOp(op); 433 434 return success(); 435 } 436 }; 437 } // namespace 438 439 //===----------------------------------------------------------------------===// 440 // Convert async.coro.save to @llvm.coro.save intrinsic. 441 //===----------------------------------------------------------------------===// 442 443 namespace { 444 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 445 public: 446 using OpConversionPattern::OpConversionPattern; 447 448 LogicalResult 449 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, 450 ConversionPatternRewriter &rewriter) const override { 451 // Save the coroutine state: @llvm.coro.save 452 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 453 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); 454 455 return success(); 456 } 457 }; 458 } // namespace 459 460 //===----------------------------------------------------------------------===// 461 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 462 //===----------------------------------------------------------------------===// 463 464 namespace { 465 466 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 467 /// branch to the appropriate block based on the return code. 468 /// 469 /// Before: 470 /// 471 /// ^suspended: 472 /// "opBefore"(...) 473 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 474 /// ^resume: 475 /// "op"(...) 476 /// ^cleanup: ... 477 /// ^suspend: ... 478 /// 479 /// After: 480 /// 481 /// ^suspended: 482 /// "opBefore"(...) 483 /// %suspend = llmv.intr.coro.suspend ... 484 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 485 /// ^resume: 486 /// "op"(...) 487 /// ^cleanup: ... 488 /// ^suspend: ... 489 /// 490 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 491 public: 492 using OpConversionPattern::OpConversionPattern; 493 494 LogicalResult 495 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, 496 ConversionPatternRewriter &rewriter) const override { 497 auto i8 = rewriter.getIntegerType(8); 498 auto i32 = rewriter.getI32Type(); 499 auto loc = op->getLoc(); 500 501 // This is not a final suspension point. 502 auto constFalse = rewriter.create<LLVM::ConstantOp>( 503 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 504 505 // Suspend a coroutine: @llvm.coro.suspend 506 auto coroState = adaptor.state(); 507 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 508 loc, i8, ValueRange({coroState, constFalse})); 509 510 // Cast return code to i32. 511 512 // After a suspension point decide if we should branch into resume, cleanup 513 // or suspend block of the coroutine (see @llvm.coro.suspend return code 514 // documentation). 515 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 516 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 517 op.cleanupDest()}; 518 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 519 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 520 /*defaultDestination=*/op.suspendDest(), 521 /*defaultOperands=*/ValueRange(), 522 /*caseValues=*/caseValues, 523 /*caseDestinations=*/caseDest, 524 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), 525 /*branchWeights=*/ArrayRef<int32_t>()); 526 527 return success(); 528 } 529 }; 530 } // namespace 531 532 //===----------------------------------------------------------------------===// 533 // Convert async.runtime.create to the corresponding runtime API call. 534 // 535 // To allocate storage for the async values we use getelementptr trick: 536 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 537 //===----------------------------------------------------------------------===// 538 539 namespace { 540 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 541 public: 542 using OpConversionPattern::OpConversionPattern; 543 544 LogicalResult 545 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, 546 ConversionPatternRewriter &rewriter) const override { 547 TypeConverter *converter = getTypeConverter(); 548 Type resultType = op->getResultTypes()[0]; 549 550 // Tokens creation maps to a simple function call. 551 if (resultType.isa<TokenType>()) { 552 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken, 553 converter->convertType(resultType)); 554 return success(); 555 } 556 557 // To create a value we need to compute the storage requirement. 558 if (auto value = resultType.dyn_cast<ValueType>()) { 559 // Returns the size requirements for the async value storage. 560 auto sizeOf = [&](ValueType valueType) -> Value { 561 auto loc = op->getLoc(); 562 auto i64 = rewriter.getI64Type(); 563 564 auto storedType = converter->convertType(valueType.getValueType()); 565 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 566 567 // %Size = getelementptr %T* null, int 1 568 // %SizeI = ptrtoint %T* %Size to i64 569 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 570 auto one = rewriter.create<LLVM::ConstantOp>( 571 loc, i64, rewriter.getI64IntegerAttr(1)); 572 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 573 one.getResult()); 574 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); 575 }; 576 577 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 578 sizeOf(value)); 579 580 return success(); 581 } 582 583 return rewriter.notifyMatchFailure(op, "unsupported async type"); 584 } 585 }; 586 } // namespace 587 588 //===----------------------------------------------------------------------===// 589 // Convert async.runtime.create_group to the corresponding runtime API call. 590 //===----------------------------------------------------------------------===// 591 592 namespace { 593 class RuntimeCreateGroupOpLowering 594 : public OpConversionPattern<RuntimeCreateGroupOp> { 595 public: 596 using OpConversionPattern::OpConversionPattern; 597 598 LogicalResult 599 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, 600 ConversionPatternRewriter &rewriter) const override { 601 TypeConverter *converter = getTypeConverter(); 602 Type resultType = op.getResult().getType(); 603 604 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, 605 converter->convertType(resultType), 606 adaptor.getOperands()); 607 return success(); 608 } 609 }; 610 } // namespace 611 612 //===----------------------------------------------------------------------===// 613 // Convert async.runtime.set_available to the corresponding runtime API call. 614 //===----------------------------------------------------------------------===// 615 616 namespace { 617 class RuntimeSetAvailableOpLowering 618 : public OpConversionPattern<RuntimeSetAvailableOp> { 619 public: 620 using OpConversionPattern::OpConversionPattern; 621 622 LogicalResult 623 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, 624 ConversionPatternRewriter &rewriter) const override { 625 StringRef apiFuncName = 626 TypeSwitch<Type, StringRef>(op.operand().getType()) 627 .Case<TokenType>([](Type) { return kEmplaceToken; }) 628 .Case<ValueType>([](Type) { return kEmplaceValue; }); 629 630 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 631 adaptor.getOperands()); 632 633 return success(); 634 } 635 }; 636 } // namespace 637 638 //===----------------------------------------------------------------------===// 639 // Convert async.runtime.set_error to the corresponding runtime API call. 640 //===----------------------------------------------------------------------===// 641 642 namespace { 643 class RuntimeSetErrorOpLowering 644 : public OpConversionPattern<RuntimeSetErrorOp> { 645 public: 646 using OpConversionPattern::OpConversionPattern; 647 648 LogicalResult 649 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, 650 ConversionPatternRewriter &rewriter) const override { 651 StringRef apiFuncName = 652 TypeSwitch<Type, StringRef>(op.operand().getType()) 653 .Case<TokenType>([](Type) { return kSetTokenError; }) 654 .Case<ValueType>([](Type) { return kSetValueError; }); 655 656 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), 657 adaptor.getOperands()); 658 659 return success(); 660 } 661 }; 662 } // namespace 663 664 //===----------------------------------------------------------------------===// 665 // Convert async.runtime.is_error to the corresponding runtime API call. 666 //===----------------------------------------------------------------------===// 667 668 namespace { 669 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 670 public: 671 using OpConversionPattern::OpConversionPattern; 672 673 LogicalResult 674 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, 675 ConversionPatternRewriter &rewriter) const override { 676 StringRef apiFuncName = 677 TypeSwitch<Type, StringRef>(op.operand().getType()) 678 .Case<TokenType>([](Type) { return kIsTokenError; }) 679 .Case<GroupType>([](Type) { return kIsGroupError; }) 680 .Case<ValueType>([](Type) { return kIsValueError; }); 681 682 rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(), 683 adaptor.getOperands()); 684 return success(); 685 } 686 }; 687 } // namespace 688 689 //===----------------------------------------------------------------------===// 690 // Convert async.runtime.await to the corresponding runtime API call. 691 //===----------------------------------------------------------------------===// 692 693 namespace { 694 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 695 public: 696 using OpConversionPattern::OpConversionPattern; 697 698 LogicalResult 699 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, 700 ConversionPatternRewriter &rewriter) const override { 701 StringRef apiFuncName = 702 TypeSwitch<Type, StringRef>(op.operand().getType()) 703 .Case<TokenType>([](Type) { return kAwaitToken; }) 704 .Case<ValueType>([](Type) { return kAwaitValue; }) 705 .Case<GroupType>([](Type) { return kAwaitGroup; }); 706 707 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 708 adaptor.getOperands()); 709 rewriter.eraseOp(op); 710 711 return success(); 712 } 713 }; 714 } // namespace 715 716 //===----------------------------------------------------------------------===// 717 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 718 //===----------------------------------------------------------------------===// 719 720 namespace { 721 class RuntimeAwaitAndResumeOpLowering 722 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 723 public: 724 using OpConversionPattern::OpConversionPattern; 725 726 LogicalResult 727 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, 728 ConversionPatternRewriter &rewriter) const override { 729 StringRef apiFuncName = 730 TypeSwitch<Type, StringRef>(op.operand().getType()) 731 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 732 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 733 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 734 735 Value operand = adaptor.operand(); 736 Value handle = adaptor.handle(); 737 738 // A pointer to coroutine resume intrinsic wrapper. 739 addResumeFunction(op->getParentOfType<ModuleOp>()); 740 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 741 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 742 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 743 744 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 745 ValueRange({operand, handle, resumePtr.getRes()})); 746 rewriter.eraseOp(op); 747 748 return success(); 749 } 750 }; 751 } // namespace 752 753 //===----------------------------------------------------------------------===// 754 // Convert async.runtime.resume to the corresponding runtime API call. 755 //===----------------------------------------------------------------------===// 756 757 namespace { 758 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 759 public: 760 using OpConversionPattern::OpConversionPattern; 761 762 LogicalResult 763 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, 764 ConversionPatternRewriter &rewriter) const override { 765 // A pointer to coroutine resume intrinsic wrapper. 766 addResumeFunction(op->getParentOfType<ModuleOp>()); 767 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 768 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 769 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 770 771 // Call async runtime API to execute a coroutine in the managed thread. 772 auto coroHdl = adaptor.handle(); 773 rewriter.replaceOpWithNewOp<CallOp>( 774 op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); 775 776 return success(); 777 } 778 }; 779 } // namespace 780 781 //===----------------------------------------------------------------------===// 782 // Convert async.runtime.store to the corresponding runtime API call. 783 //===----------------------------------------------------------------------===// 784 785 namespace { 786 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 787 public: 788 using OpConversionPattern::OpConversionPattern; 789 790 LogicalResult 791 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, 792 ConversionPatternRewriter &rewriter) const override { 793 Location loc = op->getLoc(); 794 795 // Get a pointer to the async value storage from the runtime. 796 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 797 auto storage = adaptor.storage(); 798 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 799 TypeRange(i8Ptr), storage); 800 801 // Cast from i8* to the LLVM pointer type. 802 auto valueType = op.value().getType(); 803 auto llvmValueType = getTypeConverter()->convertType(valueType); 804 if (!llvmValueType) 805 return rewriter.notifyMatchFailure( 806 op, "failed to convert stored value type to LLVM type"); 807 808 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 809 loc, LLVM::LLVMPointerType::get(llvmValueType), 810 storagePtr.getResult(0)); 811 812 // Store the yielded value into the async value storage. 813 auto value = adaptor.value(); 814 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 815 816 // Erase the original runtime store operation. 817 rewriter.eraseOp(op); 818 819 return success(); 820 } 821 }; 822 } // namespace 823 824 //===----------------------------------------------------------------------===// 825 // Convert async.runtime.load to the corresponding runtime API call. 826 //===----------------------------------------------------------------------===// 827 828 namespace { 829 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 830 public: 831 using OpConversionPattern::OpConversionPattern; 832 833 LogicalResult 834 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, 835 ConversionPatternRewriter &rewriter) const override { 836 Location loc = op->getLoc(); 837 838 // Get a pointer to the async value storage from the runtime. 839 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 840 auto storage = adaptor.storage(); 841 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 842 TypeRange(i8Ptr), storage); 843 844 // Cast from i8* to the LLVM pointer type. 845 auto valueType = op.result().getType(); 846 auto llvmValueType = getTypeConverter()->convertType(valueType); 847 if (!llvmValueType) 848 return rewriter.notifyMatchFailure( 849 op, "failed to convert loaded value type to LLVM type"); 850 851 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 852 loc, LLVM::LLVMPointerType::get(llvmValueType), 853 storagePtr.getResult(0)); 854 855 // Load from the casted pointer. 856 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 857 858 return success(); 859 } 860 }; 861 } // namespace 862 863 //===----------------------------------------------------------------------===// 864 // Convert async.runtime.add_to_group to the corresponding runtime API call. 865 //===----------------------------------------------------------------------===// 866 867 namespace { 868 class RuntimeAddToGroupOpLowering 869 : public OpConversionPattern<RuntimeAddToGroupOp> { 870 public: 871 using OpConversionPattern::OpConversionPattern; 872 873 LogicalResult 874 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, 875 ConversionPatternRewriter &rewriter) const override { 876 // Currently we can only add tokens to the group. 877 if (!op.operand().getType().isa<TokenType>()) 878 return rewriter.notifyMatchFailure(op, "only token type is supported"); 879 880 // Replace with a runtime API function call. 881 rewriter.replaceOpWithNewOp<CallOp>( 882 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); 883 884 return success(); 885 } 886 }; 887 } // namespace 888 889 //===----------------------------------------------------------------------===// 890 // Convert async.runtime.num_worker_threads to the corresponding runtime API 891 // call. 892 //===----------------------------------------------------------------------===// 893 894 namespace { 895 class RuntimeNumWorkerThreadsOpLowering 896 : public OpConversionPattern<RuntimeNumWorkerThreadsOp> { 897 public: 898 using OpConversionPattern::OpConversionPattern; 899 900 LogicalResult 901 matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor, 902 ConversionPatternRewriter &rewriter) const override { 903 904 // Replace with a runtime API function call. 905 rewriter.replaceOpWithNewOp<CallOp>(op, kGetNumWorkerThreads, 906 rewriter.getIndexType()); 907 908 return success(); 909 } 910 }; 911 } // namespace 912 913 //===----------------------------------------------------------------------===// 914 // Async reference counting ops lowering (`async.runtime.add_ref` and 915 // `async.runtime.drop_ref` to the corresponding API calls). 916 //===----------------------------------------------------------------------===// 917 918 namespace { 919 template <typename RefCountingOp> 920 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 921 public: 922 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 923 StringRef apiFunctionName) 924 : OpConversionPattern<RefCountingOp>(converter, ctx), 925 apiFunctionName(apiFunctionName) {} 926 927 LogicalResult 928 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, 929 ConversionPatternRewriter &rewriter) const override { 930 auto count = rewriter.create<arith::ConstantOp>( 931 op->getLoc(), rewriter.getI64Type(), 932 rewriter.getI64IntegerAttr(op.count())); 933 934 auto operand = adaptor.operand(); 935 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 936 ValueRange({operand, count})); 937 938 return success(); 939 } 940 941 private: 942 StringRef apiFunctionName; 943 }; 944 945 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 946 public: 947 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 948 : RefCountingOpLowering(converter, ctx, kAddRef) {} 949 }; 950 951 class RuntimeDropRefOpLowering 952 : public RefCountingOpLowering<RuntimeDropRefOp> { 953 public: 954 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 955 : RefCountingOpLowering(converter, ctx, kDropRef) {} 956 }; 957 } // namespace 958 959 //===----------------------------------------------------------------------===// 960 // Convert return operations that return async values from async regions. 961 //===----------------------------------------------------------------------===// 962 963 namespace { 964 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 965 public: 966 using OpConversionPattern::OpConversionPattern; 967 968 LogicalResult 969 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 970 ConversionPatternRewriter &rewriter) const override { 971 rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands()); 972 return success(); 973 } 974 }; 975 } // namespace 976 977 //===----------------------------------------------------------------------===// 978 979 namespace { 980 struct ConvertAsyncToLLVMPass 981 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 982 void runOnOperation() override; 983 }; 984 } // namespace 985 986 void ConvertAsyncToLLVMPass::runOnOperation() { 987 ModuleOp module = getOperation(); 988 MLIRContext *ctx = module->getContext(); 989 990 // Add declarations for most functions required by the coroutines lowering. 991 // We delay adding the resume function until it's needed because it currently 992 // fails to compile unless '-O0' is specified. 993 addAsyncRuntimeApiDeclarations(module); 994 995 // Lower async.runtime and async.coro operations to Async Runtime API and 996 // LLVM coroutine intrinsics. 997 998 // Convert async dialect types and operations to LLVM dialect. 999 AsyncRuntimeTypeConverter converter; 1000 RewritePatternSet patterns(ctx); 1001 1002 // We use conversion to LLVM type to lower async.runtime load and store 1003 // operations. 1004 LLVMTypeConverter llvmConverter(ctx); 1005 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 1006 1007 // Convert async types in function signatures and function calls. 1008 populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, converter); 1009 populateCallOpTypeConversionPattern(patterns, converter); 1010 1011 // Convert return operations inside async.execute regions. 1012 patterns.add<ReturnOpOpConversion>(converter, ctx); 1013 1014 // Lower async.runtime operations to the async runtime API calls. 1015 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 1016 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 1017 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 1018 RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering, 1019 RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter, 1020 ctx); 1021 1022 // Lower async.runtime operations that rely on LLVM type converter to convert 1023 // from async value payload type to the LLVM type. 1024 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 1025 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter, 1026 ctx); 1027 1028 // Lower async coroutine operations to LLVM coroutine intrinsics. 1029 patterns 1030 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 1031 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 1032 converter, ctx); 1033 1034 ConversionTarget target(*ctx); 1035 target 1036 .addLegalOp<arith::ConstantOp, ConstantOp, UnrealizedConversionCastOp>(); 1037 target.addLegalDialect<LLVM::LLVMDialect>(); 1038 1039 // All operations from Async dialect must be lowered to the runtime API and 1040 // LLVM intrinsics calls. 1041 target.addIllegalDialect<AsyncDialect>(); 1042 1043 // Add dynamic legality constraints to apply conversions defined above. 1044 target.addDynamicallyLegalOp<FuncOp>( 1045 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1046 target.addDynamicallyLegalOp<ReturnOp>( 1047 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1048 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1049 return converter.isSignatureLegal(op.getCalleeType()); 1050 }); 1051 1052 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1053 signalPassFailure(); 1054 } 1055 1056 //===----------------------------------------------------------------------===// 1057 // Patterns for structural type conversions for the Async dialect operations. 1058 //===----------------------------------------------------------------------===// 1059 1060 namespace { 1061 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1062 public: 1063 using OpConversionPattern::OpConversionPattern; 1064 LogicalResult 1065 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, 1066 ConversionPatternRewriter &rewriter) const override { 1067 ExecuteOp newOp = 1068 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1069 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1070 newOp.getRegion().end()); 1071 1072 // Set operands and update block argument and result types. 1073 newOp->setOperands(adaptor.getOperands()); 1074 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1075 return failure(); 1076 for (auto result : newOp.getResults()) 1077 result.setType(typeConverter->convertType(result.getType())); 1078 1079 rewriter.replaceOp(op, newOp.getResults()); 1080 return success(); 1081 } 1082 }; 1083 1084 // Dummy pattern to trigger the appropriate type conversion / materialization. 1085 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1086 public: 1087 using OpConversionPattern::OpConversionPattern; 1088 LogicalResult 1089 matchAndRewrite(AwaitOp op, OpAdaptor adaptor, 1090 ConversionPatternRewriter &rewriter) const override { 1091 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); 1092 return success(); 1093 } 1094 }; 1095 1096 // Dummy pattern to trigger the appropriate type conversion / materialization. 1097 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1098 public: 1099 using OpConversionPattern::OpConversionPattern; 1100 LogicalResult 1101 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 1102 ConversionPatternRewriter &rewriter) const override { 1103 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); 1104 return success(); 1105 } 1106 }; 1107 } // namespace 1108 1109 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1110 return std::make_unique<ConvertAsyncToLLVMPass>(); 1111 } 1112 1113 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1114 TypeConverter &typeConverter, RewritePatternSet &patterns, 1115 ConversionTarget &target) { 1116 typeConverter.addConversion([&](TokenType type) { return type; }); 1117 typeConverter.addConversion([&](ValueType type) { 1118 Type converted = typeConverter.convertType(type.getValueType()); 1119 return converted ? ValueType::get(converted) : converted; 1120 }); 1121 1122 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1123 typeConverter, patterns.getContext()); 1124 1125 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1126 [&](Operation *op) { return typeConverter.isLegal(op); }); 1127 } 1128