1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Async/IR/Async.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "mlir/Transforms/RegionUtils.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/Support/FormatVariadic.h" 23 24 #define DEBUG_TYPE "convert-async-to-llvm" 25 26 using namespace mlir; 27 using namespace mlir::async; 28 29 // Prefix for functions outlined from `async.execute` op regions. 30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 37 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 38 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 39 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 42 static constexpr const char *kAddTokenToGroup = 43 "mlirAsyncRuntimeAddTokenToGroup"; 44 static constexpr const char *kAwaitAndExecute = 45 "mlirAsyncRuntimeAwaitTokenAndExecute"; 46 static constexpr const char *kAwaitAllAndExecute = 47 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 48 49 namespace { 50 // Async Runtime API function types. 51 struct AsyncAPI { 52 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 53 return FunctionType::get({}, {TokenType::get(ctx)}, ctx); 54 } 55 56 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 57 return FunctionType::get({}, {GroupType::get(ctx)}, ctx); 58 } 59 60 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 61 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 62 } 63 64 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 65 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 66 } 67 68 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 69 return FunctionType::get({GroupType::get(ctx)}, {}, ctx); 70 } 71 72 static FunctionType executeFunctionType(MLIRContext *ctx) { 73 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 74 auto resume = resumeFunctionType(ctx).getPointerTo(); 75 return FunctionType::get({hdl, resume}, {}, ctx); 76 } 77 78 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 79 auto i64 = IntegerType::get(64, ctx); 80 return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64}, 81 ctx); 82 } 83 84 static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { 85 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 86 auto resume = resumeFunctionType(ctx).getPointerTo(); 87 return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); 88 } 89 90 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 91 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 92 auto resume = resumeFunctionType(ctx).getPointerTo(); 93 return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx); 94 } 95 96 // Auxiliary coroutine resume intrinsic wrapper. 97 static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { 98 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 99 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 100 return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); 101 } 102 }; 103 } // namespace 104 105 // Adds Async Runtime C API declarations to the module. 106 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 107 auto builder = OpBuilder::atBlockTerminator(module.getBody()); 108 109 MLIRContext *ctx = module.getContext(); 110 Location loc = module.getLoc(); 111 112 if (!module.lookupSymbol(kCreateToken)) 113 builder.create<FuncOp>(loc, kCreateToken, 114 AsyncAPI::createTokenFunctionType(ctx)); 115 116 if (!module.lookupSymbol(kCreateGroup)) 117 builder.create<FuncOp>(loc, kCreateGroup, 118 AsyncAPI::createGroupFunctionType(ctx)); 119 120 if (!module.lookupSymbol(kEmplaceToken)) 121 builder.create<FuncOp>(loc, kEmplaceToken, 122 AsyncAPI::emplaceTokenFunctionType(ctx)); 123 124 if (!module.lookupSymbol(kAwaitToken)) 125 builder.create<FuncOp>(loc, kAwaitToken, 126 AsyncAPI::awaitTokenFunctionType(ctx)); 127 128 if (!module.lookupSymbol(kAwaitGroup)) 129 builder.create<FuncOp>(loc, kAwaitGroup, 130 AsyncAPI::awaitGroupFunctionType(ctx)); 131 132 if (!module.lookupSymbol(kExecute)) 133 builder.create<FuncOp>(loc, kExecute, AsyncAPI::executeFunctionType(ctx)); 134 135 if (!module.lookupSymbol(kAddTokenToGroup)) 136 builder.create<FuncOp>(loc, kAddTokenToGroup, 137 AsyncAPI::addTokenToGroupFunctionType(ctx)); 138 139 if (!module.lookupSymbol(kAwaitAndExecute)) 140 builder.create<FuncOp>(loc, kAwaitAndExecute, 141 AsyncAPI::awaitAndExecuteFunctionType(ctx)); 142 143 if (!module.lookupSymbol(kAwaitAllAndExecute)) 144 builder.create<FuncOp>(loc, kAwaitAllAndExecute, 145 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 146 } 147 148 //===----------------------------------------------------------------------===// 149 // LLVM coroutines intrinsics declarations. 150 //===----------------------------------------------------------------------===// 151 152 static constexpr const char *kCoroId = "llvm.coro.id"; 153 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 154 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 155 static constexpr const char *kCoroSave = "llvm.coro.save"; 156 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 157 static constexpr const char *kCoroEnd = "llvm.coro.end"; 158 static constexpr const char *kCoroFree = "llvm.coro.free"; 159 static constexpr const char *kCoroResume = "llvm.coro.resume"; 160 161 /// Adds coroutine intrinsics declarations to the module. 162 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 163 using namespace mlir::LLVM; 164 165 MLIRContext *ctx = module.getContext(); 166 Location loc = module.getLoc(); 167 168 OpBuilder builder(module.getBody()->getTerminator()); 169 170 auto token = LLVMTokenType::get(ctx); 171 auto voidTy = LLVMType::getVoidTy(ctx); 172 173 auto i8 = LLVMType::getInt8Ty(ctx); 174 auto i1 = LLVMType::getInt1Ty(ctx); 175 auto i32 = LLVMType::getInt32Ty(ctx); 176 auto i64 = LLVMType::getInt64Ty(ctx); 177 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 178 179 if (!module.lookupSymbol(kCoroId)) 180 builder.create<LLVMFuncOp>( 181 loc, kCoroId, 182 LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false)); 183 184 if (!module.lookupSymbol(kCoroSizeI64)) 185 builder.create<LLVMFuncOp>(loc, kCoroSizeI64, 186 LLVMType::getFunctionTy(i64, false)); 187 188 if (!module.lookupSymbol(kCoroBegin)) 189 builder.create<LLVMFuncOp>( 190 loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); 191 192 if (!module.lookupSymbol(kCoroSave)) 193 builder.create<LLVMFuncOp>(loc, kCoroSave, 194 LLVMType::getFunctionTy(token, i8Ptr, false)); 195 196 if (!module.lookupSymbol(kCoroSuspend)) 197 builder.create<LLVMFuncOp>(loc, kCoroSuspend, 198 LLVMType::getFunctionTy(i8, {token, i1}, false)); 199 200 if (!module.lookupSymbol(kCoroEnd)) 201 builder.create<LLVMFuncOp>(loc, kCoroEnd, 202 LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false)); 203 204 if (!module.lookupSymbol(kCoroFree)) 205 builder.create<LLVMFuncOp>( 206 loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); 207 208 if (!module.lookupSymbol(kCoroResume)) 209 builder.create<LLVMFuncOp>(loc, kCoroResume, 210 LLVMType::getFunctionTy(voidTy, i8Ptr, false)); 211 } 212 213 //===----------------------------------------------------------------------===// 214 // Add malloc/free declarations to the module. 215 //===----------------------------------------------------------------------===// 216 217 static constexpr const char *kMalloc = "malloc"; 218 static constexpr const char *kFree = "free"; 219 220 /// Adds malloc/free declarations to the module. 221 static void addCRuntimeDeclarations(ModuleOp module) { 222 using namespace mlir::LLVM; 223 224 MLIRContext *ctx = module.getContext(); 225 Location loc = module.getLoc(); 226 227 OpBuilder builder(module.getBody()->getTerminator()); 228 229 auto voidTy = LLVMType::getVoidTy(ctx); 230 auto i64 = LLVMType::getInt64Ty(ctx); 231 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 232 233 if (!module.lookupSymbol(kMalloc)) 234 builder.create<LLVM::LLVMFuncOp>( 235 loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false)); 236 237 if (!module.lookupSymbol(kFree)) 238 builder.create<LLVM::LLVMFuncOp>( 239 loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false)); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // Coroutine resume function wrapper. 244 //===----------------------------------------------------------------------===// 245 246 static constexpr const char *kResume = "__resume"; 247 248 // A function that takes a coroutine handle and calls a `llvm.coro.resume` 249 // intrinsics. We need this function to be able to pass it to the async 250 // runtime execute API. 251 static void addResumeFunction(ModuleOp module) { 252 MLIRContext *ctx = module.getContext(); 253 254 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 255 Location loc = module.getLoc(); 256 257 if (module.lookupSymbol(kResume)) 258 return; 259 260 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 261 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 262 263 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 264 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 265 SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private); 266 267 auto *block = resumeOp.addEntryBlock(); 268 OpBuilder blockBuilder = OpBuilder::atBlockEnd(block); 269 270 blockBuilder.create<LLVM::CallOp>(loc, Type(), 271 blockBuilder.getSymbolRefAttr(kCoroResume), 272 resumeOp.getArgument(0)); 273 274 blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange()); 275 } 276 277 //===----------------------------------------------------------------------===// 278 // async.execute op outlining to the coroutine functions. 279 //===----------------------------------------------------------------------===// 280 281 // Function targeted for coroutine transformation has two additional blocks at 282 // the end: coroutine cleanup and coroutine suspension. 283 // 284 // async.await op lowering additionaly creates a resume block for each 285 // operation to enable non-blocking waiting via coroutine suspension. 286 namespace { 287 struct CoroMachinery { 288 Value asyncToken; 289 Value coroHandle; 290 Block *cleanup; 291 Block *suspend; 292 }; 293 } // namespace 294 295 // Builds an coroutine template compatible with LLVM coroutines lowering. 296 // 297 // - `entry` block sets up the coroutine. 298 // - `cleanup` block cleans up the coroutine state. 299 // - `suspend block after the @llvm.coro.end() defines what value will be 300 // returned to the initial caller of a coroutine. Everything before the 301 // @llvm.coro.end() will be executed at every suspension point. 302 // 303 // Coroutine structure (only the important bits): 304 // 305 // func @async_execute_fn(<function-arguments>) -> !async.token { 306 // ^entryBlock(<function-arguments>): 307 // %token = <async token> : !async.token // create async runtime token 308 // %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle 309 // br ^cleanup 310 // 311 // ^cleanup: 312 // llvm.call @llvm.coro.free(...) // delete coroutine state 313 // br ^suspend 314 // 315 // ^suspend: 316 // llvm.call @llvm.coro.end(...) // marks the end of a coroutine 317 // return %token : !async.token 318 // } 319 // 320 // The actual code for the async.execute operation body region will be inserted 321 // before the entry block terminator. 322 // 323 // 324 static CoroMachinery setupCoroMachinery(FuncOp func) { 325 assert(func.getBody().empty() && "Function must have empty body"); 326 327 MLIRContext *ctx = func.getContext(); 328 329 auto token = LLVM::LLVMTokenType::get(ctx); 330 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 331 auto i32 = LLVM::LLVMType::getInt32Ty(ctx); 332 auto i64 = LLVM::LLVMType::getInt64Ty(ctx); 333 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 334 335 Block *entryBlock = func.addEntryBlock(); 336 Location loc = func.getBody().getLoc(); 337 338 OpBuilder builder = OpBuilder::atBlockBegin(entryBlock); 339 340 // ------------------------------------------------------------------------ // 341 // Allocate async tokens/values that we will return from a ramp function. 342 // ------------------------------------------------------------------------ // 343 auto createToken = 344 builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx)); 345 346 // ------------------------------------------------------------------------ // 347 // Initialize coroutine: allocate frame, get coroutine handle. 348 // ------------------------------------------------------------------------ // 349 350 // Constants for initializing coroutine frame. 351 auto constZero = 352 builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0)); 353 auto constFalse = 354 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 355 auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr); 356 357 // Get coroutine id: @llvm.coro.id 358 auto coroId = builder.create<LLVM::CallOp>( 359 loc, token, builder.getSymbolRefAttr(kCoroId), 360 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 361 362 // Get coroutine frame size: @llvm.coro.size.i64 363 auto coroSize = builder.create<LLVM::CallOp>( 364 loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); 365 366 // Allocate memory for coroutine frame. 367 auto coroAlloc = builder.create<LLVM::CallOp>( 368 loc, i8Ptr, builder.getSymbolRefAttr(kMalloc), 369 ValueRange(coroSize.getResult(0))); 370 371 // Begin a coroutine: @llvm.coro.begin 372 auto coroHdl = builder.create<LLVM::CallOp>( 373 loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin), 374 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); 375 376 Block *cleanupBlock = func.addBlock(); 377 Block *suspendBlock = func.addBlock(); 378 379 // ------------------------------------------------------------------------ // 380 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 381 // ------------------------------------------------------------------------ // 382 builder.setInsertionPointToStart(cleanupBlock); 383 384 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 385 auto coroMem = builder.create<LLVM::CallOp>( 386 loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree), 387 ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); 388 389 // Free the memory. 390 builder.create<LLVM::CallOp>(loc, Type(), builder.getSymbolRefAttr(kFree), 391 ValueRange(coroMem.getResult(0))); 392 // Branch into the suspend block. 393 builder.create<BranchOp>(loc, suspendBlock); 394 395 // ------------------------------------------------------------------------ // 396 // Coroutine suspend block: mark the end of a coroutine and return allocated 397 // async token. 398 // ------------------------------------------------------------------------ // 399 builder.setInsertionPointToStart(suspendBlock); 400 401 // Mark the end of a coroutine: @llvm.coro.end. 402 builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd), 403 ValueRange({coroHdl.getResult(0), constFalse})); 404 405 // Return created `async.token` from the suspend block. This will be the 406 // return value of a coroutine ramp function. 407 builder.create<ReturnOp>(loc, createToken.getResult(0)); 408 409 // Branch from the entry block to the cleanup block to create a valid CFG. 410 builder.setInsertionPointToEnd(entryBlock); 411 412 builder.create<BranchOp>(loc, cleanupBlock); 413 414 // `async.await` op lowering will create resume blocks for async 415 // continuations, and will conditionally branch to cleanup or suspend blocks. 416 417 return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock, 418 suspendBlock}; 419 } 420 421 // Adds a suspension point before the `op`, and moves `op` and all operations 422 // after it into the resume block. Returns a pointer to the resume block. 423 // 424 // `coroState` must be a value returned from the call to @llvm.coro.save(...) 425 // intrinsic (saved coroutine state). 426 // 427 // Before: 428 // 429 // ^bb0: 430 // "opBefore"(...) 431 // "op"(...) 432 // ^cleanup: ... 433 // ^suspend: ... 434 // 435 // After: 436 // 437 // ^bb0: 438 // "opBefore"(...) 439 // %suspend = llmv.call @llvm.coro.suspend(...) 440 // switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 441 // ^resume: 442 // "op"(...) 443 // ^cleanup: ... 444 // ^suspend: ... 445 // 446 static Block *addSuspensionPoint(CoroMachinery coro, Value coroState, 447 Operation *op) { 448 MLIRContext *ctx = op->getContext(); 449 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 450 auto i8 = LLVM::LLVMType::getInt8Ty(ctx); 451 452 Location loc = op->getLoc(); 453 Block *splitBlock = op->getBlock(); 454 455 // Split the block before `op`, newly added block is the resume block. 456 Block *resume = splitBlock->splitBlock(op); 457 458 // Add a coroutine suspension in place of original `op` in the split block. 459 OpBuilder builder = OpBuilder::atBlockEnd(splitBlock); 460 461 auto constFalse = 462 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 463 464 // Suspend a coroutine: @llvm.coro.suspend 465 auto coroSuspend = builder.create<LLVM::CallOp>( 466 loc, i8, builder.getSymbolRefAttr(kCoroSuspend), 467 ValueRange({coroState, constFalse})); 468 469 // After a suspension point decide if we should branch into resume, cleanup 470 // or suspend block of the coroutine (see @llvm.coro.suspend return code 471 // documentation). 472 auto constZero = 473 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0)); 474 auto constNegOne = 475 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1)); 476 477 Block *resumeOrCleanup = builder.createBlock(resume); 478 479 // Suspend the coroutine ...? 480 builder.setInsertionPointToEnd(splitBlock); 481 auto isNegOne = builder.create<LLVM::ICmpOp>( 482 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); 483 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend, 484 /*falseDest=*/resumeOrCleanup); 485 486 // ... or resume or cleanup the coroutine? 487 builder.setInsertionPointToStart(resumeOrCleanup); 488 auto isZero = builder.create<LLVM::ICmpOp>( 489 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); 490 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume, 491 /*falseDest=*/coro.cleanup); 492 493 return resume; 494 } 495 496 // Outline the body region attached to the `async.execute` op into a standalone 497 // function. 498 static std::pair<FuncOp, CoroMachinery> 499 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 500 ModuleOp module = execute.getParentOfType<ModuleOp>(); 501 502 MLIRContext *ctx = module.getContext(); 503 Location loc = execute.getLoc(); 504 505 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 506 507 // Collect all outlined function inputs. 508 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 509 execute.dependencies().end()); 510 getUsedValuesDefinedAbove(execute.body(), functionInputs); 511 512 // Collect types for the outlined function inputs and outputs. 513 auto typesRange = llvm::map_range( 514 functionInputs, [](Value value) { return value.getType(); }); 515 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 516 auto outputTypes = execute.getResultTypes(); 517 518 auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes); 519 auto funcAttrs = ArrayRef<NamedAttribute>(); 520 521 // TODO: Derive outlined function name from the parent FuncOp (support 522 // multiple nested async.execute operations). 523 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 524 symbolTable.insert(func, moduleBuilder.getInsertionPoint()); 525 526 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 527 528 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 529 // blocks, adding llvm.coro instrinsics and setting up control flow. 530 CoroMachinery coro = setupCoroMachinery(func); 531 532 // Suspend async function at the end of an entry block, and resume it using 533 // Async execute API (execution will be resumed in a thread managed by the 534 // async runtime). 535 Block *entryBlock = &func.getBlocks().front(); 536 OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock); 537 538 // A pointer to coroutine resume intrinsic wrapper. 539 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 540 auto resumePtr = builder.create<LLVM::AddressOfOp>( 541 loc, resumeFnTy.getPointerTo(), kResume); 542 543 // Save the coroutine state: @llvm.coro.save 544 auto coroSave = builder.create<LLVM::CallOp>( 545 loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 546 ValueRange({coro.coroHandle})); 547 548 // Call async runtime API to execute a coroutine in the managed thread. 549 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()}; 550 builder.create<CallOp>(loc, Type(), kExecute, executeArgs); 551 552 // Split the entry block before the terminator. 553 Block *resume = addSuspensionPoint(coro, coroSave.getResult(0), 554 entryBlock->getTerminator()); 555 556 // Await on all dependencies before starting to execute the body region. 557 builder.setInsertionPointToStart(resume); 558 for (size_t i = 0; i < execute.dependencies().size(); ++i) 559 builder.create<AwaitOp>(loc, func.getArgument(i)); 560 561 // Map from function inputs defined above the execute op to the function 562 // arguments. 563 BlockAndValueMapping valueMapping; 564 valueMapping.map(functionInputs, func.getArguments()); 565 566 // Clone all operations from the execute operation body into the outlined 567 // function body, and replace all `async.yield` operations with a call 568 // to async runtime to emplace the result token. 569 for (Operation &op : execute.body().getOps()) { 570 if (isa<async::YieldOp>(op)) { 571 builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken); 572 continue; 573 } 574 builder.clone(op, valueMapping); 575 } 576 577 // Replace the original `async.execute` with a call to outlined function. 578 OpBuilder callBuilder(execute); 579 auto callOutlinedFunc = 580 callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(), 581 functionInputs.getArrayRef()); 582 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 583 execute.erase(); 584 585 return {func, coro}; 586 } 587 588 //===----------------------------------------------------------------------===// 589 // Convert Async dialect types to LLVM types. 590 //===----------------------------------------------------------------------===// 591 592 namespace { 593 class AsyncRuntimeTypeConverter : public TypeConverter { 594 public: 595 AsyncRuntimeTypeConverter() { addConversion(convertType); } 596 597 static Type convertType(Type type) { 598 MLIRContext *ctx = type.getContext(); 599 // Convert async tokens and groups to opaque pointers. 600 if (type.isa<TokenType, GroupType>()) 601 return LLVM::LLVMType::getInt8PtrTy(ctx); 602 return type; 603 } 604 }; 605 } // namespace 606 607 //===----------------------------------------------------------------------===// 608 // Convert types for all call operations to lowered async types. 609 //===----------------------------------------------------------------------===// 610 611 namespace { 612 class CallOpOpConversion : public ConversionPattern { 613 public: 614 explicit CallOpOpConversion(MLIRContext *ctx) 615 : ConversionPattern(CallOp::getOperationName(), 1, ctx) {} 616 617 LogicalResult 618 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 619 ConversionPatternRewriter &rewriter) const override { 620 AsyncRuntimeTypeConverter converter; 621 622 SmallVector<Type, 5> resultTypes; 623 converter.convertTypes(op->getResultTypes(), resultTypes); 624 625 CallOp call = cast<CallOp>(op); 626 rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(), 627 call.getOperands()); 628 629 return success(); 630 } 631 }; 632 } // namespace 633 634 //===----------------------------------------------------------------------===// 635 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 636 //===----------------------------------------------------------------------===// 637 638 namespace { 639 class CreateGroupOpLowering : public ConversionPattern { 640 public: 641 explicit CreateGroupOpLowering(MLIRContext *ctx) 642 : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {} 643 644 LogicalResult 645 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 646 ConversionPatternRewriter &rewriter) const override { 647 auto retTy = GroupType::get(op->getContext()); 648 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy); 649 return success(); 650 } 651 }; 652 } // namespace 653 654 //===----------------------------------------------------------------------===// 655 // async.add_to_group op lowering to runtime function call. 656 //===----------------------------------------------------------------------===// 657 658 namespace { 659 class AddToGroupOpLowering : public ConversionPattern { 660 public: 661 explicit AddToGroupOpLowering(MLIRContext *ctx) 662 : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {} 663 664 LogicalResult 665 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 666 ConversionPatternRewriter &rewriter) const override { 667 // Currently we can only add tokens to the group. 668 auto addToGroup = cast<AddToGroupOp>(op); 669 if (!addToGroup.operand().getType().isa<TokenType>()) 670 return failure(); 671 672 auto i64 = IntegerType::get(64, op->getContext()); 673 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands); 674 return success(); 675 } 676 }; 677 } // namespace 678 679 //===----------------------------------------------------------------------===// 680 // async.await and async.await_all op lowerings to the corresponding async 681 // runtime function calls. 682 //===----------------------------------------------------------------------===// 683 684 namespace { 685 686 template <typename AwaitType, typename AwaitableType> 687 class AwaitOpLoweringBase : public ConversionPattern { 688 protected: 689 explicit AwaitOpLoweringBase( 690 MLIRContext *ctx, 691 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions, 692 StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) 693 : ConversionPattern(AwaitType::getOperationName(), 1, ctx), 694 outlinedFunctions(outlinedFunctions), 695 blockingAwaitFuncName(blockingAwaitFuncName), 696 coroAwaitFuncName(coroAwaitFuncName) {} 697 698 public: 699 LogicalResult 700 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 701 ConversionPatternRewriter &rewriter) const override { 702 // We can only await on one the `AwaitableType` (for `await` it can be 703 // only a `token`, for `await_all` it is a `group`). 704 auto await = cast<AwaitType>(op); 705 if (!await.operand().getType().template isa<AwaitableType>()) 706 return failure(); 707 708 // Check if await operation is inside the outlined coroutine function. 709 auto func = await.template getParentOfType<FuncOp>(); 710 auto outlined = outlinedFunctions.find(func); 711 const bool isInCoroutine = outlined != outlinedFunctions.end(); 712 713 Location loc = op->getLoc(); 714 715 // Inside regular function we convert await operation to the blocking 716 // async API await function call. 717 if (!isInCoroutine) 718 rewriter.create<CallOp>(loc, Type(), blockingAwaitFuncName, 719 ValueRange(op->getOperand(0))); 720 721 // Inside the coroutine we convert await operation into coroutine suspension 722 // point, and resume execution asynchronously. 723 if (isInCoroutine) { 724 const CoroMachinery &coro = outlined->getSecond(); 725 726 OpBuilder builder(op); 727 MLIRContext *ctx = op->getContext(); 728 729 // A pointer to coroutine resume intrinsic wrapper. 730 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 731 auto resumePtr = builder.create<LLVM::AddressOfOp>( 732 loc, resumeFnTy.getPointerTo(), kResume); 733 734 // Save the coroutine state: @llvm.coro.save 735 auto coroSave = builder.create<LLVM::CallOp>( 736 loc, LLVM::LLVMTokenType::get(ctx), 737 builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle)); 738 739 // Call async runtime API to resume a coroutine in the managed thread when 740 // the async await argument becomes ready. 741 SmallVector<Value, 3> awaitAndExecuteArgs = { 742 await.getOperand(), coro.coroHandle, resumePtr.res()}; 743 builder.create<CallOp>(loc, Type(), coroAwaitFuncName, 744 awaitAndExecuteArgs); 745 746 // Split the entry block before the await operation. 747 addSuspensionPoint(coro, coroSave.getResult(0), op); 748 } 749 750 // Original operation was replaced by function call or suspension point. 751 rewriter.eraseOp(op); 752 753 return success(); 754 } 755 756 private: 757 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 758 StringRef blockingAwaitFuncName; 759 StringRef coroAwaitFuncName; 760 }; 761 762 // Lowering for `async.await` operation (only token operands are supported). 763 class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 764 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 765 766 public: 767 explicit AwaitOpLowering( 768 MLIRContext *ctx, 769 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 770 : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {} 771 }; 772 773 // Lowering for `async.await_all` operation. 774 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 775 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 776 777 public: 778 explicit AwaitAllOpLowering( 779 MLIRContext *ctx, 780 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 781 : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {} 782 }; 783 784 } // namespace 785 786 //===----------------------------------------------------------------------===// 787 788 namespace { 789 struct ConvertAsyncToLLVMPass 790 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 791 void runOnOperation() override; 792 }; 793 794 void ConvertAsyncToLLVMPass::runOnOperation() { 795 ModuleOp module = getOperation(); 796 SymbolTable symbolTable(module); 797 798 // Outline all `async.execute` body regions into async functions (coroutines). 799 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 800 801 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 802 // We currently do not support execute operations that have async value 803 // operands or produce async results. 804 if (!execute.operands().empty() || !execute.results().empty()) { 805 execute.emitOpError("can't outline async.execute op with async value " 806 "operands or returned async results"); 807 return WalkResult::interrupt(); 808 } 809 810 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 811 812 return WalkResult::advance(); 813 }); 814 815 // Failed to outline all async execute operations. 816 if (outlineResult.wasInterrupted()) { 817 signalPassFailure(); 818 return; 819 } 820 821 LLVM_DEBUG({ 822 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 823 << " async functions\n"; 824 }); 825 826 // Add declarations for all functions required by the coroutines lowering. 827 addResumeFunction(module); 828 addAsyncRuntimeApiDeclarations(module); 829 addCoroutineIntrinsicsDeclarations(module); 830 addCRuntimeDeclarations(module); 831 832 MLIRContext *ctx = &getContext(); 833 834 // Convert async dialect types and operations to LLVM dialect. 835 AsyncRuntimeTypeConverter converter; 836 OwningRewritePatternList patterns; 837 838 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 839 patterns.insert<CallOpOpConversion>(ctx); 840 patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 841 patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions); 842 843 ConversionTarget target(*ctx); 844 target.addLegalDialect<LLVM::LLVMDialect>(); 845 target.addIllegalDialect<AsyncDialect>(); 846 target.addDynamicallyLegalOp<FuncOp>( 847 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 848 target.addDynamicallyLegalOp<CallOp>( 849 [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); 850 851 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 852 signalPassFailure(); 853 } 854 } // namespace 855 856 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 857 return std::make_unique<ConvertAsyncToLLVMPass>(); 858 } 859