1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Async/IR/Async.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "mlir/Transforms/RegionUtils.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/Support/FormatVariadic.h" 23 24 #define DEBUG_TYPE "convert-async-to-llvm" 25 26 using namespace mlir; 27 using namespace mlir::async; 28 29 // Prefix for functions outlined from `async.execute` op regions. 30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 37 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 39 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 40 static constexpr const char *kAwaitAndExecute = 41 "mlirAsyncRuntimeAwaitTokenAndExecute"; 42 43 namespace { 44 // Async Runtime API function types. 45 struct AsyncAPI { 46 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 47 return FunctionType::get({}, {TokenType::get(ctx)}, ctx); 48 } 49 50 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 51 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 52 } 53 54 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 55 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 56 } 57 58 static FunctionType executeFunctionType(MLIRContext *ctx) { 59 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 60 auto resume = resumeFunctionType(ctx).getPointerTo(); 61 return FunctionType::get({hdl, resume}, {}, ctx); 62 } 63 64 static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { 65 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 66 auto resume = resumeFunctionType(ctx).getPointerTo(); 67 return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); 68 } 69 70 // Auxiliary coroutine resume intrinsic wrapper. 71 static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { 72 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 73 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 74 return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); 75 } 76 }; 77 } // namespace 78 79 // Adds Async Runtime C API declarations to the module. 80 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 81 auto builder = OpBuilder::atBlockTerminator(module.getBody()); 82 83 MLIRContext *ctx = module.getContext(); 84 Location loc = module.getLoc(); 85 86 if (!module.lookupSymbol(kCreateToken)) 87 builder.create<FuncOp>(loc, kCreateToken, 88 AsyncAPI::createTokenFunctionType(ctx)); 89 90 if (!module.lookupSymbol(kEmplaceToken)) 91 builder.create<FuncOp>(loc, kEmplaceToken, 92 AsyncAPI::emplaceTokenFunctionType(ctx)); 93 94 if (!module.lookupSymbol(kAwaitToken)) 95 builder.create<FuncOp>(loc, kAwaitToken, 96 AsyncAPI::awaitTokenFunctionType(ctx)); 97 98 if (!module.lookupSymbol(kExecute)) 99 builder.create<FuncOp>(loc, kExecute, AsyncAPI::executeFunctionType(ctx)); 100 101 if (!module.lookupSymbol(kAwaitAndExecute)) 102 builder.create<FuncOp>(loc, kAwaitAndExecute, 103 AsyncAPI::awaitAndExecuteFunctionType(ctx)); 104 } 105 106 //===----------------------------------------------------------------------===// 107 // LLVM coroutines intrinsics declarations. 108 //===----------------------------------------------------------------------===// 109 110 static constexpr const char *kCoroId = "llvm.coro.id"; 111 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 112 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 113 static constexpr const char *kCoroSave = "llvm.coro.save"; 114 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 115 static constexpr const char *kCoroEnd = "llvm.coro.end"; 116 static constexpr const char *kCoroFree = "llvm.coro.free"; 117 static constexpr const char *kCoroResume = "llvm.coro.resume"; 118 119 /// Adds coroutine intrinsics declarations to the module. 120 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 121 using namespace mlir::LLVM; 122 123 MLIRContext *ctx = module.getContext(); 124 Location loc = module.getLoc(); 125 126 OpBuilder builder(module.getBody()->getTerminator()); 127 128 auto token = LLVMTokenType::get(ctx); 129 auto voidTy = LLVMType::getVoidTy(ctx); 130 131 auto i8 = LLVMType::getInt8Ty(ctx); 132 auto i1 = LLVMType::getInt1Ty(ctx); 133 auto i32 = LLVMType::getInt32Ty(ctx); 134 auto i64 = LLVMType::getInt64Ty(ctx); 135 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 136 137 if (!module.lookupSymbol(kCoroId)) 138 builder.create<LLVMFuncOp>( 139 loc, kCoroId, 140 LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false)); 141 142 if (!module.lookupSymbol(kCoroSizeI64)) 143 builder.create<LLVMFuncOp>(loc, kCoroSizeI64, 144 LLVMType::getFunctionTy(i64, false)); 145 146 if (!module.lookupSymbol(kCoroBegin)) 147 builder.create<LLVMFuncOp>( 148 loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); 149 150 if (!module.lookupSymbol(kCoroSave)) 151 builder.create<LLVMFuncOp>(loc, kCoroSave, 152 LLVMType::getFunctionTy(token, i8Ptr, false)); 153 154 if (!module.lookupSymbol(kCoroSuspend)) 155 builder.create<LLVMFuncOp>(loc, kCoroSuspend, 156 LLVMType::getFunctionTy(i8, {token, i1}, false)); 157 158 if (!module.lookupSymbol(kCoroEnd)) 159 builder.create<LLVMFuncOp>(loc, kCoroEnd, 160 LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false)); 161 162 if (!module.lookupSymbol(kCoroFree)) 163 builder.create<LLVMFuncOp>( 164 loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); 165 166 if (!module.lookupSymbol(kCoroResume)) 167 builder.create<LLVMFuncOp>(loc, kCoroResume, 168 LLVMType::getFunctionTy(voidTy, i8Ptr, false)); 169 } 170 171 //===----------------------------------------------------------------------===// 172 // Add malloc/free declarations to the module. 173 //===----------------------------------------------------------------------===// 174 175 static constexpr const char *kMalloc = "malloc"; 176 static constexpr const char *kFree = "free"; 177 178 /// Adds malloc/free declarations to the module. 179 static void addCRuntimeDeclarations(ModuleOp module) { 180 using namespace mlir::LLVM; 181 182 MLIRContext *ctx = module.getContext(); 183 Location loc = module.getLoc(); 184 185 OpBuilder builder(module.getBody()->getTerminator()); 186 187 auto voidTy = LLVMType::getVoidTy(ctx); 188 auto i64 = LLVMType::getInt64Ty(ctx); 189 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 190 191 if (!module.lookupSymbol(kMalloc)) 192 builder.create<LLVM::LLVMFuncOp>( 193 loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false)); 194 195 if (!module.lookupSymbol(kFree)) 196 builder.create<LLVM::LLVMFuncOp>( 197 loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false)); 198 } 199 200 //===----------------------------------------------------------------------===// 201 // Coroutine resume function wrapper. 202 //===----------------------------------------------------------------------===// 203 204 static constexpr const char *kResume = "__resume"; 205 206 // A function that takes a coroutine handle and calls a `llvm.coro.resume` 207 // intrinsics. We need this function to be able to pass it to the async 208 // runtime execute API. 209 static void addResumeFunction(ModuleOp module) { 210 MLIRContext *ctx = module.getContext(); 211 212 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 213 Location loc = module.getLoc(); 214 215 if (module.lookupSymbol(kResume)) 216 return; 217 218 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 219 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 220 221 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 222 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 223 SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private); 224 225 auto *block = resumeOp.addEntryBlock(); 226 OpBuilder blockBuilder = OpBuilder::atBlockEnd(block); 227 228 blockBuilder.create<LLVM::CallOp>(loc, Type(), 229 blockBuilder.getSymbolRefAttr(kCoroResume), 230 resumeOp.getArgument(0)); 231 232 blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange()); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // async.execute op outlining to the coroutine functions. 237 //===----------------------------------------------------------------------===// 238 239 // Function targeted for coroutine transformation has two additional blocks at 240 // the end: coroutine cleanup and coroutine suspension. 241 // 242 // async.await op lowering additionaly creates a resume block for each 243 // operation to enable non-blocking waiting via coroutine suspension. 244 namespace { 245 struct CoroMachinery { 246 Value asyncToken; 247 Value coroHandle; 248 Block *cleanup; 249 Block *suspend; 250 }; 251 } // namespace 252 253 // Builds an coroutine template compatible with LLVM coroutines lowering. 254 // 255 // - `entry` block sets up the coroutine. 256 // - `cleanup` block cleans up the coroutine state. 257 // - `suspend block after the @llvm.coro.end() defines what value will be 258 // returned to the initial caller of a coroutine. Everything before the 259 // @llvm.coro.end() will be executed at every suspension point. 260 // 261 // Coroutine structure (only the important bits): 262 // 263 // func @async_execute_fn(<function-arguments>) -> !async.token { 264 // ^entryBlock(<function-arguments>): 265 // %token = <async token> : !async.token // create async runtime token 266 // %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle 267 // br ^cleanup 268 // 269 // ^cleanup: 270 // llvm.call @llvm.coro.free(...) // delete coroutine state 271 // br ^suspend 272 // 273 // ^suspend: 274 // llvm.call @llvm.coro.end(...) // marks the end of a coroutine 275 // return %token : !async.token 276 // } 277 // 278 // The actual code for the async.execute operation body region will be inserted 279 // before the entry block terminator. 280 // 281 // 282 static CoroMachinery setupCoroMachinery(FuncOp func) { 283 assert(func.getBody().empty() && "Function must have empty body"); 284 285 MLIRContext *ctx = func.getContext(); 286 287 auto token = LLVM::LLVMTokenType::get(ctx); 288 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 289 auto i32 = LLVM::LLVMType::getInt32Ty(ctx); 290 auto i64 = LLVM::LLVMType::getInt64Ty(ctx); 291 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 292 293 Block *entryBlock = func.addEntryBlock(); 294 Location loc = func.getBody().getLoc(); 295 296 OpBuilder builder = OpBuilder::atBlockBegin(entryBlock); 297 298 // ------------------------------------------------------------------------ // 299 // Allocate async tokens/values that we will return from a ramp function. 300 // ------------------------------------------------------------------------ // 301 auto createToken = 302 builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx)); 303 304 // ------------------------------------------------------------------------ // 305 // Initialize coroutine: allocate frame, get coroutine handle. 306 // ------------------------------------------------------------------------ // 307 308 // Constants for initializing coroutine frame. 309 auto constZero = 310 builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0)); 311 auto constFalse = 312 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 313 auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr); 314 315 // Get coroutine id: @llvm.coro.id 316 auto coroId = builder.create<LLVM::CallOp>( 317 loc, token, builder.getSymbolRefAttr(kCoroId), 318 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 319 320 // Get coroutine frame size: @llvm.coro.size.i64 321 auto coroSize = builder.create<LLVM::CallOp>( 322 loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); 323 324 // Allocate memory for coroutine frame. 325 auto coroAlloc = builder.create<LLVM::CallOp>( 326 loc, i8Ptr, builder.getSymbolRefAttr(kMalloc), 327 ValueRange(coroSize.getResult(0))); 328 329 // Begin a coroutine: @llvm.coro.begin 330 auto coroHdl = builder.create<LLVM::CallOp>( 331 loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin), 332 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); 333 334 Block *cleanupBlock = func.addBlock(); 335 Block *suspendBlock = func.addBlock(); 336 337 // ------------------------------------------------------------------------ // 338 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 339 // ------------------------------------------------------------------------ // 340 builder.setInsertionPointToStart(cleanupBlock); 341 342 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 343 auto coroMem = builder.create<LLVM::CallOp>( 344 loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree), 345 ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); 346 347 // Free the memory. 348 builder.create<LLVM::CallOp>(loc, Type(), builder.getSymbolRefAttr(kFree), 349 ValueRange(coroMem.getResult(0))); 350 // Branch into the suspend block. 351 builder.create<BranchOp>(loc, suspendBlock); 352 353 // ------------------------------------------------------------------------ // 354 // Coroutine suspend block: mark the end of a coroutine and return allocated 355 // async token. 356 // ------------------------------------------------------------------------ // 357 builder.setInsertionPointToStart(suspendBlock); 358 359 // Mark the end of a coroutine: @llvm.coro.end. 360 builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd), 361 ValueRange({coroHdl.getResult(0), constFalse})); 362 363 // Return created `async.token` from the suspend block. This will be the 364 // return value of a coroutine ramp function. 365 builder.create<ReturnOp>(loc, createToken.getResult(0)); 366 367 // Branch from the entry block to the cleanup block to create a valid CFG. 368 builder.setInsertionPointToEnd(entryBlock); 369 370 builder.create<BranchOp>(loc, cleanupBlock); 371 372 // `async.await` op lowering will create resume blocks for async 373 // continuations, and will conditionally branch to cleanup or suspend blocks. 374 375 return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock, 376 suspendBlock}; 377 } 378 379 // Adds a suspension point before the `op`, and moves `op` and all operations 380 // after it into the resume block. Returns a pointer to the resume block. 381 // 382 // `coroState` must be a value returned from the call to @llvm.coro.save(...) 383 // intrinsic (saved coroutine state). 384 // 385 // Before: 386 // 387 // ^bb0: 388 // "opBefore"(...) 389 // "op"(...) 390 // ^cleanup: ... 391 // ^suspend: ... 392 // 393 // After: 394 // 395 // ^bb0: 396 // "opBefore"(...) 397 // %suspend = llmv.call @llvm.coro.suspend(...) 398 // switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 399 // ^resume: 400 // "op"(...) 401 // ^cleanup: ... 402 // ^suspend: ... 403 // 404 static Block *addSuspensionPoint(CoroMachinery coro, Value coroState, 405 Operation *op) { 406 MLIRContext *ctx = op->getContext(); 407 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 408 auto i8 = LLVM::LLVMType::getInt8Ty(ctx); 409 410 Location loc = op->getLoc(); 411 Block *splitBlock = op->getBlock(); 412 413 // Split the block before `op`, newly added block is the resume block. 414 Block *resume = splitBlock->splitBlock(op); 415 416 // Add a coroutine suspension in place of original `op` in the split block. 417 OpBuilder builder = OpBuilder::atBlockEnd(splitBlock); 418 419 auto constFalse = 420 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 421 422 // Suspend a coroutine: @llvm.coro.suspend 423 auto coroSuspend = builder.create<LLVM::CallOp>( 424 loc, i8, builder.getSymbolRefAttr(kCoroSuspend), 425 ValueRange({coroState, constFalse})); 426 427 // After a suspension point decide if we should branch into resume, cleanup 428 // or suspend block of the coroutine (see @llvm.coro.suspend return code 429 // documentation). 430 auto constZero = 431 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0)); 432 auto constNegOne = 433 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1)); 434 435 Block *resumeOrCleanup = builder.createBlock(resume); 436 437 // Suspend the coroutine ...? 438 builder.setInsertionPointToEnd(splitBlock); 439 auto isNegOne = builder.create<LLVM::ICmpOp>( 440 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); 441 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend, 442 /*falseDest=*/resumeOrCleanup); 443 444 // ... or resume or cleanup the coroutine? 445 builder.setInsertionPointToStart(resumeOrCleanup); 446 auto isZero = builder.create<LLVM::ICmpOp>( 447 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); 448 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume, 449 /*falseDest=*/coro.cleanup); 450 451 return resume; 452 } 453 454 // Outline the body region attached to the `async.execute` op into a standalone 455 // function. 456 static std::pair<FuncOp, CoroMachinery> 457 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 458 ModuleOp module = execute.getParentOfType<ModuleOp>(); 459 460 MLIRContext *ctx = module.getContext(); 461 Location loc = execute.getLoc(); 462 463 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 464 465 // Collect all outlined function inputs. 466 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 467 execute.dependencies().end()); 468 getUsedValuesDefinedAbove(execute.body(), functionInputs); 469 470 // Collect types for the outlined function inputs and outputs. 471 auto typesRange = llvm::map_range( 472 functionInputs, [](Value value) { return value.getType(); }); 473 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 474 auto outputTypes = execute.getResultTypes(); 475 476 auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes); 477 auto funcAttrs = ArrayRef<NamedAttribute>(); 478 479 // TODO: Derive outlined function name from the parent FuncOp (support 480 // multiple nested async.execute operations). 481 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 482 symbolTable.insert(func, moduleBuilder.getInsertionPoint()); 483 484 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 485 486 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 487 // blocks, adding llvm.coro instrinsics and setting up control flow. 488 CoroMachinery coro = setupCoroMachinery(func); 489 490 // Suspend async function at the end of an entry block, and resume it using 491 // Async execute API (execution will be resumed in a thread managed by the 492 // async runtime). 493 Block *entryBlock = &func.getBlocks().front(); 494 OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock); 495 496 // A pointer to coroutine resume intrinsic wrapper. 497 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 498 auto resumePtr = builder.create<LLVM::AddressOfOp>( 499 loc, resumeFnTy.getPointerTo(), kResume); 500 501 // Save the coroutine state: @llvm.coro.save 502 auto coroSave = builder.create<LLVM::CallOp>( 503 loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 504 ValueRange({coro.coroHandle})); 505 506 // Call async runtime API to execute a coroutine in the managed thread. 507 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()}; 508 builder.create<CallOp>(loc, Type(), kExecute, executeArgs); 509 510 // Split the entry block before the terminator. 511 Block *resume = addSuspensionPoint(coro, coroSave.getResult(0), 512 entryBlock->getTerminator()); 513 514 // Await on all dependencies before starting to execute the body region. 515 builder.setInsertionPointToStart(resume); 516 for (size_t i = 0; i < execute.dependencies().size(); ++i) 517 builder.create<AwaitOp>(loc, func.getArgument(i)); 518 519 // Map from function inputs defined above the execute op to the function 520 // arguments. 521 BlockAndValueMapping valueMapping; 522 valueMapping.map(functionInputs, func.getArguments()); 523 524 // Clone all operations from the execute operation body into the outlined 525 // function body, and replace all `async.yield` operations with a call 526 // to async runtime to emplace the result token. 527 for (Operation &op : execute.body().getOps()) { 528 if (isa<async::YieldOp>(op)) { 529 builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken); 530 continue; 531 } 532 builder.clone(op, valueMapping); 533 } 534 535 // Replace the original `async.execute` with a call to outlined function. 536 OpBuilder callBuilder(execute); 537 auto callOutlinedFunc = 538 callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(), 539 functionInputs.getArrayRef()); 540 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 541 execute.erase(); 542 543 return {func, coro}; 544 } 545 546 //===----------------------------------------------------------------------===// 547 // Convert Async dialect types to LLVM types. 548 //===----------------------------------------------------------------------===// 549 550 namespace { 551 class AsyncRuntimeTypeConverter : public TypeConverter { 552 public: 553 AsyncRuntimeTypeConverter() { addConversion(convertType); } 554 555 static Type convertType(Type type) { 556 MLIRContext *ctx = type.getContext(); 557 // Convert async tokens to opaque pointers. 558 if (type.isa<TokenType>()) 559 return LLVM::LLVMType::getInt8PtrTy(ctx); 560 return type; 561 } 562 }; 563 } // namespace 564 565 //===----------------------------------------------------------------------===// 566 // Convert types for all call operations to lowered async types. 567 //===----------------------------------------------------------------------===// 568 569 namespace { 570 class CallOpOpConversion : public ConversionPattern { 571 public: 572 explicit CallOpOpConversion(MLIRContext *ctx) 573 : ConversionPattern(CallOp::getOperationName(), 1, ctx) {} 574 575 LogicalResult 576 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 577 ConversionPatternRewriter &rewriter) const override { 578 AsyncRuntimeTypeConverter converter; 579 580 SmallVector<Type, 5> resultTypes; 581 converter.convertTypes(op->getResultTypes(), resultTypes); 582 583 CallOp call = cast<CallOp>(op); 584 rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(), 585 call.getOperands()); 586 587 return success(); 588 } 589 }; 590 } // namespace 591 592 //===----------------------------------------------------------------------===// 593 // async.await op lowering to mlirAsyncRuntimeAwaitToken function call. 594 //===----------------------------------------------------------------------===// 595 596 namespace { 597 class AwaitOpLowering : public ConversionPattern { 598 public: 599 explicit AwaitOpLowering( 600 MLIRContext *ctx, 601 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 602 : ConversionPattern(AwaitOp::getOperationName(), 1, ctx), 603 outlinedFunctions(outlinedFunctions) {} 604 605 LogicalResult 606 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 607 ConversionPatternRewriter &rewriter) const override { 608 // We can only await on the token operand. Async valus are not supported. 609 auto await = cast<AwaitOp>(op); 610 if (!await.operand().getType().isa<TokenType>()) 611 return failure(); 612 613 // Check if `async.await` is inside the outlined coroutine function. 614 auto func = await.getParentOfType<FuncOp>(); 615 auto outlined = outlinedFunctions.find(func); 616 const bool isInCoroutine = outlined != outlinedFunctions.end(); 617 618 Location loc = op->getLoc(); 619 620 // Inside regular function we convert await operation to the blocking 621 // async API await function call. 622 if (!isInCoroutine) 623 rewriter.create<CallOp>(loc, Type(), kAwaitToken, 624 ValueRange(op->getOperand(0))); 625 626 // Inside the coroutine we convert await operation into coroutine suspension 627 // point, and resume execution asynchronously. 628 if (isInCoroutine) { 629 const CoroMachinery &coro = outlined->getSecond(); 630 631 OpBuilder builder(op); 632 MLIRContext *ctx = op->getContext(); 633 634 // A pointer to coroutine resume intrinsic wrapper. 635 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 636 auto resumePtr = builder.create<LLVM::AddressOfOp>( 637 loc, resumeFnTy.getPointerTo(), kResume); 638 639 // Save the coroutine state: @llvm.coro.save 640 auto coroSave = builder.create<LLVM::CallOp>( 641 loc, LLVM::LLVMTokenType::get(ctx), 642 builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle)); 643 644 // Call async runtime API to resume a coroutine in the managed thread when 645 // the async await argument becomes ready. 646 SmallVector<Value, 3> awaitAndExecuteArgs = { 647 await.getOperand(), coro.coroHandle, resumePtr.res()}; 648 builder.create<CallOp>(loc, Type(), kAwaitAndExecute, 649 awaitAndExecuteArgs); 650 651 // Split the entry block before the await operation. 652 addSuspensionPoint(coro, coroSave.getResult(0), op); 653 } 654 655 // Original operation was replaced by function call or suspension point. 656 rewriter.eraseOp(op); 657 658 return success(); 659 } 660 661 private: 662 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 663 }; 664 } // namespace 665 666 //===----------------------------------------------------------------------===// 667 668 namespace { 669 struct ConvertAsyncToLLVMPass 670 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 671 void runOnOperation() override; 672 }; 673 674 void ConvertAsyncToLLVMPass::runOnOperation() { 675 ModuleOp module = getOperation(); 676 SymbolTable symbolTable(module); 677 678 // Outline all `async.execute` body regions into async functions (coroutines). 679 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 680 681 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 682 // We currently do not support execute operations that have async value 683 // operands or produce async results. 684 if (!execute.operands().empty() || !execute.results().empty()) { 685 execute.emitOpError("can't outline async.execute op with async value " 686 "operands or returned async results"); 687 return WalkResult::interrupt(); 688 } 689 690 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 691 692 return WalkResult::advance(); 693 }); 694 695 // Failed to outline all async execute operations. 696 if (outlineResult.wasInterrupted()) { 697 signalPassFailure(); 698 return; 699 } 700 701 LLVM_DEBUG({ 702 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 703 << " async functions\n"; 704 }); 705 706 // Add declarations for all functions required by the coroutines lowering. 707 addResumeFunction(module); 708 addAsyncRuntimeApiDeclarations(module); 709 addCoroutineIntrinsicsDeclarations(module); 710 addCRuntimeDeclarations(module); 711 712 MLIRContext *ctx = &getContext(); 713 714 // Convert async dialect types and operations to LLVM dialect. 715 AsyncRuntimeTypeConverter converter; 716 OwningRewritePatternList patterns; 717 718 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 719 patterns.insert<CallOpOpConversion>(ctx); 720 patterns.insert<AwaitOpLowering>(ctx, outlinedFunctions); 721 722 ConversionTarget target(*ctx); 723 target.addLegalDialect<LLVM::LLVMDialect>(); 724 target.addIllegalDialect<AsyncDialect>(); 725 target.addDynamicallyLegalOp<FuncOp>( 726 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 727 target.addDynamicallyLegalOp<CallOp>( 728 [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); 729 730 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 731 signalPassFailure(); 732 } 733 } // namespace 734 735 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 736 return std::make_unique<ConvertAsyncToLLVMPass>(); 737 } 738