1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Async/IR/Async.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "mlir/Transforms/RegionUtils.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/Support/FormatVariadic.h" 23 24 #define DEBUG_TYPE "convert-async-to-llvm" 25 26 using namespace mlir; 27 using namespace mlir::async; 28 29 // Prefix for functions outlined from `async.execute` op regions. 30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 37 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 38 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 39 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 40 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 41 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 42 static constexpr const char *kAddTokenToGroup = 43 "mlirAsyncRuntimeAddTokenToGroup"; 44 static constexpr const char *kAwaitAndExecute = 45 "mlirAsyncRuntimeAwaitTokenAndExecute"; 46 static constexpr const char *kAwaitAllAndExecute = 47 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 48 49 namespace { 50 // Async Runtime API function types. 51 struct AsyncAPI { 52 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 53 return FunctionType::get({}, {TokenType::get(ctx)}, ctx); 54 } 55 56 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 57 return FunctionType::get({}, {GroupType::get(ctx)}, ctx); 58 } 59 60 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 61 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 62 } 63 64 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 65 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 66 } 67 68 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 69 return FunctionType::get({GroupType::get(ctx)}, {}, ctx); 70 } 71 72 static FunctionType executeFunctionType(MLIRContext *ctx) { 73 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 74 auto resume = resumeFunctionType(ctx).getPointerTo(); 75 return FunctionType::get({hdl, resume}, {}, ctx); 76 } 77 78 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 79 auto i64 = IntegerType::get(64, ctx); 80 return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64}, 81 ctx); 82 } 83 84 static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { 85 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 86 auto resume = resumeFunctionType(ctx).getPointerTo(); 87 return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); 88 } 89 90 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 91 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 92 auto resume = resumeFunctionType(ctx).getPointerTo(); 93 return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx); 94 } 95 96 // Auxiliary coroutine resume intrinsic wrapper. 97 static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { 98 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 99 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 100 return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); 101 } 102 }; 103 } // namespace 104 105 // Adds Async Runtime C API declarations to the module. 106 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 107 auto builder = OpBuilder::atBlockTerminator(module.getBody()); 108 109 auto addFuncDecl = [&](StringRef name, FunctionType type) { 110 if (module.lookupSymbol(name)) 111 return; 112 builder.create<FuncOp>(module.getLoc(), name, type).setPrivate(); 113 }; 114 115 MLIRContext *ctx = module.getContext(); 116 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 117 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 118 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 119 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 120 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 121 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 122 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 123 addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx)); 124 addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 125 } 126 127 //===----------------------------------------------------------------------===// 128 // LLVM coroutines intrinsics declarations. 129 //===----------------------------------------------------------------------===// 130 131 static constexpr const char *kCoroId = "llvm.coro.id"; 132 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 133 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 134 static constexpr const char *kCoroSave = "llvm.coro.save"; 135 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 136 static constexpr const char *kCoroEnd = "llvm.coro.end"; 137 static constexpr const char *kCoroFree = "llvm.coro.free"; 138 static constexpr const char *kCoroResume = "llvm.coro.resume"; 139 140 /// Adds an LLVM function declaration to a module. 141 static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name, 142 LLVM::LLVMType ret, 143 ArrayRef<LLVM::LLVMType> params) { 144 if (module.lookupSymbol(name)) 145 return; 146 LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false); 147 builder.create<LLVM::LLVMFuncOp>(module.getLoc(), name, type); 148 } 149 150 /// Adds coroutine intrinsics declarations to the module. 151 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 152 using namespace mlir::LLVM; 153 154 MLIRContext *ctx = module.getContext(); 155 OpBuilder builder(module.getBody()->getTerminator()); 156 157 auto token = LLVMTokenType::get(ctx); 158 auto voidTy = LLVMType::getVoidTy(ctx); 159 160 auto i8 = LLVMType::getInt8Ty(ctx); 161 auto i1 = LLVMType::getInt1Ty(ctx); 162 auto i32 = LLVMType::getInt32Ty(ctx); 163 auto i64 = LLVMType::getInt64Ty(ctx); 164 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 165 166 addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); 167 addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); 168 addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); 169 addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); 170 addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); 171 addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); 172 addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); 173 addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // Add malloc/free declarations to the module. 178 //===----------------------------------------------------------------------===// 179 180 static constexpr const char *kMalloc = "malloc"; 181 static constexpr const char *kFree = "free"; 182 183 /// Adds malloc/free declarations to the module. 184 static void addCRuntimeDeclarations(ModuleOp module) { 185 using namespace mlir::LLVM; 186 187 MLIRContext *ctx = module.getContext(); 188 OpBuilder builder(module.getBody()->getTerminator()); 189 190 auto voidTy = LLVMType::getVoidTy(ctx); 191 auto i64 = LLVMType::getInt64Ty(ctx); 192 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 193 194 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 195 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 196 } 197 198 //===----------------------------------------------------------------------===// 199 // Coroutine resume function wrapper. 200 //===----------------------------------------------------------------------===// 201 202 static constexpr const char *kResume = "__resume"; 203 204 // A function that takes a coroutine handle and calls a `llvm.coro.resume` 205 // intrinsics. We need this function to be able to pass it to the async 206 // runtime execute API. 207 static void addResumeFunction(ModuleOp module) { 208 MLIRContext *ctx = module.getContext(); 209 210 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 211 Location loc = module.getLoc(); 212 213 if (module.lookupSymbol(kResume)) 214 return; 215 216 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 217 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 218 219 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 220 loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false)); 221 resumeOp.setPrivate(); 222 223 auto *block = resumeOp.addEntryBlock(); 224 OpBuilder blockBuilder = OpBuilder::atBlockEnd(block); 225 226 blockBuilder.create<LLVM::CallOp>(loc, Type(), 227 blockBuilder.getSymbolRefAttr(kCoroResume), 228 resumeOp.getArgument(0)); 229 230 blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange()); 231 } 232 233 //===----------------------------------------------------------------------===// 234 // async.execute op outlining to the coroutine functions. 235 //===----------------------------------------------------------------------===// 236 237 // Function targeted for coroutine transformation has two additional blocks at 238 // the end: coroutine cleanup and coroutine suspension. 239 // 240 // async.await op lowering additionaly creates a resume block for each 241 // operation to enable non-blocking waiting via coroutine suspension. 242 namespace { 243 struct CoroMachinery { 244 Value asyncToken; 245 Value coroHandle; 246 Block *cleanup; 247 Block *suspend; 248 }; 249 } // namespace 250 251 // Builds an coroutine template compatible with LLVM coroutines lowering. 252 // 253 // - `entry` block sets up the coroutine. 254 // - `cleanup` block cleans up the coroutine state. 255 // - `suspend block after the @llvm.coro.end() defines what value will be 256 // returned to the initial caller of a coroutine. Everything before the 257 // @llvm.coro.end() will be executed at every suspension point. 258 // 259 // Coroutine structure (only the important bits): 260 // 261 // func @async_execute_fn(<function-arguments>) -> !async.token { 262 // ^entryBlock(<function-arguments>): 263 // %token = <async token> : !async.token // create async runtime token 264 // %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle 265 // br ^cleanup 266 // 267 // ^cleanup: 268 // llvm.call @llvm.coro.free(...) // delete coroutine state 269 // br ^suspend 270 // 271 // ^suspend: 272 // llvm.call @llvm.coro.end(...) // marks the end of a coroutine 273 // return %token : !async.token 274 // } 275 // 276 // The actual code for the async.execute operation body region will be inserted 277 // before the entry block terminator. 278 // 279 // 280 static CoroMachinery setupCoroMachinery(FuncOp func) { 281 assert(func.getBody().empty() && "Function must have empty body"); 282 283 MLIRContext *ctx = func.getContext(); 284 285 auto token = LLVM::LLVMTokenType::get(ctx); 286 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 287 auto i32 = LLVM::LLVMType::getInt32Ty(ctx); 288 auto i64 = LLVM::LLVMType::getInt64Ty(ctx); 289 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 290 291 Block *entryBlock = func.addEntryBlock(); 292 Location loc = func.getBody().getLoc(); 293 294 OpBuilder builder = OpBuilder::atBlockBegin(entryBlock); 295 296 // ------------------------------------------------------------------------ // 297 // Allocate async tokens/values that we will return from a ramp function. 298 // ------------------------------------------------------------------------ // 299 auto createToken = 300 builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx)); 301 302 // ------------------------------------------------------------------------ // 303 // Initialize coroutine: allocate frame, get coroutine handle. 304 // ------------------------------------------------------------------------ // 305 306 // Constants for initializing coroutine frame. 307 auto constZero = 308 builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0)); 309 auto constFalse = 310 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 311 auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr); 312 313 // Get coroutine id: @llvm.coro.id 314 auto coroId = builder.create<LLVM::CallOp>( 315 loc, token, builder.getSymbolRefAttr(kCoroId), 316 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 317 318 // Get coroutine frame size: @llvm.coro.size.i64 319 auto coroSize = builder.create<LLVM::CallOp>( 320 loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); 321 322 // Allocate memory for coroutine frame. 323 auto coroAlloc = builder.create<LLVM::CallOp>( 324 loc, i8Ptr, builder.getSymbolRefAttr(kMalloc), 325 ValueRange(coroSize.getResult(0))); 326 327 // Begin a coroutine: @llvm.coro.begin 328 auto coroHdl = builder.create<LLVM::CallOp>( 329 loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin), 330 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); 331 332 Block *cleanupBlock = func.addBlock(); 333 Block *suspendBlock = func.addBlock(); 334 335 // ------------------------------------------------------------------------ // 336 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 337 // ------------------------------------------------------------------------ // 338 builder.setInsertionPointToStart(cleanupBlock); 339 340 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 341 auto coroMem = builder.create<LLVM::CallOp>( 342 loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree), 343 ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); 344 345 // Free the memory. 346 builder.create<LLVM::CallOp>(loc, Type(), builder.getSymbolRefAttr(kFree), 347 ValueRange(coroMem.getResult(0))); 348 // Branch into the suspend block. 349 builder.create<BranchOp>(loc, suspendBlock); 350 351 // ------------------------------------------------------------------------ // 352 // Coroutine suspend block: mark the end of a coroutine and return allocated 353 // async token. 354 // ------------------------------------------------------------------------ // 355 builder.setInsertionPointToStart(suspendBlock); 356 357 // Mark the end of a coroutine: @llvm.coro.end. 358 builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd), 359 ValueRange({coroHdl.getResult(0), constFalse})); 360 361 // Return created `async.token` from the suspend block. This will be the 362 // return value of a coroutine ramp function. 363 builder.create<ReturnOp>(loc, createToken.getResult(0)); 364 365 // Branch from the entry block to the cleanup block to create a valid CFG. 366 builder.setInsertionPointToEnd(entryBlock); 367 368 builder.create<BranchOp>(loc, cleanupBlock); 369 370 // `async.await` op lowering will create resume blocks for async 371 // continuations, and will conditionally branch to cleanup or suspend blocks. 372 373 return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock, 374 suspendBlock}; 375 } 376 377 // Adds a suspension point before the `op`, and moves `op` and all operations 378 // after it into the resume block. Returns a pointer to the resume block. 379 // 380 // `coroState` must be a value returned from the call to @llvm.coro.save(...) 381 // intrinsic (saved coroutine state). 382 // 383 // Before: 384 // 385 // ^bb0: 386 // "opBefore"(...) 387 // "op"(...) 388 // ^cleanup: ... 389 // ^suspend: ... 390 // 391 // After: 392 // 393 // ^bb0: 394 // "opBefore"(...) 395 // %suspend = llmv.call @llvm.coro.suspend(...) 396 // switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 397 // ^resume: 398 // "op"(...) 399 // ^cleanup: ... 400 // ^suspend: ... 401 // 402 static Block *addSuspensionPoint(CoroMachinery coro, Value coroState, 403 Operation *op) { 404 MLIRContext *ctx = op->getContext(); 405 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 406 auto i8 = LLVM::LLVMType::getInt8Ty(ctx); 407 408 Location loc = op->getLoc(); 409 Block *splitBlock = op->getBlock(); 410 411 // Split the block before `op`, newly added block is the resume block. 412 Block *resume = splitBlock->splitBlock(op); 413 414 // Add a coroutine suspension in place of original `op` in the split block. 415 OpBuilder builder = OpBuilder::atBlockEnd(splitBlock); 416 417 auto constFalse = 418 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 419 420 // Suspend a coroutine: @llvm.coro.suspend 421 auto coroSuspend = builder.create<LLVM::CallOp>( 422 loc, i8, builder.getSymbolRefAttr(kCoroSuspend), 423 ValueRange({coroState, constFalse})); 424 425 // After a suspension point decide if we should branch into resume, cleanup 426 // or suspend block of the coroutine (see @llvm.coro.suspend return code 427 // documentation). 428 auto constZero = 429 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0)); 430 auto constNegOne = 431 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1)); 432 433 Block *resumeOrCleanup = builder.createBlock(resume); 434 435 // Suspend the coroutine ...? 436 builder.setInsertionPointToEnd(splitBlock); 437 auto isNegOne = builder.create<LLVM::ICmpOp>( 438 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); 439 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend, 440 /*falseDest=*/resumeOrCleanup); 441 442 // ... or resume or cleanup the coroutine? 443 builder.setInsertionPointToStart(resumeOrCleanup); 444 auto isZero = builder.create<LLVM::ICmpOp>( 445 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); 446 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume, 447 /*falseDest=*/coro.cleanup); 448 449 return resume; 450 } 451 452 // Outline the body region attached to the `async.execute` op into a standalone 453 // function. 454 static std::pair<FuncOp, CoroMachinery> 455 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 456 ModuleOp module = execute.getParentOfType<ModuleOp>(); 457 458 MLIRContext *ctx = module.getContext(); 459 Location loc = execute.getLoc(); 460 461 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 462 463 // Collect all outlined function inputs. 464 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 465 execute.dependencies().end()); 466 getUsedValuesDefinedAbove(execute.body(), functionInputs); 467 468 // Collect types for the outlined function inputs and outputs. 469 auto typesRange = llvm::map_range( 470 functionInputs, [](Value value) { return value.getType(); }); 471 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 472 auto outputTypes = execute.getResultTypes(); 473 474 auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes); 475 auto funcAttrs = ArrayRef<NamedAttribute>(); 476 477 // TODO: Derive outlined function name from the parent FuncOp (support 478 // multiple nested async.execute operations). 479 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 480 symbolTable.insert(func, moduleBuilder.getInsertionPoint()); 481 482 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 483 484 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 485 // blocks, adding llvm.coro instrinsics and setting up control flow. 486 CoroMachinery coro = setupCoroMachinery(func); 487 488 // Suspend async function at the end of an entry block, and resume it using 489 // Async execute API (execution will be resumed in a thread managed by the 490 // async runtime). 491 Block *entryBlock = &func.getBlocks().front(); 492 OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock); 493 494 // A pointer to coroutine resume intrinsic wrapper. 495 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 496 auto resumePtr = builder.create<LLVM::AddressOfOp>( 497 loc, resumeFnTy.getPointerTo(), kResume); 498 499 // Save the coroutine state: @llvm.coro.save 500 auto coroSave = builder.create<LLVM::CallOp>( 501 loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 502 ValueRange({coro.coroHandle})); 503 504 // Call async runtime API to execute a coroutine in the managed thread. 505 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()}; 506 builder.create<CallOp>(loc, Type(), kExecute, executeArgs); 507 508 // Split the entry block before the terminator. 509 Block *resume = addSuspensionPoint(coro, coroSave.getResult(0), 510 entryBlock->getTerminator()); 511 512 // Await on all dependencies before starting to execute the body region. 513 builder.setInsertionPointToStart(resume); 514 for (size_t i = 0; i < execute.dependencies().size(); ++i) 515 builder.create<AwaitOp>(loc, func.getArgument(i)); 516 517 // Map from function inputs defined above the execute op to the function 518 // arguments. 519 BlockAndValueMapping valueMapping; 520 valueMapping.map(functionInputs, func.getArguments()); 521 522 // Clone all operations from the execute operation body into the outlined 523 // function body, and replace all `async.yield` operations with a call 524 // to async runtime to emplace the result token. 525 for (Operation &op : execute.body().getOps()) { 526 if (isa<async::YieldOp>(op)) { 527 builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken); 528 continue; 529 } 530 builder.clone(op, valueMapping); 531 } 532 533 // Replace the original `async.execute` with a call to outlined function. 534 OpBuilder callBuilder(execute); 535 auto callOutlinedFunc = 536 callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(), 537 functionInputs.getArrayRef()); 538 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 539 execute.erase(); 540 541 return {func, coro}; 542 } 543 544 //===----------------------------------------------------------------------===// 545 // Convert Async dialect types to LLVM types. 546 //===----------------------------------------------------------------------===// 547 548 namespace { 549 class AsyncRuntimeTypeConverter : public TypeConverter { 550 public: 551 AsyncRuntimeTypeConverter() { addConversion(convertType); } 552 553 static Type convertType(Type type) { 554 MLIRContext *ctx = type.getContext(); 555 // Convert async tokens and groups to opaque pointers. 556 if (type.isa<TokenType, GroupType>()) 557 return LLVM::LLVMType::getInt8PtrTy(ctx); 558 return type; 559 } 560 }; 561 } // namespace 562 563 //===----------------------------------------------------------------------===// 564 // Convert types for all call operations to lowered async types. 565 //===----------------------------------------------------------------------===// 566 567 namespace { 568 class CallOpOpConversion : public ConversionPattern { 569 public: 570 explicit CallOpOpConversion(MLIRContext *ctx) 571 : ConversionPattern(CallOp::getOperationName(), 1, ctx) {} 572 573 LogicalResult 574 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 575 ConversionPatternRewriter &rewriter) const override { 576 AsyncRuntimeTypeConverter converter; 577 578 SmallVector<Type, 5> resultTypes; 579 converter.convertTypes(op->getResultTypes(), resultTypes); 580 581 CallOp call = cast<CallOp>(op); 582 rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(), 583 call.getOperands()); 584 585 return success(); 586 } 587 }; 588 } // namespace 589 590 //===----------------------------------------------------------------------===// 591 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 592 //===----------------------------------------------------------------------===// 593 594 namespace { 595 class CreateGroupOpLowering : public ConversionPattern { 596 public: 597 explicit CreateGroupOpLowering(MLIRContext *ctx) 598 : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {} 599 600 LogicalResult 601 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 602 ConversionPatternRewriter &rewriter) const override { 603 auto retTy = GroupType::get(op->getContext()); 604 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy); 605 return success(); 606 } 607 }; 608 } // namespace 609 610 //===----------------------------------------------------------------------===// 611 // async.add_to_group op lowering to runtime function call. 612 //===----------------------------------------------------------------------===// 613 614 namespace { 615 class AddToGroupOpLowering : public ConversionPattern { 616 public: 617 explicit AddToGroupOpLowering(MLIRContext *ctx) 618 : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {} 619 620 LogicalResult 621 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 622 ConversionPatternRewriter &rewriter) const override { 623 // Currently we can only add tokens to the group. 624 auto addToGroup = cast<AddToGroupOp>(op); 625 if (!addToGroup.operand().getType().isa<TokenType>()) 626 return failure(); 627 628 auto i64 = IntegerType::get(64, op->getContext()); 629 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands); 630 return success(); 631 } 632 }; 633 } // namespace 634 635 //===----------------------------------------------------------------------===// 636 // async.await and async.await_all op lowerings to the corresponding async 637 // runtime function calls. 638 //===----------------------------------------------------------------------===// 639 640 namespace { 641 642 template <typename AwaitType, typename AwaitableType> 643 class AwaitOpLoweringBase : public ConversionPattern { 644 protected: 645 explicit AwaitOpLoweringBase( 646 MLIRContext *ctx, 647 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions, 648 StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) 649 : ConversionPattern(AwaitType::getOperationName(), 1, ctx), 650 outlinedFunctions(outlinedFunctions), 651 blockingAwaitFuncName(blockingAwaitFuncName), 652 coroAwaitFuncName(coroAwaitFuncName) {} 653 654 public: 655 LogicalResult 656 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 657 ConversionPatternRewriter &rewriter) const override { 658 // We can only await on one the `AwaitableType` (for `await` it can be 659 // only a `token`, for `await_all` it is a `group`). 660 auto await = cast<AwaitType>(op); 661 if (!await.operand().getType().template isa<AwaitableType>()) 662 return failure(); 663 664 // Check if await operation is inside the outlined coroutine function. 665 auto func = await.template getParentOfType<FuncOp>(); 666 auto outlined = outlinedFunctions.find(func); 667 const bool isInCoroutine = outlined != outlinedFunctions.end(); 668 669 Location loc = op->getLoc(); 670 671 // Inside regular function we convert await operation to the blocking 672 // async API await function call. 673 if (!isInCoroutine) 674 rewriter.create<CallOp>(loc, Type(), blockingAwaitFuncName, 675 ValueRange(op->getOperand(0))); 676 677 // Inside the coroutine we convert await operation into coroutine suspension 678 // point, and resume execution asynchronously. 679 if (isInCoroutine) { 680 const CoroMachinery &coro = outlined->getSecond(); 681 682 OpBuilder builder(op); 683 MLIRContext *ctx = op->getContext(); 684 685 // A pointer to coroutine resume intrinsic wrapper. 686 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 687 auto resumePtr = builder.create<LLVM::AddressOfOp>( 688 loc, resumeFnTy.getPointerTo(), kResume); 689 690 // Save the coroutine state: @llvm.coro.save 691 auto coroSave = builder.create<LLVM::CallOp>( 692 loc, LLVM::LLVMTokenType::get(ctx), 693 builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle)); 694 695 // Call async runtime API to resume a coroutine in the managed thread when 696 // the async await argument becomes ready. 697 SmallVector<Value, 3> awaitAndExecuteArgs = { 698 await.getOperand(), coro.coroHandle, resumePtr.res()}; 699 builder.create<CallOp>(loc, Type(), coroAwaitFuncName, 700 awaitAndExecuteArgs); 701 702 // Split the entry block before the await operation. 703 addSuspensionPoint(coro, coroSave.getResult(0), op); 704 } 705 706 // Original operation was replaced by function call or suspension point. 707 rewriter.eraseOp(op); 708 709 return success(); 710 } 711 712 private: 713 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 714 StringRef blockingAwaitFuncName; 715 StringRef coroAwaitFuncName; 716 }; 717 718 // Lowering for `async.await` operation (only token operands are supported). 719 class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 720 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 721 722 public: 723 explicit AwaitOpLowering( 724 MLIRContext *ctx, 725 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 726 : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {} 727 }; 728 729 // Lowering for `async.await_all` operation. 730 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 731 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 732 733 public: 734 explicit AwaitAllOpLowering( 735 MLIRContext *ctx, 736 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 737 : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {} 738 }; 739 740 } // namespace 741 742 //===----------------------------------------------------------------------===// 743 744 namespace { 745 struct ConvertAsyncToLLVMPass 746 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 747 void runOnOperation() override; 748 }; 749 750 void ConvertAsyncToLLVMPass::runOnOperation() { 751 ModuleOp module = getOperation(); 752 SymbolTable symbolTable(module); 753 754 // Outline all `async.execute` body regions into async functions (coroutines). 755 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 756 757 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 758 // We currently do not support execute operations that have async value 759 // operands or produce async results. 760 if (!execute.operands().empty() || !execute.results().empty()) { 761 execute.emitOpError("can't outline async.execute op with async value " 762 "operands or returned async results"); 763 return WalkResult::interrupt(); 764 } 765 766 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 767 768 return WalkResult::advance(); 769 }); 770 771 // Failed to outline all async execute operations. 772 if (outlineResult.wasInterrupted()) { 773 signalPassFailure(); 774 return; 775 } 776 777 LLVM_DEBUG({ 778 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 779 << " async functions\n"; 780 }); 781 782 // Add declarations for all functions required by the coroutines lowering. 783 addResumeFunction(module); 784 addAsyncRuntimeApiDeclarations(module); 785 addCoroutineIntrinsicsDeclarations(module); 786 addCRuntimeDeclarations(module); 787 788 MLIRContext *ctx = &getContext(); 789 790 // Convert async dialect types and operations to LLVM dialect. 791 AsyncRuntimeTypeConverter converter; 792 OwningRewritePatternList patterns; 793 794 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 795 patterns.insert<CallOpOpConversion>(ctx); 796 patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 797 patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions); 798 799 ConversionTarget target(*ctx); 800 target.addLegalDialect<LLVM::LLVMDialect>(); 801 target.addIllegalDialect<AsyncDialect>(); 802 target.addDynamicallyLegalOp<FuncOp>( 803 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 804 target.addDynamicallyLegalOp<CallOp>( 805 [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); 806 807 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 808 signalPassFailure(); 809 } 810 } // namespace 811 812 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 813 return std::make_unique<ConvertAsyncToLLVMPass>(); 814 } 815