1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Async/IR/Async.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/RegionUtils.h" 23 #include "llvm/ADT/SetVector.h" 24 #include "llvm/Support/FormatVariadic.h" 25 26 #define DEBUG_TYPE "convert-async-to-llvm" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 // Prefix for functions outlined from `async.execute` op regions. 32 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 33 34 //===----------------------------------------------------------------------===// 35 // Async Runtime C API declaration. 36 //===----------------------------------------------------------------------===// 37 38 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 39 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 40 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 41 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 42 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 43 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 44 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 45 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 46 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 47 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 48 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 49 static constexpr const char *kGetValueStorage = 50 "mlirAsyncRuntimeGetValueStorage"; 51 static constexpr const char *kAddTokenToGroup = 52 "mlirAsyncRuntimeAddTokenToGroup"; 53 static constexpr const char *kAwaitTokenAndExecute = 54 "mlirAsyncRuntimeAwaitTokenAndExecute"; 55 static constexpr const char *kAwaitValueAndExecute = 56 "mlirAsyncRuntimeAwaitValueAndExecute"; 57 static constexpr const char *kAwaitAllAndExecute = 58 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 59 60 namespace { 61 /// Async Runtime API function types. 62 /// 63 /// Because we can't create API function signature for type parametrized 64 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 65 /// lowering all async data types become opaque pointers at runtime. 66 struct AsyncAPI { 67 // All async types are lowered to opaque i8* LLVM pointers at runtime. 68 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 69 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 70 } 71 72 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 73 auto ref = opaquePointerType(ctx); 74 auto count = IntegerType::get(ctx, 32); 75 return FunctionType::get(ctx, {ref, count}, {}); 76 } 77 78 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 79 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 80 } 81 82 static FunctionType createValueFunctionType(MLIRContext *ctx) { 83 auto i32 = IntegerType::get(ctx, 32); 84 auto value = opaquePointerType(ctx); 85 return FunctionType::get(ctx, {i32}, {value}); 86 } 87 88 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 89 return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); 90 } 91 92 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 93 auto value = opaquePointerType(ctx); 94 auto storage = opaquePointerType(ctx); 95 return FunctionType::get(ctx, {value}, {storage}); 96 } 97 98 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 99 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 100 } 101 102 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 103 auto value = opaquePointerType(ctx); 104 return FunctionType::get(ctx, {value}, {}); 105 } 106 107 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 108 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 109 } 110 111 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 112 auto value = opaquePointerType(ctx); 113 return FunctionType::get(ctx, {value}, {}); 114 } 115 116 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 117 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 118 } 119 120 static FunctionType executeFunctionType(MLIRContext *ctx) { 121 auto hdl = opaquePointerType(ctx); 122 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 123 return FunctionType::get(ctx, {hdl, resume}, {}); 124 } 125 126 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 127 auto i64 = IntegerType::get(ctx, 64); 128 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 129 {i64}); 130 } 131 132 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 133 auto hdl = opaquePointerType(ctx); 134 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 135 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 136 } 137 138 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 139 auto value = opaquePointerType(ctx); 140 auto hdl = opaquePointerType(ctx); 141 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 142 return FunctionType::get(ctx, {value, hdl, resume}, {}); 143 } 144 145 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 146 auto hdl = opaquePointerType(ctx); 147 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 148 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 149 } 150 151 // Auxiliary coroutine resume intrinsic wrapper. 152 static Type resumeFunctionType(MLIRContext *ctx) { 153 auto voidTy = LLVM::LLVMVoidType::get(ctx); 154 auto i8Ptr = opaquePointerType(ctx); 155 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 156 } 157 }; 158 } // namespace 159 160 /// Adds Async Runtime C API declarations to the module. 161 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 162 auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), 163 module.getBody()); 164 165 auto addFuncDecl = [&](StringRef name, FunctionType type) { 166 if (module.lookupSymbol(name)) 167 return; 168 builder.create<FuncOp>(name, type).setPrivate(); 169 }; 170 171 MLIRContext *ctx = module.getContext(); 172 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 173 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 174 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 175 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 176 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 177 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 178 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 179 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 180 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 181 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 182 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 183 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 184 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 185 addFuncDecl(kAwaitTokenAndExecute, 186 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 187 addFuncDecl(kAwaitValueAndExecute, 188 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 189 addFuncDecl(kAwaitAllAndExecute, 190 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 191 } 192 193 //===----------------------------------------------------------------------===// 194 // LLVM coroutines intrinsics declarations. 195 //===----------------------------------------------------------------------===// 196 197 static constexpr const char *kCoroId = "llvm.coro.id"; 198 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 199 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 200 static constexpr const char *kCoroSave = "llvm.coro.save"; 201 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 202 static constexpr const char *kCoroEnd = "llvm.coro.end"; 203 static constexpr const char *kCoroFree = "llvm.coro.free"; 204 static constexpr const char *kCoroResume = "llvm.coro.resume"; 205 206 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 207 StringRef name, Type ret, ArrayRef<Type> params) { 208 if (module.lookupSymbol(name)) 209 return; 210 Type type = LLVM::LLVMFunctionType::get(ret, params); 211 builder.create<LLVM::LLVMFuncOp>(name, type); 212 } 213 214 /// Adds coroutine intrinsics declarations to the module. 215 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 216 using namespace mlir::LLVM; 217 218 MLIRContext *ctx = module.getContext(); 219 ImplicitLocOpBuilder builder(module.getLoc(), 220 module.getBody()->getTerminator()); 221 222 auto token = LLVMTokenType::get(ctx); 223 auto voidTy = LLVMVoidType::get(ctx); 224 225 auto i8 = IntegerType::get(ctx, 8); 226 auto i1 = IntegerType::get(ctx, 1); 227 auto i32 = IntegerType::get(ctx, 32); 228 auto i64 = IntegerType::get(ctx, 64); 229 auto i8Ptr = LLVMPointerType::get(i8); 230 231 addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); 232 addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); 233 addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); 234 addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); 235 addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); 236 addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); 237 addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); 238 addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // Add malloc/free declarations to the module. 243 //===----------------------------------------------------------------------===// 244 245 static constexpr const char *kMalloc = "malloc"; 246 static constexpr const char *kFree = "free"; 247 248 /// Adds malloc/free declarations to the module. 249 static void addCRuntimeDeclarations(ModuleOp module) { 250 using namespace mlir::LLVM; 251 252 MLIRContext *ctx = module.getContext(); 253 ImplicitLocOpBuilder builder(module.getLoc(), 254 module.getBody()->getTerminator()); 255 256 auto voidTy = LLVMVoidType::get(ctx); 257 auto i64 = IntegerType::get(ctx, 64); 258 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 259 260 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 261 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 262 } 263 264 //===----------------------------------------------------------------------===// 265 // Coroutine resume function wrapper. 266 //===----------------------------------------------------------------------===// 267 268 static constexpr const char *kResume = "__resume"; 269 270 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 271 /// intrinsics. We need this function to be able to pass it to the async 272 /// runtime execute API. 273 static void addResumeFunction(ModuleOp module) { 274 MLIRContext *ctx = module.getContext(); 275 276 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 277 Location loc = module.getLoc(); 278 279 if (module.lookupSymbol(kResume)) 280 return; 281 282 auto voidTy = LLVM::LLVMVoidType::get(ctx); 283 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 284 285 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 286 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 287 resumeOp.setPrivate(); 288 289 auto *block = resumeOp.addEntryBlock(); 290 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 291 292 blockBuilder.create<LLVM::CallOp>(TypeRange(), 293 blockBuilder.getSymbolRefAttr(kCoroResume), 294 resumeOp.getArgument(0)); 295 296 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 297 } 298 299 //===----------------------------------------------------------------------===// 300 // async.execute op outlining to the coroutine functions. 301 //===----------------------------------------------------------------------===// 302 303 /// Function targeted for coroutine transformation has two additional blocks at 304 /// the end: coroutine cleanup and coroutine suspension. 305 /// 306 /// async.await op lowering additionaly creates a resume block for each 307 /// operation to enable non-blocking waiting via coroutine suspension. 308 namespace { 309 struct CoroMachinery { 310 // Async execute region returns a completion token, and an async value for 311 // each yielded value. 312 // 313 // %token, %result = async.execute -> !async.value<T> { 314 // %0 = constant ... : T 315 // async.yield %0 : T 316 // } 317 Value asyncToken; // token representing completion of the async region 318 llvm::SmallVector<Value, 4> returnValues; // returned async values 319 320 Value coroHandle; 321 Block *cleanup; 322 Block *suspend; 323 }; 324 } // namespace 325 326 /// Builds an coroutine template compatible with LLVM coroutines lowering. 327 /// 328 /// - `entry` block sets up the coroutine. 329 /// - `cleanup` block cleans up the coroutine state. 330 /// - `suspend block after the @llvm.coro.end() defines what value will be 331 /// returned to the initial caller of a coroutine. Everything before the 332 /// @llvm.coro.end() will be executed at every suspension point. 333 /// 334 /// Coroutine structure (only the important bits): 335 /// 336 /// func @async_execute_fn(<function-arguments>) 337 /// -> (!async.token, !async.value<T>) 338 /// { 339 /// ^entryBlock(<function-arguments>): 340 /// %token = <async token> : !async.token // create async runtime token 341 /// %value = <async value> : !async.value<T> // create async value 342 /// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle 343 /// br ^cleanup 344 /// 345 /// ^cleanup: 346 /// llvm.call @llvm.coro.free(...) // delete coroutine state 347 /// br ^suspend 348 /// 349 /// ^suspend: 350 /// llvm.call @llvm.coro.end(...) // marks the end of a coroutine 351 /// return %token, %value : !async.token, !async.value<T> 352 /// } 353 /// 354 /// The actual code for the async.execute operation body region will be inserted 355 /// before the entry block terminator. 356 /// 357 /// 358 static CoroMachinery setupCoroMachinery(FuncOp func) { 359 assert(func.getBody().empty() && "Function must have empty body"); 360 361 MLIRContext *ctx = func.getContext(); 362 363 auto token = LLVM::LLVMTokenType::get(ctx); 364 auto i1 = IntegerType::get(ctx, 1); 365 auto i32 = IntegerType::get(ctx, 32); 366 auto i64 = IntegerType::get(ctx, 64); 367 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 368 369 Block *entryBlock = func.addEntryBlock(); 370 Location loc = func.getBody().getLoc(); 371 372 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock); 373 374 // ------------------------------------------------------------------------ // 375 // Allocate async tokens/values that we will return from a ramp function. 376 // ------------------------------------------------------------------------ // 377 auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx)); 378 379 // Async value operands and results must be convertible to LLVM types. This is 380 // verified before the function outlining. 381 LLVMTypeConverter converter(ctx); 382 383 // Returns the size requirements for the async value storage. 384 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 385 auto sizeOf = [&](ValueType valueType) -> Value { 386 auto storedType = converter.convertType(valueType.getValueType()); 387 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 388 389 // %Size = getelementptr %T* null, int 1 390 // %SizeI = ptrtoint %T* %Size to i32 391 auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType); 392 auto one = builder.create<LLVM::ConstantOp>(loc, i32, 393 builder.getI32IntegerAttr(1)); 394 auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 395 one.getResult()); 396 return builder.create<LLVM::PtrToIntOp>(loc, i32, gep); 397 }; 398 399 // We use the `async.value` type as a return type although it does not match 400 // the `kCreateValue` function signature, because it will be later lowered to 401 // the runtime type (opaque i8* pointer). 402 llvm::SmallVector<CallOp, 4> createValues; 403 for (auto resultType : func.getCallableResults().drop_front(1)) 404 createValues.emplace_back(builder.create<CallOp>( 405 loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>()))); 406 407 auto createdValues = llvm::map_range( 408 createValues, [](CallOp call) { return call.getResult(0); }); 409 llvm::SmallVector<Value, 4> returnValues(createdValues.begin(), 410 createdValues.end()); 411 412 // ------------------------------------------------------------------------ // 413 // Initialize coroutine: allocate frame, get coroutine handle. 414 // ------------------------------------------------------------------------ // 415 416 // Constants for initializing coroutine frame. 417 auto constZero = 418 builder.create<LLVM::ConstantOp>(i32, builder.getI32IntegerAttr(0)); 419 auto constFalse = 420 builder.create<LLVM::ConstantOp>(i1, builder.getBoolAttr(false)); 421 auto nullPtr = builder.create<LLVM::NullOp>(i8Ptr); 422 423 // Get coroutine id: @llvm.coro.id 424 auto coroId = builder.create<LLVM::CallOp>( 425 token, builder.getSymbolRefAttr(kCoroId), 426 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 427 428 // Get coroutine frame size: @llvm.coro.size.i64 429 auto coroSize = builder.create<LLVM::CallOp>( 430 i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); 431 432 // Allocate memory for coroutine frame. 433 auto coroAlloc = 434 builder.create<LLVM::CallOp>(i8Ptr, builder.getSymbolRefAttr(kMalloc), 435 ValueRange(coroSize.getResult(0))); 436 437 // Begin a coroutine: @llvm.coro.begin 438 auto coroHdl = builder.create<LLVM::CallOp>( 439 i8Ptr, builder.getSymbolRefAttr(kCoroBegin), 440 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); 441 442 Block *cleanupBlock = func.addBlock(); 443 Block *suspendBlock = func.addBlock(); 444 445 // ------------------------------------------------------------------------ // 446 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 447 // ------------------------------------------------------------------------ // 448 builder.setInsertionPointToStart(cleanupBlock); 449 450 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 451 auto coroMem = builder.create<LLVM::CallOp>( 452 i8Ptr, builder.getSymbolRefAttr(kCoroFree), 453 ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); 454 455 // Free the memory. 456 builder.create<LLVM::CallOp>(TypeRange(), builder.getSymbolRefAttr(kFree), 457 ValueRange(coroMem.getResult(0))); 458 // Branch into the suspend block. 459 builder.create<BranchOp>(suspendBlock); 460 461 // ------------------------------------------------------------------------ // 462 // Coroutine suspend block: mark the end of a coroutine and return allocated 463 // async token. 464 // ------------------------------------------------------------------------ // 465 builder.setInsertionPointToStart(suspendBlock); 466 467 // Mark the end of a coroutine: @llvm.coro.end. 468 builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd), 469 ValueRange({coroHdl.getResult(0), constFalse})); 470 471 // Return created `async.token` and `async.values` from the suspend block. 472 // This will be the return value of a coroutine ramp function. 473 SmallVector<Value, 4> ret{createToken.getResult(0)}; 474 ret.insert(ret.end(), returnValues.begin(), returnValues.end()); 475 builder.create<ReturnOp>(loc, ret); 476 477 // Branch from the entry block to the cleanup block to create a valid CFG. 478 builder.setInsertionPointToEnd(entryBlock); 479 480 builder.create<BranchOp>(cleanupBlock); 481 482 // `async.await` op lowering will create resume blocks for async 483 // continuations, and will conditionally branch to cleanup or suspend blocks. 484 485 CoroMachinery machinery; 486 machinery.asyncToken = createToken.getResult(0); 487 machinery.returnValues = returnValues; 488 machinery.coroHandle = coroHdl.getResult(0); 489 machinery.cleanup = cleanupBlock; 490 machinery.suspend = suspendBlock; 491 return machinery; 492 } 493 494 /// Add a LLVM coroutine suspension point to the end of suspended block, to 495 /// resume execution in resume block. The caller is responsible for creating the 496 /// two suspended/resume blocks with the desired ops contained in each block. 497 /// This function merely provides the required control flow logic. 498 /// 499 /// `coroState` must be a value returned from the call to @llvm.coro.save(...) 500 /// intrinsic (saved coroutine state). 501 /// 502 /// Before: 503 /// 504 /// ^bb0: 505 /// "opBefore"(...) 506 /// "op"(...) 507 /// ^cleanup: ... 508 /// ^suspend: ... 509 /// ^resume: 510 /// "op"(...) 511 /// 512 /// After: 513 /// 514 /// ^bb0: 515 /// "opBefore"(...) 516 /// %suspend = llmv.call @llvm.coro.suspend(...) 517 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 518 /// ^resume: 519 /// "op"(...) 520 /// ^cleanup: ... 521 /// ^suspend: ... 522 /// 523 static void addSuspensionPoint(CoroMachinery coro, Value coroState, 524 Operation *op, Block *suspended, Block *resume, 525 OpBuilder &builder) { 526 Location loc = op->getLoc(); 527 MLIRContext *ctx = op->getContext(); 528 auto i1 = IntegerType::get(ctx, 1); 529 auto i8 = IntegerType::get(ctx, 8); 530 531 // Add a coroutine suspension in place of original `op` in the split block. 532 OpBuilder::InsertionGuard guard(builder); 533 builder.setInsertionPointToEnd(suspended); 534 535 auto constFalse = 536 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 537 538 // Suspend a coroutine: @llvm.coro.suspend 539 auto coroSuspend = builder.create<LLVM::CallOp>( 540 loc, i8, builder.getSymbolRefAttr(kCoroSuspend), 541 ValueRange({coroState, constFalse})); 542 543 // After a suspension point decide if we should branch into resume, cleanup 544 // or suspend block of the coroutine (see @llvm.coro.suspend return code 545 // documentation). 546 auto constZero = 547 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0)); 548 auto constNegOne = 549 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1)); 550 551 Block *resumeOrCleanup = builder.createBlock(resume); 552 553 // Suspend the coroutine ...? 554 builder.setInsertionPointToEnd(suspended); 555 auto isNegOne = builder.create<LLVM::ICmpOp>( 556 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); 557 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend, 558 /*falseDest=*/resumeOrCleanup); 559 560 // ... or resume or cleanup the coroutine? 561 builder.setInsertionPointToStart(resumeOrCleanup); 562 auto isZero = builder.create<LLVM::ICmpOp>( 563 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); 564 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume, 565 /*falseDest=*/coro.cleanup); 566 } 567 568 /// Outline the body region attached to the `async.execute` op into a standalone 569 /// function. 570 /// 571 /// Note that this is not reversible transformation. 572 static std::pair<FuncOp, CoroMachinery> 573 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 574 ModuleOp module = execute->getParentOfType<ModuleOp>(); 575 576 MLIRContext *ctx = module.getContext(); 577 Location loc = execute.getLoc(); 578 579 // Collect all outlined function inputs. 580 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 581 execute.dependencies().end()); 582 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 583 getUsedValuesDefinedAbove(execute.body(), functionInputs); 584 585 // Collect types for the outlined function inputs and outputs. 586 auto typesRange = llvm::map_range( 587 functionInputs, [](Value value) { return value.getType(); }); 588 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 589 auto outputTypes = execute.getResultTypes(); 590 591 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 592 auto funcAttrs = ArrayRef<NamedAttribute>(); 593 594 // TODO: Derive outlined function name from the parent FuncOp (support 595 // multiple nested async.execute operations). 596 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 597 symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); 598 599 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 600 601 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 602 // blocks, adding llvm.coro instrinsics and setting up control flow. 603 CoroMachinery coro = setupCoroMachinery(func); 604 605 // Suspend async function at the end of an entry block, and resume it using 606 // Async execute API (execution will be resumed in a thread managed by the 607 // async runtime). 608 Block *entryBlock = &func.getBlocks().front(); 609 auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); 610 611 // A pointer to coroutine resume intrinsic wrapper. 612 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 613 auto resumePtr = builder.create<LLVM::AddressOfOp>( 614 LLVM::LLVMPointerType::get(resumeFnTy), kResume); 615 616 // Save the coroutine state: @llvm.coro.save 617 auto coroSave = builder.create<LLVM::CallOp>( 618 LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 619 ValueRange({coro.coroHandle})); 620 621 // Call async runtime API to execute a coroutine in the managed thread. 622 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()}; 623 builder.create<CallOp>(TypeRange(), kExecute, executeArgs); 624 625 // Split the entry block before the terminator. 626 auto *terminatorOp = entryBlock->getTerminator(); 627 Block *suspended = terminatorOp->getBlock(); 628 Block *resume = suspended->splitBlock(terminatorOp); 629 addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended, 630 resume, builder); 631 632 size_t numDependencies = execute.dependencies().size(); 633 size_t numOperands = execute.operands().size(); 634 635 // Await on all dependencies before starting to execute the body region. 636 builder.setInsertionPointToStart(resume); 637 for (size_t i = 0; i < numDependencies; ++i) 638 builder.create<AwaitOp>(func.getArgument(i)); 639 640 // Await on all async value operands and unwrap the payload. 641 SmallVector<Value, 4> unwrappedOperands(numOperands); 642 for (size_t i = 0; i < numOperands; ++i) { 643 Value operand = func.getArgument(numDependencies + i); 644 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 645 } 646 647 // Map from function inputs defined above the execute op to the function 648 // arguments. 649 BlockAndValueMapping valueMapping; 650 valueMapping.map(functionInputs, func.getArguments()); 651 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 652 653 // Clone all operations from the execute operation body into the outlined 654 // function body. 655 for (Operation &op : execute.body().getOps()) 656 builder.clone(op, valueMapping); 657 658 // Replace the original `async.execute` with a call to outlined function. 659 ImplicitLocOpBuilder callBuilder(loc, execute); 660 auto callOutlinedFunc = callBuilder.create<CallOp>( 661 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 662 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 663 execute.erase(); 664 665 return {func, coro}; 666 } 667 668 //===----------------------------------------------------------------------===// 669 // Convert Async dialect types to LLVM types. 670 //===----------------------------------------------------------------------===// 671 672 namespace { 673 674 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 675 /// their runtime type (opaque pointers) and does not convert any other types. 676 class AsyncRuntimeTypeConverter : public TypeConverter { 677 public: 678 AsyncRuntimeTypeConverter() { 679 addConversion([](Type type) { return type; }); 680 addConversion(convertAsyncTypes); 681 } 682 683 static Optional<Type> convertAsyncTypes(Type type) { 684 if (type.isa<TokenType, GroupType, ValueType>()) 685 return AsyncAPI::opaquePointerType(type.getContext()); 686 return llvm::None; 687 } 688 }; 689 } // namespace 690 691 //===----------------------------------------------------------------------===// 692 // Convert return operations that return async values from async regions. 693 //===----------------------------------------------------------------------===// 694 695 namespace { 696 class ReturnOpOpConversion : public ConversionPattern { 697 public: 698 explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx) 699 : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {} 700 701 LogicalResult 702 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 703 ConversionPatternRewriter &rewriter) const override { 704 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 705 return success(); 706 } 707 }; 708 } // namespace 709 710 //===----------------------------------------------------------------------===// 711 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` 712 // to the corresponding API calls). 713 //===----------------------------------------------------------------------===// 714 715 namespace { 716 717 template <typename RefCountingOp> 718 class RefCountingOpLowering : public ConversionPattern { 719 public: 720 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 721 StringRef apiFunctionName) 722 : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx), 723 apiFunctionName(apiFunctionName) {} 724 725 LogicalResult 726 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 727 ConversionPatternRewriter &rewriter) const override { 728 RefCountingOp refCountingOp = cast<RefCountingOp>(op); 729 730 auto count = rewriter.create<ConstantOp>( 731 op->getLoc(), rewriter.getI32Type(), 732 rewriter.getI32IntegerAttr(refCountingOp.count())); 733 734 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 735 ValueRange({operands[0], count})); 736 737 return success(); 738 } 739 740 private: 741 StringRef apiFunctionName; 742 }; 743 744 /// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. 745 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> { 746 public: 747 explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 748 : RefCountingOpLowering(converter, ctx, kAddRef) {} 749 }; 750 751 /// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 752 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> { 753 public: 754 explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 755 : RefCountingOpLowering(converter, ctx, kDropRef) {} 756 }; 757 758 } // namespace 759 760 //===----------------------------------------------------------------------===// 761 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 762 //===----------------------------------------------------------------------===// 763 764 namespace { 765 class CreateGroupOpLowering : public ConversionPattern { 766 public: 767 explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) 768 : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter, 769 ctx) {} 770 771 LogicalResult 772 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 773 ConversionPatternRewriter &rewriter) const override { 774 auto retTy = GroupType::get(op->getContext()); 775 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy); 776 return success(); 777 } 778 }; 779 } // namespace 780 781 //===----------------------------------------------------------------------===// 782 // async.add_to_group op lowering to runtime function call. 783 //===----------------------------------------------------------------------===// 784 785 namespace { 786 class AddToGroupOpLowering : public ConversionPattern { 787 public: 788 explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) 789 : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) { 790 } 791 792 LogicalResult 793 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 794 ConversionPatternRewriter &rewriter) const override { 795 // Currently we can only add tokens to the group. 796 auto addToGroup = cast<AddToGroupOp>(op); 797 if (!addToGroup.operand().getType().isa<TokenType>()) 798 return failure(); 799 800 auto i64 = IntegerType::get(op->getContext(), 64); 801 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands); 802 return success(); 803 } 804 }; 805 } // namespace 806 807 //===----------------------------------------------------------------------===// 808 // async.await and async.await_all op lowerings to the corresponding async 809 // runtime function calls. 810 //===----------------------------------------------------------------------===// 811 812 namespace { 813 814 template <typename AwaitType, typename AwaitableType> 815 class AwaitOpLoweringBase : public ConversionPattern { 816 protected: 817 explicit AwaitOpLoweringBase( 818 TypeConverter &converter, MLIRContext *ctx, 819 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions, 820 StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) 821 : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx), 822 outlinedFunctions(outlinedFunctions), 823 blockingAwaitFuncName(blockingAwaitFuncName), 824 coroAwaitFuncName(coroAwaitFuncName) {} 825 826 public: 827 LogicalResult 828 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 829 ConversionPatternRewriter &rewriter) const override { 830 // We can only await on one the `AwaitableType` (for `await` it can be 831 // a `token` or a `value`, for `await_all` it must be a `group`). 832 auto await = cast<AwaitType>(op); 833 if (!await.operand().getType().template isa<AwaitableType>()) 834 return failure(); 835 836 // Check if await operation is inside the outlined coroutine function. 837 auto func = await->template getParentOfType<FuncOp>(); 838 auto outlined = outlinedFunctions.find(func); 839 const bool isInCoroutine = outlined != outlinedFunctions.end(); 840 841 Location loc = op->getLoc(); 842 843 // Inside regular function we convert await operation to the blocking 844 // async API await function call. 845 if (!isInCoroutine) 846 rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName, 847 ValueRange(operands[0])); 848 849 // Inside the coroutine we convert await operation into coroutine suspension 850 // point, and resume execution asynchronously. 851 if (isInCoroutine) { 852 const CoroMachinery &coro = outlined->getSecond(); 853 854 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 855 MLIRContext *ctx = op->getContext(); 856 857 // A pointer to coroutine resume intrinsic wrapper. 858 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 859 auto resumePtr = builder.create<LLVM::AddressOfOp>( 860 LLVM::LLVMPointerType::get(resumeFnTy), kResume); 861 862 // Save the coroutine state: @llvm.coro.save 863 auto coroSave = builder.create<LLVM::CallOp>( 864 LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 865 ValueRange(coro.coroHandle)); 866 867 // Call async runtime API to resume a coroutine in the managed thread when 868 // the async await argument becomes ready. 869 SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle, 870 resumePtr.res()}; 871 builder.create<CallOp>(TypeRange(), coroAwaitFuncName, 872 awaitAndExecuteArgs); 873 874 Block *suspended = op->getBlock(); 875 876 // Split the entry block before the await operation. 877 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 878 addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume, 879 builder); 880 881 // Make sure that replacement value will be constructed in resume block. 882 rewriter.setInsertionPointToStart(resume); 883 } 884 885 // Replace or erase the await operation with the new value. 886 if (Value replaceWith = getReplacementValue(op, operands[0], rewriter)) 887 rewriter.replaceOp(op, replaceWith); 888 else 889 rewriter.eraseOp(op); 890 891 return success(); 892 } 893 894 virtual Value getReplacementValue(Operation *op, Value operand, 895 ConversionPatternRewriter &rewriter) const { 896 return Value(); 897 } 898 899 private: 900 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 901 StringRef blockingAwaitFuncName; 902 StringRef coroAwaitFuncName; 903 }; 904 905 /// Lowering for `async.await` with a token operand. 906 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 907 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 908 909 public: 910 explicit AwaitTokenOpLowering( 911 TypeConverter &converter, MLIRContext *ctx, 912 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 913 : Base(converter, ctx, outlinedFunctions, kAwaitToken, 914 kAwaitTokenAndExecute) {} 915 }; 916 917 /// Lowering for `async.await` with a value operand. 918 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 919 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 920 921 public: 922 explicit AwaitValueOpLowering( 923 TypeConverter &converter, MLIRContext *ctx, 924 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 925 : Base(converter, ctx, outlinedFunctions, kAwaitValue, 926 kAwaitValueAndExecute) {} 927 928 Value 929 getReplacementValue(Operation *op, Value operand, 930 ConversionPatternRewriter &rewriter) const override { 931 Location loc = op->getLoc(); 932 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 933 934 // Get the underlying value type from the `async.value`. 935 auto await = cast<AwaitOp>(op); 936 auto valueType = await.operand().getType().cast<ValueType>().getValueType(); 937 938 // Get a pointer to an async value storage from the runtime. 939 auto storage = rewriter.create<CallOp>(loc, kGetValueStorage, 940 TypeRange(i8Ptr), operand); 941 942 // Cast from i8* to the pointer pointer to LLVM type. 943 auto llvmValueType = getTypeConverter()->convertType(valueType); 944 auto castedStorage = rewriter.create<LLVM::BitcastOp>( 945 loc, LLVM::LLVMPointerType::get(llvmValueType), storage.getResult(0)); 946 947 // Load from the async value storage. 948 auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult()); 949 950 // Cast from LLVM type to the expected value type if necessary. This cast 951 // will become no-op after lowering to LLVM. 952 if (valueType == loaded.getType()) 953 return loaded; 954 return rewriter.create<LLVM::DialectCastOp>(loc, valueType, loaded); 955 } 956 }; 957 958 /// Lowering for `async.await_all` operation. 959 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 960 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 961 962 public: 963 explicit AwaitAllOpLowering( 964 TypeConverter &converter, MLIRContext *ctx, 965 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 966 : Base(converter, ctx, outlinedFunctions, kAwaitGroup, 967 kAwaitAllAndExecute) {} 968 }; 969 970 } // namespace 971 972 //===----------------------------------------------------------------------===// 973 // async.yield op lowerings to the corresponding async runtime function calls. 974 //===----------------------------------------------------------------------===// 975 976 class YieldOpLowering : public ConversionPattern { 977 public: 978 explicit YieldOpLowering( 979 TypeConverter &converter, MLIRContext *ctx, 980 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 981 : ConversionPattern(async::YieldOp::getOperationName(), 1, converter, 982 ctx), 983 outlinedFunctions(outlinedFunctions) {} 984 985 LogicalResult 986 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 987 ConversionPatternRewriter &rewriter) const override { 988 // Check if yield operation is inside the outlined coroutine function. 989 auto func = op->template getParentOfType<FuncOp>(); 990 auto outlined = outlinedFunctions.find(func); 991 if (outlined == outlinedFunctions.end()) 992 return op->emitOpError( 993 "async.yield is not inside the outlined coroutine function"); 994 995 Location loc = op->getLoc(); 996 const CoroMachinery &coro = outlined->getSecond(); 997 998 // Store yielded values into the async values storage and emplace them. 999 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 1000 1001 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 1002 // Store `yieldValue` into the `asyncValue` storage. 1003 Value yieldValue = std::get<0>(tuple); 1004 Value asyncValue = std::get<1>(tuple); 1005 1006 // Get an opaque i8* pointer to an async value storage from the runtime. 1007 auto storage = rewriter.create<CallOp>(loc, kGetValueStorage, 1008 TypeRange(i8Ptr), asyncValue); 1009 1010 // Cast storage pointer to the yielded value type. 1011 auto castedStorage = rewriter.create<LLVM::BitcastOp>( 1012 loc, LLVM::LLVMPointerType::get(yieldValue.getType()), 1013 storage.getResult(0)); 1014 1015 // Store the yielded value into the async value storage. 1016 rewriter.create<LLVM::StoreOp>(loc, yieldValue, 1017 castedStorage.getResult()); 1018 1019 // Emplace the `async.value` to mark it ready. 1020 rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue); 1021 } 1022 1023 // Emplace the completion token to mark it ready. 1024 rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken); 1025 1026 // Original operation was replaced by the function call(s). 1027 rewriter.eraseOp(op); 1028 1029 return success(); 1030 } 1031 1032 private: 1033 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 1034 }; 1035 1036 //===----------------------------------------------------------------------===// 1037 1038 namespace { 1039 struct ConvertAsyncToLLVMPass 1040 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 1041 void runOnOperation() override; 1042 }; 1043 1044 void ConvertAsyncToLLVMPass::runOnOperation() { 1045 ModuleOp module = getOperation(); 1046 SymbolTable symbolTable(module); 1047 1048 MLIRContext *ctx = &getContext(); 1049 1050 // Outline all `async.execute` body regions into async functions (coroutines). 1051 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 1052 1053 // We use conversion to LLVM type to ensure that all `async.value` operands 1054 // and results can be lowered to LLVM load and store operations. 1055 LLVMTypeConverter llvmConverter(ctx); 1056 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 1057 1058 // Returns true if the `async.value` payload is convertible to LLVM. 1059 auto isConvertibleToLlvm = [&](Type type) -> bool { 1060 auto valueType = type.cast<ValueType>().getValueType(); 1061 return static_cast<bool>(llvmConverter.convertType(valueType)); 1062 }; 1063 1064 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 1065 // All operands and results must be convertible to LLVM. 1066 if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) { 1067 execute.emitOpError("operands payload must be convertible to LLVM type"); 1068 return WalkResult::interrupt(); 1069 } 1070 if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) { 1071 execute.emitOpError("results payload must be convertible to LLVM type"); 1072 return WalkResult::interrupt(); 1073 } 1074 1075 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 1076 1077 return WalkResult::advance(); 1078 }); 1079 1080 // Failed to outline all async execute operations. 1081 if (outlineResult.wasInterrupted()) { 1082 signalPassFailure(); 1083 return; 1084 } 1085 1086 LLVM_DEBUG({ 1087 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 1088 << " async functions\n"; 1089 }); 1090 1091 // Add declarations for all functions required by the coroutines lowering. 1092 addResumeFunction(module); 1093 addAsyncRuntimeApiDeclarations(module); 1094 addCoroutineIntrinsicsDeclarations(module); 1095 addCRuntimeDeclarations(module); 1096 1097 // Convert async dialect types and operations to LLVM dialect. 1098 AsyncRuntimeTypeConverter converter; 1099 OwningRewritePatternList patterns; 1100 1101 // Convert async types in function signatures and function calls. 1102 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 1103 populateCallOpTypeConversionPattern(patterns, ctx, converter); 1104 1105 // Convert return operations inside async.execute regions. 1106 patterns.insert<ReturnOpOpConversion>(converter, ctx); 1107 1108 // Lower async operations to async runtime API calls. 1109 patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx); 1110 patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx); 1111 1112 // Use LLVM type converter to automatically convert between the async value 1113 // payload type and LLVM type when loading/storing from/to the async 1114 // value storage which is an opaque i8* pointer using LLVM load/store ops. 1115 patterns 1116 .insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>( 1117 llvmConverter, ctx, outlinedFunctions); 1118 patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions); 1119 1120 ConversionTarget target(*ctx); 1121 target.addLegalOp<ConstantOp>(); 1122 target.addLegalDialect<LLVM::LLVMDialect>(); 1123 1124 // All operations from Async dialect must be lowered to the runtime API calls. 1125 target.addIllegalDialect<AsyncDialect>(); 1126 1127 // Add dynamic legality constraints to apply conversions defined above. 1128 target.addDynamicallyLegalOp<FuncOp>( 1129 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1130 target.addDynamicallyLegalOp<ReturnOp>( 1131 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1132 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1133 return converter.isSignatureLegal(op.getCalleeType()); 1134 }); 1135 1136 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1137 signalPassFailure(); 1138 } 1139 } // namespace 1140 1141 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1142 return std::make_unique<ConvertAsyncToLLVMPass>(); 1143 } 1144