1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Async/IR/Async.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "mlir/Transforms/RegionUtils.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/Support/FormatVariadic.h" 23 24 #define DEBUG_TYPE "convert-async-to-llvm" 25 26 using namespace mlir; 27 using namespace mlir::async; 28 29 // Prefix for functions outlined from `async.execute` op regions. 30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 39 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 40 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 41 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 42 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 43 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 44 static constexpr const char *kAddTokenToGroup = 45 "mlirAsyncRuntimeAddTokenToGroup"; 46 static constexpr const char *kAwaitAndExecute = 47 "mlirAsyncRuntimeAwaitTokenAndExecute"; 48 static constexpr const char *kAwaitAllAndExecute = 49 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 50 51 namespace { 52 // Async Runtime API function types. 53 struct AsyncAPI { 54 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 55 auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); 56 auto count = IntegerType::get(32, ctx); 57 return FunctionType::get({ref, count}, {}, ctx); 58 } 59 60 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 61 return FunctionType::get({}, {TokenType::get(ctx)}, ctx); 62 } 63 64 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 65 return FunctionType::get({}, {GroupType::get(ctx)}, ctx); 66 } 67 68 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 69 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 70 } 71 72 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 73 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 74 } 75 76 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 77 return FunctionType::get({GroupType::get(ctx)}, {}, ctx); 78 } 79 80 static FunctionType executeFunctionType(MLIRContext *ctx) { 81 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 82 auto resume = resumeFunctionType(ctx).getPointerTo(); 83 return FunctionType::get({hdl, resume}, {}, ctx); 84 } 85 86 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 87 auto i64 = IntegerType::get(64, ctx); 88 return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64}, 89 ctx); 90 } 91 92 static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { 93 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 94 auto resume = resumeFunctionType(ctx).getPointerTo(); 95 return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); 96 } 97 98 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 99 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 100 auto resume = resumeFunctionType(ctx).getPointerTo(); 101 return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx); 102 } 103 104 // Auxiliary coroutine resume intrinsic wrapper. 105 static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { 106 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 107 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 108 return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); 109 } 110 }; 111 } // namespace 112 113 // Adds Async Runtime C API declarations to the module. 114 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 115 auto builder = OpBuilder::atBlockTerminator(module.getBody()); 116 117 auto addFuncDecl = [&](StringRef name, FunctionType type) { 118 if (module.lookupSymbol(name)) 119 return; 120 builder.create<FuncOp>(module.getLoc(), name, type).setPrivate(); 121 }; 122 123 MLIRContext *ctx = module.getContext(); 124 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 125 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 126 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 127 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 128 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 129 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 130 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 131 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 132 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 133 addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx)); 134 addFuncDecl(kAwaitAllAndExecute, 135 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 136 } 137 138 //===----------------------------------------------------------------------===// 139 // LLVM coroutines intrinsics declarations. 140 //===----------------------------------------------------------------------===// 141 142 static constexpr const char *kCoroId = "llvm.coro.id"; 143 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 144 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 145 static constexpr const char *kCoroSave = "llvm.coro.save"; 146 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 147 static constexpr const char *kCoroEnd = "llvm.coro.end"; 148 static constexpr const char *kCoroFree = "llvm.coro.free"; 149 static constexpr const char *kCoroResume = "llvm.coro.resume"; 150 151 /// Adds an LLVM function declaration to a module. 152 static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name, 153 LLVM::LLVMType ret, 154 ArrayRef<LLVM::LLVMType> params) { 155 if (module.lookupSymbol(name)) 156 return; 157 LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false); 158 builder.create<LLVM::LLVMFuncOp>(module.getLoc(), name, type); 159 } 160 161 /// Adds coroutine intrinsics declarations to the module. 162 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 163 using namespace mlir::LLVM; 164 165 MLIRContext *ctx = module.getContext(); 166 OpBuilder builder(module.getBody()->getTerminator()); 167 168 auto token = LLVMTokenType::get(ctx); 169 auto voidTy = LLVMType::getVoidTy(ctx); 170 171 auto i8 = LLVMType::getInt8Ty(ctx); 172 auto i1 = LLVMType::getInt1Ty(ctx); 173 auto i32 = LLVMType::getInt32Ty(ctx); 174 auto i64 = LLVMType::getInt64Ty(ctx); 175 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 176 177 addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); 178 addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); 179 addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); 180 addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); 181 addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); 182 addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); 183 addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); 184 addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); 185 } 186 187 //===----------------------------------------------------------------------===// 188 // Add malloc/free declarations to the module. 189 //===----------------------------------------------------------------------===// 190 191 static constexpr const char *kMalloc = "malloc"; 192 static constexpr const char *kFree = "free"; 193 194 /// Adds malloc/free declarations to the module. 195 static void addCRuntimeDeclarations(ModuleOp module) { 196 using namespace mlir::LLVM; 197 198 MLIRContext *ctx = module.getContext(); 199 OpBuilder builder(module.getBody()->getTerminator()); 200 201 auto voidTy = LLVMType::getVoidTy(ctx); 202 auto i64 = LLVMType::getInt64Ty(ctx); 203 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 204 205 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 206 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 207 } 208 209 //===----------------------------------------------------------------------===// 210 // Coroutine resume function wrapper. 211 //===----------------------------------------------------------------------===// 212 213 static constexpr const char *kResume = "__resume"; 214 215 // A function that takes a coroutine handle and calls a `llvm.coro.resume` 216 // intrinsics. We need this function to be able to pass it to the async 217 // runtime execute API. 218 static void addResumeFunction(ModuleOp module) { 219 MLIRContext *ctx = module.getContext(); 220 221 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 222 Location loc = module.getLoc(); 223 224 if (module.lookupSymbol(kResume)) 225 return; 226 227 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 228 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 229 230 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 231 loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false)); 232 resumeOp.setPrivate(); 233 234 auto *block = resumeOp.addEntryBlock(); 235 OpBuilder blockBuilder = OpBuilder::atBlockEnd(block); 236 237 blockBuilder.create<LLVM::CallOp>(loc, TypeRange(), 238 blockBuilder.getSymbolRefAttr(kCoroResume), 239 resumeOp.getArgument(0)); 240 241 blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange()); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // async.execute op outlining to the coroutine functions. 246 //===----------------------------------------------------------------------===// 247 248 // Function targeted for coroutine transformation has two additional blocks at 249 // the end: coroutine cleanup and coroutine suspension. 250 // 251 // async.await op lowering additionaly creates a resume block for each 252 // operation to enable non-blocking waiting via coroutine suspension. 253 namespace { 254 struct CoroMachinery { 255 Value asyncToken; 256 Value coroHandle; 257 Block *cleanup; 258 Block *suspend; 259 }; 260 } // namespace 261 262 // Builds an coroutine template compatible with LLVM coroutines lowering. 263 // 264 // - `entry` block sets up the coroutine. 265 // - `cleanup` block cleans up the coroutine state. 266 // - `suspend block after the @llvm.coro.end() defines what value will be 267 // returned to the initial caller of a coroutine. Everything before the 268 // @llvm.coro.end() will be executed at every suspension point. 269 // 270 // Coroutine structure (only the important bits): 271 // 272 // func @async_execute_fn(<function-arguments>) -> !async.token { 273 // ^entryBlock(<function-arguments>): 274 // %token = <async token> : !async.token // create async runtime token 275 // %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle 276 // br ^cleanup 277 // 278 // ^cleanup: 279 // llvm.call @llvm.coro.free(...) // delete coroutine state 280 // br ^suspend 281 // 282 // ^suspend: 283 // llvm.call @llvm.coro.end(...) // marks the end of a coroutine 284 // return %token : !async.token 285 // } 286 // 287 // The actual code for the async.execute operation body region will be inserted 288 // before the entry block terminator. 289 // 290 // 291 static CoroMachinery setupCoroMachinery(FuncOp func) { 292 assert(func.getBody().empty() && "Function must have empty body"); 293 294 MLIRContext *ctx = func.getContext(); 295 296 auto token = LLVM::LLVMTokenType::get(ctx); 297 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 298 auto i32 = LLVM::LLVMType::getInt32Ty(ctx); 299 auto i64 = LLVM::LLVMType::getInt64Ty(ctx); 300 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 301 302 Block *entryBlock = func.addEntryBlock(); 303 Location loc = func.getBody().getLoc(); 304 305 OpBuilder builder = OpBuilder::atBlockBegin(entryBlock); 306 307 // ------------------------------------------------------------------------ // 308 // Allocate async tokens/values that we will return from a ramp function. 309 // ------------------------------------------------------------------------ // 310 auto createToken = 311 builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx)); 312 313 // ------------------------------------------------------------------------ // 314 // Initialize coroutine: allocate frame, get coroutine handle. 315 // ------------------------------------------------------------------------ // 316 317 // Constants for initializing coroutine frame. 318 auto constZero = 319 builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0)); 320 auto constFalse = 321 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 322 auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr); 323 324 // Get coroutine id: @llvm.coro.id 325 auto coroId = builder.create<LLVM::CallOp>( 326 loc, token, builder.getSymbolRefAttr(kCoroId), 327 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 328 329 // Get coroutine frame size: @llvm.coro.size.i64 330 auto coroSize = builder.create<LLVM::CallOp>( 331 loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); 332 333 // Allocate memory for coroutine frame. 334 auto coroAlloc = builder.create<LLVM::CallOp>( 335 loc, i8Ptr, builder.getSymbolRefAttr(kMalloc), 336 ValueRange(coroSize.getResult(0))); 337 338 // Begin a coroutine: @llvm.coro.begin 339 auto coroHdl = builder.create<LLVM::CallOp>( 340 loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin), 341 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); 342 343 Block *cleanupBlock = func.addBlock(); 344 Block *suspendBlock = func.addBlock(); 345 346 // ------------------------------------------------------------------------ // 347 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 348 // ------------------------------------------------------------------------ // 349 builder.setInsertionPointToStart(cleanupBlock); 350 351 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 352 auto coroMem = builder.create<LLVM::CallOp>( 353 loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree), 354 ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); 355 356 // Free the memory. 357 builder.create<LLVM::CallOp>(loc, TypeRange(), 358 builder.getSymbolRefAttr(kFree), 359 ValueRange(coroMem.getResult(0))); 360 // Branch into the suspend block. 361 builder.create<BranchOp>(loc, suspendBlock); 362 363 // ------------------------------------------------------------------------ // 364 // Coroutine suspend block: mark the end of a coroutine and return allocated 365 // async token. 366 // ------------------------------------------------------------------------ // 367 builder.setInsertionPointToStart(suspendBlock); 368 369 // Mark the end of a coroutine: @llvm.coro.end. 370 builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd), 371 ValueRange({coroHdl.getResult(0), constFalse})); 372 373 // Return created `async.token` from the suspend block. This will be the 374 // return value of a coroutine ramp function. 375 builder.create<ReturnOp>(loc, createToken.getResult(0)); 376 377 // Branch from the entry block to the cleanup block to create a valid CFG. 378 builder.setInsertionPointToEnd(entryBlock); 379 380 builder.create<BranchOp>(loc, cleanupBlock); 381 382 // `async.await` op lowering will create resume blocks for async 383 // continuations, and will conditionally branch to cleanup or suspend blocks. 384 385 return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock, 386 suspendBlock}; 387 } 388 389 // Add a LLVM coroutine suspension point to the end of suspended block, to 390 // resume execution in resume block. The caller is responsible for creating the 391 // two suspended/resume blocks with the desired ops contained in each block. 392 // This function merely provides the required control flow logic. 393 // 394 // `coroState` must be a value returned from the call to @llvm.coro.save(...) 395 // intrinsic (saved coroutine state). 396 // 397 // Before: 398 // 399 // ^bb0: 400 // "opBefore"(...) 401 // "op"(...) 402 // ^cleanup: ... 403 // ^suspend: ... 404 // ^resume: 405 // "op"(...) 406 // 407 // After: 408 // 409 // ^bb0: 410 // "opBefore"(...) 411 // %suspend = llmv.call @llvm.coro.suspend(...) 412 // switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 413 // ^resume: 414 // "op"(...) 415 // ^cleanup: ... 416 // ^suspend: ... 417 // 418 static void addSuspensionPoint(CoroMachinery coro, Value coroState, 419 Operation *op, Block *suspended, Block *resume, 420 OpBuilder &builder) { 421 Location loc = op->getLoc(); 422 MLIRContext *ctx = op->getContext(); 423 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 424 auto i8 = LLVM::LLVMType::getInt8Ty(ctx); 425 426 // Add a coroutine suspension in place of original `op` in the split block. 427 OpBuilder::InsertionGuard guard(builder); 428 builder.setInsertionPointToEnd(suspended); 429 430 auto constFalse = 431 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 432 433 // Suspend a coroutine: @llvm.coro.suspend 434 auto coroSuspend = builder.create<LLVM::CallOp>( 435 loc, i8, builder.getSymbolRefAttr(kCoroSuspend), 436 ValueRange({coroState, constFalse})); 437 438 // After a suspension point decide if we should branch into resume, cleanup 439 // or suspend block of the coroutine (see @llvm.coro.suspend return code 440 // documentation). 441 auto constZero = 442 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0)); 443 auto constNegOne = 444 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1)); 445 446 Block *resumeOrCleanup = builder.createBlock(resume); 447 448 // Suspend the coroutine ...? 449 builder.setInsertionPointToEnd(suspended); 450 auto isNegOne = builder.create<LLVM::ICmpOp>( 451 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); 452 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend, 453 /*falseDest=*/resumeOrCleanup); 454 455 // ... or resume or cleanup the coroutine? 456 builder.setInsertionPointToStart(resumeOrCleanup); 457 auto isZero = builder.create<LLVM::ICmpOp>( 458 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); 459 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume, 460 /*falseDest=*/coro.cleanup); 461 } 462 463 // Outline the body region attached to the `async.execute` op into a standalone 464 // function. 465 // 466 // Note that this is not reversible transformation. 467 static std::pair<FuncOp, CoroMachinery> 468 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 469 ModuleOp module = execute->getParentOfType<ModuleOp>(); 470 471 MLIRContext *ctx = module.getContext(); 472 Location loc = execute.getLoc(); 473 474 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 475 476 // Collect all outlined function inputs. 477 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 478 execute.dependencies().end()); 479 getUsedValuesDefinedAbove(execute.body(), functionInputs); 480 481 // Collect types for the outlined function inputs and outputs. 482 auto typesRange = llvm::map_range( 483 functionInputs, [](Value value) { return value.getType(); }); 484 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 485 auto outputTypes = execute.getResultTypes(); 486 487 auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes); 488 auto funcAttrs = ArrayRef<NamedAttribute>(); 489 490 // TODO: Derive outlined function name from the parent FuncOp (support 491 // multiple nested async.execute operations). 492 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 493 symbolTable.insert(func, moduleBuilder.getInsertionPoint()); 494 495 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 496 497 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 498 // blocks, adding llvm.coro instrinsics and setting up control flow. 499 CoroMachinery coro = setupCoroMachinery(func); 500 501 // Suspend async function at the end of an entry block, and resume it using 502 // Async execute API (execution will be resumed in a thread managed by the 503 // async runtime). 504 Block *entryBlock = &func.getBlocks().front(); 505 OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock); 506 507 // A pointer to coroutine resume intrinsic wrapper. 508 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 509 auto resumePtr = builder.create<LLVM::AddressOfOp>( 510 loc, resumeFnTy.getPointerTo(), kResume); 511 512 // Save the coroutine state: @llvm.coro.save 513 auto coroSave = builder.create<LLVM::CallOp>( 514 loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 515 ValueRange({coro.coroHandle})); 516 517 // Call async runtime API to execute a coroutine in the managed thread. 518 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()}; 519 builder.create<CallOp>(loc, TypeRange(), kExecute, executeArgs); 520 521 // Split the entry block before the terminator. 522 auto *terminatorOp = entryBlock->getTerminator(); 523 Block *suspended = terminatorOp->getBlock(); 524 Block *resume = suspended->splitBlock(terminatorOp); 525 addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended, 526 resume, builder); 527 528 // Await on all dependencies before starting to execute the body region. 529 builder.setInsertionPointToStart(resume); 530 for (size_t i = 0; i < execute.dependencies().size(); ++i) 531 builder.create<AwaitOp>(loc, func.getArgument(i)); 532 533 // Map from function inputs defined above the execute op to the function 534 // arguments. 535 BlockAndValueMapping valueMapping; 536 valueMapping.map(functionInputs, func.getArguments()); 537 538 // Clone all operations from the execute operation body into the outlined 539 // function body, and replace all `async.yield` operations with a call 540 // to async runtime to emplace the result token. 541 for (Operation &op : execute.body().getOps()) { 542 if (isa<async::YieldOp>(op)) { 543 builder.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken); 544 continue; 545 } 546 builder.clone(op, valueMapping); 547 } 548 549 // Replace the original `async.execute` with a call to outlined function. 550 OpBuilder callBuilder(execute); 551 auto callOutlinedFunc = 552 callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(), 553 functionInputs.getArrayRef()); 554 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 555 execute.erase(); 556 557 return {func, coro}; 558 } 559 560 //===----------------------------------------------------------------------===// 561 // Convert Async dialect types to LLVM types. 562 //===----------------------------------------------------------------------===// 563 564 namespace { 565 class AsyncRuntimeTypeConverter : public TypeConverter { 566 public: 567 AsyncRuntimeTypeConverter() { addConversion(convertType); } 568 569 static Type convertType(Type type) { 570 MLIRContext *ctx = type.getContext(); 571 // Convert async tokens and groups to opaque pointers. 572 if (type.isa<TokenType, GroupType>()) 573 return LLVM::LLVMType::getInt8PtrTy(ctx); 574 return type; 575 } 576 }; 577 } // namespace 578 579 //===----------------------------------------------------------------------===// 580 // Convert types for all call operations to lowered async types. 581 //===----------------------------------------------------------------------===// 582 583 namespace { 584 class CallOpOpConversion : public ConversionPattern { 585 public: 586 explicit CallOpOpConversion(MLIRContext *ctx) 587 : ConversionPattern(CallOp::getOperationName(), 1, ctx) {} 588 589 LogicalResult 590 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 591 ConversionPatternRewriter &rewriter) const override { 592 AsyncRuntimeTypeConverter converter; 593 594 SmallVector<Type, 5> resultTypes; 595 converter.convertTypes(op->getResultTypes(), resultTypes); 596 597 CallOp call = cast<CallOp>(op); 598 rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(), 599 operands); 600 601 return success(); 602 } 603 }; 604 } // namespace 605 606 //===----------------------------------------------------------------------===// 607 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` 608 // to the corresponding API calls). 609 //===----------------------------------------------------------------------===// 610 611 namespace { 612 613 template <typename RefCountingOp> 614 class RefCountingOpLowering : public ConversionPattern { 615 public: 616 explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName) 617 : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx), 618 apiFunctionName(apiFunctionName) {} 619 620 LogicalResult 621 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 622 ConversionPatternRewriter &rewriter) const override { 623 RefCountingOp refCountingOp = cast<RefCountingOp>(op); 624 625 auto count = rewriter.create<ConstantOp>( 626 op->getLoc(), rewriter.getI32Type(), 627 rewriter.getI32IntegerAttr(refCountingOp.count())); 628 629 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 630 ValueRange({operands[0], count})); 631 632 return success(); 633 } 634 635 private: 636 StringRef apiFunctionName; 637 }; 638 639 // async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. 640 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> { 641 public: 642 explicit AddRefOpLowering(MLIRContext *ctx) 643 : RefCountingOpLowering(ctx, kAddRef) {} 644 }; 645 646 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 647 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> { 648 public: 649 explicit DropRefOpLowering(MLIRContext *ctx) 650 : RefCountingOpLowering(ctx, kDropRef) {} 651 }; 652 653 } // namespace 654 655 //===----------------------------------------------------------------------===// 656 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 657 //===----------------------------------------------------------------------===// 658 659 namespace { 660 class CreateGroupOpLowering : public ConversionPattern { 661 public: 662 explicit CreateGroupOpLowering(MLIRContext *ctx) 663 : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {} 664 665 LogicalResult 666 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 667 ConversionPatternRewriter &rewriter) const override { 668 auto retTy = GroupType::get(op->getContext()); 669 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy); 670 return success(); 671 } 672 }; 673 } // namespace 674 675 //===----------------------------------------------------------------------===// 676 // async.add_to_group op lowering to runtime function call. 677 //===----------------------------------------------------------------------===// 678 679 namespace { 680 class AddToGroupOpLowering : public ConversionPattern { 681 public: 682 explicit AddToGroupOpLowering(MLIRContext *ctx) 683 : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {} 684 685 LogicalResult 686 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 687 ConversionPatternRewriter &rewriter) const override { 688 // Currently we can only add tokens to the group. 689 auto addToGroup = cast<AddToGroupOp>(op); 690 if (!addToGroup.operand().getType().isa<TokenType>()) 691 return failure(); 692 693 auto i64 = IntegerType::get(64, op->getContext()); 694 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands); 695 return success(); 696 } 697 }; 698 } // namespace 699 700 //===----------------------------------------------------------------------===// 701 // async.await and async.await_all op lowerings to the corresponding async 702 // runtime function calls. 703 //===----------------------------------------------------------------------===// 704 705 namespace { 706 707 template <typename AwaitType, typename AwaitableType> 708 class AwaitOpLoweringBase : public ConversionPattern { 709 protected: 710 explicit AwaitOpLoweringBase( 711 MLIRContext *ctx, 712 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions, 713 StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) 714 : ConversionPattern(AwaitType::getOperationName(), 1, ctx), 715 outlinedFunctions(outlinedFunctions), 716 blockingAwaitFuncName(blockingAwaitFuncName), 717 coroAwaitFuncName(coroAwaitFuncName) {} 718 719 public: 720 LogicalResult 721 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 722 ConversionPatternRewriter &rewriter) const override { 723 // We can only await on one the `AwaitableType` (for `await` it can be 724 // only a `token`, for `await_all` it is a `group`). 725 auto await = cast<AwaitType>(op); 726 if (!await.operand().getType().template isa<AwaitableType>()) 727 return failure(); 728 729 // Check if await operation is inside the outlined coroutine function. 730 auto func = await->template getParentOfType<FuncOp>(); 731 auto outlined = outlinedFunctions.find(func); 732 const bool isInCoroutine = outlined != outlinedFunctions.end(); 733 734 Location loc = op->getLoc(); 735 736 // Inside regular function we convert await operation to the blocking 737 // async API await function call. 738 if (!isInCoroutine) 739 rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName, 740 ValueRange(operands[0])); 741 742 // Inside the coroutine we convert await operation into coroutine suspension 743 // point, and resume execution asynchronously. 744 if (isInCoroutine) { 745 const CoroMachinery &coro = outlined->getSecond(); 746 747 OpBuilder builder(op, rewriter.getListener()); 748 MLIRContext *ctx = op->getContext(); 749 750 // A pointer to coroutine resume intrinsic wrapper. 751 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 752 auto resumePtr = builder.create<LLVM::AddressOfOp>( 753 loc, resumeFnTy.getPointerTo(), kResume); 754 755 // Save the coroutine state: @llvm.coro.save 756 auto coroSave = builder.create<LLVM::CallOp>( 757 loc, LLVM::LLVMTokenType::get(ctx), 758 builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle)); 759 760 // Call async runtime API to resume a coroutine in the managed thread when 761 // the async await argument becomes ready. 762 SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle, 763 resumePtr.res()}; 764 builder.create<CallOp>(loc, TypeRange(), coroAwaitFuncName, 765 awaitAndExecuteArgs); 766 767 Block *suspended = op->getBlock(); 768 769 // Split the entry block before the await operation. 770 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 771 addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume, 772 builder); 773 } 774 775 // Original operation was replaced by function call or suspension point. 776 rewriter.eraseOp(op); 777 778 return success(); 779 } 780 781 private: 782 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 783 StringRef blockingAwaitFuncName; 784 StringRef coroAwaitFuncName; 785 }; 786 787 // Lowering for `async.await` operation (only token operands are supported). 788 class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 789 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 790 791 public: 792 explicit AwaitOpLowering( 793 MLIRContext *ctx, 794 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 795 : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {} 796 }; 797 798 // Lowering for `async.await_all` operation. 799 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 800 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 801 802 public: 803 explicit AwaitAllOpLowering( 804 MLIRContext *ctx, 805 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 806 : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {} 807 }; 808 809 } // namespace 810 811 //===----------------------------------------------------------------------===// 812 813 namespace { 814 struct ConvertAsyncToLLVMPass 815 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 816 void runOnOperation() override; 817 }; 818 819 void ConvertAsyncToLLVMPass::runOnOperation() { 820 ModuleOp module = getOperation(); 821 SymbolTable symbolTable(module); 822 823 // Outline all `async.execute` body regions into async functions (coroutines). 824 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 825 826 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 827 // We currently do not support execute operations that have async value 828 // operands or produce async results. 829 if (!execute.operands().empty() || !execute.results().empty()) { 830 execute.emitOpError("can't outline async.execute op with async value " 831 "operands or returned async results"); 832 return WalkResult::interrupt(); 833 } 834 835 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 836 837 return WalkResult::advance(); 838 }); 839 840 // Failed to outline all async execute operations. 841 if (outlineResult.wasInterrupted()) { 842 signalPassFailure(); 843 return; 844 } 845 846 LLVM_DEBUG({ 847 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 848 << " async functions\n"; 849 }); 850 851 // Add declarations for all functions required by the coroutines lowering. 852 addResumeFunction(module); 853 addAsyncRuntimeApiDeclarations(module); 854 addCoroutineIntrinsicsDeclarations(module); 855 addCRuntimeDeclarations(module); 856 857 MLIRContext *ctx = &getContext(); 858 859 // Convert async dialect types and operations to LLVM dialect. 860 AsyncRuntimeTypeConverter converter; 861 OwningRewritePatternList patterns; 862 863 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 864 patterns.insert<CallOpOpConversion>(ctx); 865 patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx); 866 patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 867 patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions); 868 869 ConversionTarget target(*ctx); 870 target.addLegalOp<ConstantOp>(); 871 target.addLegalDialect<LLVM::LLVMDialect>(); 872 target.addIllegalDialect<AsyncDialect>(); 873 target.addDynamicallyLegalOp<FuncOp>( 874 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 875 target.addDynamicallyLegalOp<CallOp>( 876 [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); 877 878 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 879 signalPassFailure(); 880 } 881 } // namespace 882 883 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 884 return std::make_unique<ConvertAsyncToLLVMPass>(); 885 } 886