1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Async/IR/Async.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "mlir/Transforms/RegionUtils.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/Support/FormatVariadic.h" 23 24 #define DEBUG_TYPE "convert-async-to-llvm" 25 26 using namespace mlir; 27 using namespace mlir::async; 28 29 // Prefix for functions outlined from `async.execute` op regions. 30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 39 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 40 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 41 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 42 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 43 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 44 static constexpr const char *kAddTokenToGroup = 45 "mlirAsyncRuntimeAddTokenToGroup"; 46 static constexpr const char *kAwaitAndExecute = 47 "mlirAsyncRuntimeAwaitTokenAndExecute"; 48 static constexpr const char *kAwaitAllAndExecute = 49 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 50 51 namespace { 52 // Async Runtime API function types. 53 struct AsyncAPI { 54 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 55 auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); 56 auto count = IntegerType::get(32, ctx); 57 return FunctionType::get({ref, count}, {}, ctx); 58 } 59 60 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 61 return FunctionType::get({}, {TokenType::get(ctx)}, ctx); 62 } 63 64 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 65 return FunctionType::get({}, {GroupType::get(ctx)}, ctx); 66 } 67 68 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 69 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 70 } 71 72 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 73 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 74 } 75 76 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 77 return FunctionType::get({GroupType::get(ctx)}, {}, ctx); 78 } 79 80 static FunctionType executeFunctionType(MLIRContext *ctx) { 81 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 82 auto resume = resumeFunctionType(ctx).getPointerTo(); 83 return FunctionType::get({hdl, resume}, {}, ctx); 84 } 85 86 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 87 auto i64 = IntegerType::get(64, ctx); 88 return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64}, 89 ctx); 90 } 91 92 static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { 93 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 94 auto resume = resumeFunctionType(ctx).getPointerTo(); 95 return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); 96 } 97 98 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 99 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 100 auto resume = resumeFunctionType(ctx).getPointerTo(); 101 return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx); 102 } 103 104 // Auxiliary coroutine resume intrinsic wrapper. 105 static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { 106 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 107 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 108 return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); 109 } 110 }; 111 } // namespace 112 113 // Adds Async Runtime C API declarations to the module. 114 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 115 auto builder = OpBuilder::atBlockTerminator(module.getBody()); 116 117 auto addFuncDecl = [&](StringRef name, FunctionType type) { 118 if (module.lookupSymbol(name)) 119 return; 120 builder.create<FuncOp>(module.getLoc(), name, type).setPrivate(); 121 }; 122 123 MLIRContext *ctx = module.getContext(); 124 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 125 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 126 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 127 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 128 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 129 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 130 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 131 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 132 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 133 addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx)); 134 addFuncDecl(kAwaitAllAndExecute, 135 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 136 } 137 138 //===----------------------------------------------------------------------===// 139 // LLVM coroutines intrinsics declarations. 140 //===----------------------------------------------------------------------===// 141 142 static constexpr const char *kCoroId = "llvm.coro.id"; 143 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 144 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 145 static constexpr const char *kCoroSave = "llvm.coro.save"; 146 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 147 static constexpr const char *kCoroEnd = "llvm.coro.end"; 148 static constexpr const char *kCoroFree = "llvm.coro.free"; 149 static constexpr const char *kCoroResume = "llvm.coro.resume"; 150 151 /// Adds an LLVM function declaration to a module. 152 static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name, 153 LLVM::LLVMType ret, 154 ArrayRef<LLVM::LLVMType> params) { 155 if (module.lookupSymbol(name)) 156 return; 157 LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false); 158 builder.create<LLVM::LLVMFuncOp>(module.getLoc(), name, type); 159 } 160 161 /// Adds coroutine intrinsics declarations to the module. 162 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 163 using namespace mlir::LLVM; 164 165 MLIRContext *ctx = module.getContext(); 166 OpBuilder builder(module.getBody()->getTerminator()); 167 168 auto token = LLVMTokenType::get(ctx); 169 auto voidTy = LLVMType::getVoidTy(ctx); 170 171 auto i8 = LLVMType::getInt8Ty(ctx); 172 auto i1 = LLVMType::getInt1Ty(ctx); 173 auto i32 = LLVMType::getInt32Ty(ctx); 174 auto i64 = LLVMType::getInt64Ty(ctx); 175 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 176 177 addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); 178 addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); 179 addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); 180 addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); 181 addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); 182 addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); 183 addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); 184 addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); 185 } 186 187 //===----------------------------------------------------------------------===// 188 // Add malloc/free declarations to the module. 189 //===----------------------------------------------------------------------===// 190 191 static constexpr const char *kMalloc = "malloc"; 192 static constexpr const char *kFree = "free"; 193 194 /// Adds malloc/free declarations to the module. 195 static void addCRuntimeDeclarations(ModuleOp module) { 196 using namespace mlir::LLVM; 197 198 MLIRContext *ctx = module.getContext(); 199 OpBuilder builder(module.getBody()->getTerminator()); 200 201 auto voidTy = LLVMType::getVoidTy(ctx); 202 auto i64 = LLVMType::getInt64Ty(ctx); 203 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 204 205 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 206 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 207 } 208 209 //===----------------------------------------------------------------------===// 210 // Coroutine resume function wrapper. 211 //===----------------------------------------------------------------------===// 212 213 static constexpr const char *kResume = "__resume"; 214 215 // A function that takes a coroutine handle and calls a `llvm.coro.resume` 216 // intrinsics. We need this function to be able to pass it to the async 217 // runtime execute API. 218 static void addResumeFunction(ModuleOp module) { 219 MLIRContext *ctx = module.getContext(); 220 221 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 222 Location loc = module.getLoc(); 223 224 if (module.lookupSymbol(kResume)) 225 return; 226 227 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 228 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 229 230 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 231 loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false)); 232 resumeOp.setPrivate(); 233 234 auto *block = resumeOp.addEntryBlock(); 235 OpBuilder blockBuilder = OpBuilder::atBlockEnd(block); 236 237 blockBuilder.create<LLVM::CallOp>(loc, TypeRange(), 238 blockBuilder.getSymbolRefAttr(kCoroResume), 239 resumeOp.getArgument(0)); 240 241 blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange()); 242 } 243 244 //===----------------------------------------------------------------------===// 245 // async.execute op outlining to the coroutine functions. 246 //===----------------------------------------------------------------------===// 247 248 // Function targeted for coroutine transformation has two additional blocks at 249 // the end: coroutine cleanup and coroutine suspension. 250 // 251 // async.await op lowering additionaly creates a resume block for each 252 // operation to enable non-blocking waiting via coroutine suspension. 253 namespace { 254 struct CoroMachinery { 255 Value asyncToken; 256 Value coroHandle; 257 Block *cleanup; 258 Block *suspend; 259 }; 260 } // namespace 261 262 // Builds an coroutine template compatible with LLVM coroutines lowering. 263 // 264 // - `entry` block sets up the coroutine. 265 // - `cleanup` block cleans up the coroutine state. 266 // - `suspend block after the @llvm.coro.end() defines what value will be 267 // returned to the initial caller of a coroutine. Everything before the 268 // @llvm.coro.end() will be executed at every suspension point. 269 // 270 // Coroutine structure (only the important bits): 271 // 272 // func @async_execute_fn(<function-arguments>) -> !async.token { 273 // ^entryBlock(<function-arguments>): 274 // %token = <async token> : !async.token // create async runtime token 275 // %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle 276 // br ^cleanup 277 // 278 // ^cleanup: 279 // llvm.call @llvm.coro.free(...) // delete coroutine state 280 // br ^suspend 281 // 282 // ^suspend: 283 // llvm.call @llvm.coro.end(...) // marks the end of a coroutine 284 // return %token : !async.token 285 // } 286 // 287 // The actual code for the async.execute operation body region will be inserted 288 // before the entry block terminator. 289 // 290 // 291 static CoroMachinery setupCoroMachinery(FuncOp func) { 292 assert(func.getBody().empty() && "Function must have empty body"); 293 294 MLIRContext *ctx = func.getContext(); 295 296 auto token = LLVM::LLVMTokenType::get(ctx); 297 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 298 auto i32 = LLVM::LLVMType::getInt32Ty(ctx); 299 auto i64 = LLVM::LLVMType::getInt64Ty(ctx); 300 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 301 302 Block *entryBlock = func.addEntryBlock(); 303 Location loc = func.getBody().getLoc(); 304 305 OpBuilder builder = OpBuilder::atBlockBegin(entryBlock); 306 307 // ------------------------------------------------------------------------ // 308 // Allocate async tokens/values that we will return from a ramp function. 309 // ------------------------------------------------------------------------ // 310 auto createToken = 311 builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx)); 312 313 // ------------------------------------------------------------------------ // 314 // Initialize coroutine: allocate frame, get coroutine handle. 315 // ------------------------------------------------------------------------ // 316 317 // Constants for initializing coroutine frame. 318 auto constZero = 319 builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0)); 320 auto constFalse = 321 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 322 auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr); 323 324 // Get coroutine id: @llvm.coro.id 325 auto coroId = builder.create<LLVM::CallOp>( 326 loc, token, builder.getSymbolRefAttr(kCoroId), 327 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 328 329 // Get coroutine frame size: @llvm.coro.size.i64 330 auto coroSize = builder.create<LLVM::CallOp>( 331 loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); 332 333 // Allocate memory for coroutine frame. 334 auto coroAlloc = builder.create<LLVM::CallOp>( 335 loc, i8Ptr, builder.getSymbolRefAttr(kMalloc), 336 ValueRange(coroSize.getResult(0))); 337 338 // Begin a coroutine: @llvm.coro.begin 339 auto coroHdl = builder.create<LLVM::CallOp>( 340 loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin), 341 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); 342 343 Block *cleanupBlock = func.addBlock(); 344 Block *suspendBlock = func.addBlock(); 345 346 // ------------------------------------------------------------------------ // 347 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 348 // ------------------------------------------------------------------------ // 349 builder.setInsertionPointToStart(cleanupBlock); 350 351 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 352 auto coroMem = builder.create<LLVM::CallOp>( 353 loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree), 354 ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); 355 356 // Free the memory. 357 builder.create<LLVM::CallOp>(loc, TypeRange(), 358 builder.getSymbolRefAttr(kFree), 359 ValueRange(coroMem.getResult(0))); 360 // Branch into the suspend block. 361 builder.create<BranchOp>(loc, suspendBlock); 362 363 // ------------------------------------------------------------------------ // 364 // Coroutine suspend block: mark the end of a coroutine and return allocated 365 // async token. 366 // ------------------------------------------------------------------------ // 367 builder.setInsertionPointToStart(suspendBlock); 368 369 // Mark the end of a coroutine: @llvm.coro.end. 370 builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd), 371 ValueRange({coroHdl.getResult(0), constFalse})); 372 373 // Return created `async.token` from the suspend block. This will be the 374 // return value of a coroutine ramp function. 375 builder.create<ReturnOp>(loc, createToken.getResult(0)); 376 377 // Branch from the entry block to the cleanup block to create a valid CFG. 378 builder.setInsertionPointToEnd(entryBlock); 379 380 builder.create<BranchOp>(loc, cleanupBlock); 381 382 // `async.await` op lowering will create resume blocks for async 383 // continuations, and will conditionally branch to cleanup or suspend blocks. 384 385 return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock, 386 suspendBlock}; 387 } 388 389 // Adds a suspension point before the `op`, and moves `op` and all operations 390 // after it into the resume block. Returns a pointer to the resume block. 391 // 392 // `coroState` must be a value returned from the call to @llvm.coro.save(...) 393 // intrinsic (saved coroutine state). 394 // 395 // Before: 396 // 397 // ^bb0: 398 // "opBefore"(...) 399 // "op"(...) 400 // ^cleanup: ... 401 // ^suspend: ... 402 // 403 // After: 404 // 405 // ^bb0: 406 // "opBefore"(...) 407 // %suspend = llmv.call @llvm.coro.suspend(...) 408 // switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 409 // ^resume: 410 // "op"(...) 411 // ^cleanup: ... 412 // ^suspend: ... 413 // 414 static Block *addSuspensionPoint(CoroMachinery coro, Value coroState, 415 Operation *op) { 416 MLIRContext *ctx = op->getContext(); 417 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 418 auto i8 = LLVM::LLVMType::getInt8Ty(ctx); 419 420 Location loc = op->getLoc(); 421 Block *splitBlock = op->getBlock(); 422 423 // Split the block before `op`, newly added block is the resume block. 424 Block *resume = splitBlock->splitBlock(op); 425 426 // Add a coroutine suspension in place of original `op` in the split block. 427 OpBuilder builder = OpBuilder::atBlockEnd(splitBlock); 428 429 auto constFalse = 430 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 431 432 // Suspend a coroutine: @llvm.coro.suspend 433 auto coroSuspend = builder.create<LLVM::CallOp>( 434 loc, i8, builder.getSymbolRefAttr(kCoroSuspend), 435 ValueRange({coroState, constFalse})); 436 437 // After a suspension point decide if we should branch into resume, cleanup 438 // or suspend block of the coroutine (see @llvm.coro.suspend return code 439 // documentation). 440 auto constZero = 441 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0)); 442 auto constNegOne = 443 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1)); 444 445 Block *resumeOrCleanup = builder.createBlock(resume); 446 447 // Suspend the coroutine ...? 448 builder.setInsertionPointToEnd(splitBlock); 449 auto isNegOne = builder.create<LLVM::ICmpOp>( 450 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); 451 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend, 452 /*falseDest=*/resumeOrCleanup); 453 454 // ... or resume or cleanup the coroutine? 455 builder.setInsertionPointToStart(resumeOrCleanup); 456 auto isZero = builder.create<LLVM::ICmpOp>( 457 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); 458 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume, 459 /*falseDest=*/coro.cleanup); 460 461 return resume; 462 } 463 464 // Outline the body region attached to the `async.execute` op into a standalone 465 // function. 466 static std::pair<FuncOp, CoroMachinery> 467 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 468 ModuleOp module = execute.getParentOfType<ModuleOp>(); 469 470 MLIRContext *ctx = module.getContext(); 471 Location loc = execute.getLoc(); 472 473 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 474 475 // Collect all outlined function inputs. 476 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 477 execute.dependencies().end()); 478 getUsedValuesDefinedAbove(execute.body(), functionInputs); 479 480 // Collect types for the outlined function inputs and outputs. 481 auto typesRange = llvm::map_range( 482 functionInputs, [](Value value) { return value.getType(); }); 483 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 484 auto outputTypes = execute.getResultTypes(); 485 486 auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes); 487 auto funcAttrs = ArrayRef<NamedAttribute>(); 488 489 // TODO: Derive outlined function name from the parent FuncOp (support 490 // multiple nested async.execute operations). 491 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 492 symbolTable.insert(func, moduleBuilder.getInsertionPoint()); 493 494 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 495 496 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 497 // blocks, adding llvm.coro instrinsics and setting up control flow. 498 CoroMachinery coro = setupCoroMachinery(func); 499 500 // Suspend async function at the end of an entry block, and resume it using 501 // Async execute API (execution will be resumed in a thread managed by the 502 // async runtime). 503 Block *entryBlock = &func.getBlocks().front(); 504 OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock); 505 506 // A pointer to coroutine resume intrinsic wrapper. 507 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 508 auto resumePtr = builder.create<LLVM::AddressOfOp>( 509 loc, resumeFnTy.getPointerTo(), kResume); 510 511 // Save the coroutine state: @llvm.coro.save 512 auto coroSave = builder.create<LLVM::CallOp>( 513 loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 514 ValueRange({coro.coroHandle})); 515 516 // Call async runtime API to execute a coroutine in the managed thread. 517 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()}; 518 builder.create<CallOp>(loc, TypeRange(), kExecute, executeArgs); 519 520 // Split the entry block before the terminator. 521 Block *resume = addSuspensionPoint(coro, coroSave.getResult(0), 522 entryBlock->getTerminator()); 523 524 // Await on all dependencies before starting to execute the body region. 525 builder.setInsertionPointToStart(resume); 526 for (size_t i = 0; i < execute.dependencies().size(); ++i) 527 builder.create<AwaitOp>(loc, func.getArgument(i)); 528 529 // Map from function inputs defined above the execute op to the function 530 // arguments. 531 BlockAndValueMapping valueMapping; 532 valueMapping.map(functionInputs, func.getArguments()); 533 534 // Clone all operations from the execute operation body into the outlined 535 // function body, and replace all `async.yield` operations with a call 536 // to async runtime to emplace the result token. 537 for (Operation &op : execute.body().getOps()) { 538 if (isa<async::YieldOp>(op)) { 539 builder.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken); 540 continue; 541 } 542 builder.clone(op, valueMapping); 543 } 544 545 // Replace the original `async.execute` with a call to outlined function. 546 OpBuilder callBuilder(execute); 547 auto callOutlinedFunc = 548 callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(), 549 functionInputs.getArrayRef()); 550 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 551 execute.erase(); 552 553 return {func, coro}; 554 } 555 556 //===----------------------------------------------------------------------===// 557 // Convert Async dialect types to LLVM types. 558 //===----------------------------------------------------------------------===// 559 560 namespace { 561 class AsyncRuntimeTypeConverter : public TypeConverter { 562 public: 563 AsyncRuntimeTypeConverter() { addConversion(convertType); } 564 565 static Type convertType(Type type) { 566 MLIRContext *ctx = type.getContext(); 567 // Convert async tokens and groups to opaque pointers. 568 if (type.isa<TokenType, GroupType>()) 569 return LLVM::LLVMType::getInt8PtrTy(ctx); 570 return type; 571 } 572 }; 573 } // namespace 574 575 //===----------------------------------------------------------------------===// 576 // Convert types for all call operations to lowered async types. 577 //===----------------------------------------------------------------------===// 578 579 namespace { 580 class CallOpOpConversion : public ConversionPattern { 581 public: 582 explicit CallOpOpConversion(MLIRContext *ctx) 583 : ConversionPattern(CallOp::getOperationName(), 1, ctx) {} 584 585 LogicalResult 586 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 587 ConversionPatternRewriter &rewriter) const override { 588 AsyncRuntimeTypeConverter converter; 589 590 SmallVector<Type, 5> resultTypes; 591 converter.convertTypes(op->getResultTypes(), resultTypes); 592 593 CallOp call = cast<CallOp>(op); 594 rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(), 595 operands); 596 597 return success(); 598 } 599 }; 600 } // namespace 601 602 //===----------------------------------------------------------------------===// 603 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` 604 // to the corresponding API calls). 605 //===----------------------------------------------------------------------===// 606 607 namespace { 608 609 template <typename RefCountingOp> 610 class RefCountingOpLowering : public ConversionPattern { 611 public: 612 explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName) 613 : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx), 614 apiFunctionName(apiFunctionName) {} 615 616 LogicalResult 617 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 618 ConversionPatternRewriter &rewriter) const override { 619 RefCountingOp refCountingOp = cast<RefCountingOp>(op); 620 621 auto count = rewriter.create<ConstantOp>( 622 op->getLoc(), rewriter.getI32Type(), 623 rewriter.getI32IntegerAttr(refCountingOp.count())); 624 625 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 626 ValueRange({operands[0], count})); 627 628 return success(); 629 } 630 631 private: 632 StringRef apiFunctionName; 633 }; 634 635 // async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. 636 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> { 637 public: 638 explicit AddRefOpLowering(MLIRContext *ctx) 639 : RefCountingOpLowering(ctx, kAddRef) {} 640 }; 641 642 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 643 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> { 644 public: 645 explicit DropRefOpLowering(MLIRContext *ctx) 646 : RefCountingOpLowering(ctx, kDropRef) {} 647 }; 648 649 } // namespace 650 651 //===----------------------------------------------------------------------===// 652 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 653 //===----------------------------------------------------------------------===// 654 655 namespace { 656 class CreateGroupOpLowering : public ConversionPattern { 657 public: 658 explicit CreateGroupOpLowering(MLIRContext *ctx) 659 : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {} 660 661 LogicalResult 662 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 663 ConversionPatternRewriter &rewriter) const override { 664 auto retTy = GroupType::get(op->getContext()); 665 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy); 666 return success(); 667 } 668 }; 669 } // namespace 670 671 //===----------------------------------------------------------------------===// 672 // async.add_to_group op lowering to runtime function call. 673 //===----------------------------------------------------------------------===// 674 675 namespace { 676 class AddToGroupOpLowering : public ConversionPattern { 677 public: 678 explicit AddToGroupOpLowering(MLIRContext *ctx) 679 : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {} 680 681 LogicalResult 682 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 683 ConversionPatternRewriter &rewriter) const override { 684 // Currently we can only add tokens to the group. 685 auto addToGroup = cast<AddToGroupOp>(op); 686 if (!addToGroup.operand().getType().isa<TokenType>()) 687 return failure(); 688 689 auto i64 = IntegerType::get(64, op->getContext()); 690 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands); 691 return success(); 692 } 693 }; 694 } // namespace 695 696 //===----------------------------------------------------------------------===// 697 // async.await and async.await_all op lowerings to the corresponding async 698 // runtime function calls. 699 //===----------------------------------------------------------------------===// 700 701 namespace { 702 703 template <typename AwaitType, typename AwaitableType> 704 class AwaitOpLoweringBase : public ConversionPattern { 705 protected: 706 explicit AwaitOpLoweringBase( 707 MLIRContext *ctx, 708 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions, 709 StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) 710 : ConversionPattern(AwaitType::getOperationName(), 1, ctx), 711 outlinedFunctions(outlinedFunctions), 712 blockingAwaitFuncName(blockingAwaitFuncName), 713 coroAwaitFuncName(coroAwaitFuncName) {} 714 715 public: 716 LogicalResult 717 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 718 ConversionPatternRewriter &rewriter) const override { 719 // We can only await on one the `AwaitableType` (for `await` it can be 720 // only a `token`, for `await_all` it is a `group`). 721 auto await = cast<AwaitType>(op); 722 if (!await.operand().getType().template isa<AwaitableType>()) 723 return failure(); 724 725 // Check if await operation is inside the outlined coroutine function. 726 auto func = await.template getParentOfType<FuncOp>(); 727 auto outlined = outlinedFunctions.find(func); 728 const bool isInCoroutine = outlined != outlinedFunctions.end(); 729 730 Location loc = op->getLoc(); 731 732 // Inside regular function we convert await operation to the blocking 733 // async API await function call. 734 if (!isInCoroutine) 735 rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName, 736 ValueRange(operands[0])); 737 738 // Inside the coroutine we convert await operation into coroutine suspension 739 // point, and resume execution asynchronously. 740 if (isInCoroutine) { 741 const CoroMachinery &coro = outlined->getSecond(); 742 743 OpBuilder builder(op); 744 MLIRContext *ctx = op->getContext(); 745 746 // A pointer to coroutine resume intrinsic wrapper. 747 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 748 auto resumePtr = builder.create<LLVM::AddressOfOp>( 749 loc, resumeFnTy.getPointerTo(), kResume); 750 751 // Save the coroutine state: @llvm.coro.save 752 auto coroSave = builder.create<LLVM::CallOp>( 753 loc, LLVM::LLVMTokenType::get(ctx), 754 builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle)); 755 756 // Call async runtime API to resume a coroutine in the managed thread when 757 // the async await argument becomes ready. 758 SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle, 759 resumePtr.res()}; 760 builder.create<CallOp>(loc, TypeRange(), coroAwaitFuncName, 761 awaitAndExecuteArgs); 762 763 // Split the entry block before the await operation. 764 addSuspensionPoint(coro, coroSave.getResult(0), op); 765 } 766 767 // Original operation was replaced by function call or suspension point. 768 rewriter.eraseOp(op); 769 770 return success(); 771 } 772 773 private: 774 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 775 StringRef blockingAwaitFuncName; 776 StringRef coroAwaitFuncName; 777 }; 778 779 // Lowering for `async.await` operation (only token operands are supported). 780 class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 781 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 782 783 public: 784 explicit AwaitOpLowering( 785 MLIRContext *ctx, 786 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 787 : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {} 788 }; 789 790 // Lowering for `async.await_all` operation. 791 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 792 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 793 794 public: 795 explicit AwaitAllOpLowering( 796 MLIRContext *ctx, 797 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 798 : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {} 799 }; 800 801 } // namespace 802 803 //===----------------------------------------------------------------------===// 804 805 namespace { 806 struct ConvertAsyncToLLVMPass 807 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 808 void runOnOperation() override; 809 }; 810 811 void ConvertAsyncToLLVMPass::runOnOperation() { 812 ModuleOp module = getOperation(); 813 SymbolTable symbolTable(module); 814 815 // Outline all `async.execute` body regions into async functions (coroutines). 816 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 817 818 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 819 // We currently do not support execute operations that have async value 820 // operands or produce async results. 821 if (!execute.operands().empty() || !execute.results().empty()) { 822 execute.emitOpError("can't outline async.execute op with async value " 823 "operands or returned async results"); 824 return WalkResult::interrupt(); 825 } 826 827 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 828 829 return WalkResult::advance(); 830 }); 831 832 // Failed to outline all async execute operations. 833 if (outlineResult.wasInterrupted()) { 834 signalPassFailure(); 835 return; 836 } 837 838 LLVM_DEBUG({ 839 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 840 << " async functions\n"; 841 }); 842 843 // Add declarations for all functions required by the coroutines lowering. 844 addResumeFunction(module); 845 addAsyncRuntimeApiDeclarations(module); 846 addCoroutineIntrinsicsDeclarations(module); 847 addCRuntimeDeclarations(module); 848 849 MLIRContext *ctx = &getContext(); 850 851 // Convert async dialect types and operations to LLVM dialect. 852 AsyncRuntimeTypeConverter converter; 853 OwningRewritePatternList patterns; 854 855 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 856 patterns.insert<CallOpOpConversion>(ctx); 857 patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx); 858 patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 859 patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions); 860 861 ConversionTarget target(*ctx); 862 target.addLegalOp<ConstantOp>(); 863 target.addLegalDialect<LLVM::LLVMDialect>(); 864 target.addIllegalDialect<AsyncDialect>(); 865 target.addDynamicallyLegalOp<FuncOp>( 866 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 867 target.addDynamicallyLegalOp<CallOp>( 868 [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); 869 870 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 871 signalPassFailure(); 872 } 873 } // namespace 874 875 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 876 return std::make_unique<ConvertAsyncToLLVMPass>(); 877 } 878