1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Async/IR/Async.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "mlir/Transforms/RegionUtils.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/Support/FormatVariadic.h" 23 24 #define DEBUG_TYPE "convert-async-to-llvm" 25 26 using namespace mlir; 27 using namespace mlir::async; 28 29 // Prefix for functions outlined from `async.execute` op regions. 30 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 31 32 //===----------------------------------------------------------------------===// 33 // Async Runtime C API declaration. 34 //===----------------------------------------------------------------------===// 35 36 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 37 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 38 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 39 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 40 static constexpr const char *kAwaitAndExecute = 41 "mlirAsyncRuntimeAwaitTokenAndExecute"; 42 43 namespace { 44 // Async Runtime API function types. 45 struct AsyncAPI { 46 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 47 return FunctionType::get({}, {TokenType::get(ctx)}, ctx); 48 } 49 50 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 51 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 52 } 53 54 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 55 return FunctionType::get({TokenType::get(ctx)}, {}, ctx); 56 } 57 58 static FunctionType executeFunctionType(MLIRContext *ctx) { 59 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 60 auto resume = resumeFunctionType(ctx).getPointerTo(); 61 return FunctionType::get({hdl, resume}, {}, ctx); 62 } 63 64 static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { 65 auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); 66 auto resume = resumeFunctionType(ctx).getPointerTo(); 67 return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); 68 } 69 70 // Auxiliary coroutine resume intrinsic wrapper. 71 static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { 72 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 73 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 74 return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); 75 } 76 }; 77 } // namespace 78 79 // Adds Async Runtime C API declarations to the module. 80 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 81 auto builder = OpBuilder::atBlockTerminator(module.getBody()); 82 83 MLIRContext *ctx = module.getContext(); 84 Location loc = module.getLoc(); 85 86 if (!module.lookupSymbol(kCreateToken)) 87 builder.create<FuncOp>(loc, kCreateToken, 88 AsyncAPI::createTokenFunctionType(ctx)); 89 90 if (!module.lookupSymbol(kEmplaceToken)) 91 builder.create<FuncOp>(loc, kEmplaceToken, 92 AsyncAPI::emplaceTokenFunctionType(ctx)); 93 94 if (!module.lookupSymbol(kAwaitToken)) 95 builder.create<FuncOp>(loc, kAwaitToken, 96 AsyncAPI::awaitTokenFunctionType(ctx)); 97 98 if (!module.lookupSymbol(kExecute)) 99 builder.create<FuncOp>(loc, kExecute, AsyncAPI::executeFunctionType(ctx)); 100 101 if (!module.lookupSymbol(kAwaitAndExecute)) 102 builder.create<FuncOp>(loc, kAwaitAndExecute, 103 AsyncAPI::awaitAndExecuteFunctionType(ctx)); 104 } 105 106 //===----------------------------------------------------------------------===// 107 // LLVM coroutines intrinsics declarations. 108 //===----------------------------------------------------------------------===// 109 110 static constexpr const char *kCoroId = "llvm.coro.id"; 111 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 112 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 113 static constexpr const char *kCoroSave = "llvm.coro.save"; 114 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 115 static constexpr const char *kCoroEnd = "llvm.coro.end"; 116 static constexpr const char *kCoroFree = "llvm.coro.free"; 117 static constexpr const char *kCoroResume = "llvm.coro.resume"; 118 119 /// Adds coroutine intrinsics declarations to the module. 120 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 121 using namespace mlir::LLVM; 122 123 MLIRContext *ctx = module.getContext(); 124 Location loc = module.getLoc(); 125 126 OpBuilder builder(module.getBody()->getTerminator()); 127 128 auto token = LLVMTokenType::get(ctx); 129 auto voidTy = LLVMType::getVoidTy(ctx); 130 131 auto i8 = LLVMType::getInt8Ty(ctx); 132 auto i1 = LLVMType::getInt1Ty(ctx); 133 auto i32 = LLVMType::getInt32Ty(ctx); 134 auto i64 = LLVMType::getInt64Ty(ctx); 135 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 136 137 if (!module.lookupSymbol(kCoroId)) 138 builder.create<LLVMFuncOp>( 139 loc, kCoroId, 140 LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false)); 141 142 if (!module.lookupSymbol(kCoroSizeI64)) 143 builder.create<LLVMFuncOp>(loc, kCoroSizeI64, 144 LLVMType::getFunctionTy(i64, false)); 145 146 if (!module.lookupSymbol(kCoroBegin)) 147 builder.create<LLVMFuncOp>( 148 loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); 149 150 if (!module.lookupSymbol(kCoroSave)) 151 builder.create<LLVMFuncOp>(loc, kCoroSave, 152 LLVMType::getFunctionTy(token, i8Ptr, false)); 153 154 if (!module.lookupSymbol(kCoroSuspend)) 155 builder.create<LLVMFuncOp>(loc, kCoroSuspend, 156 LLVMType::getFunctionTy(i8, {token, i1}, false)); 157 158 if (!module.lookupSymbol(kCoroEnd)) 159 builder.create<LLVMFuncOp>(loc, kCoroEnd, 160 LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false)); 161 162 if (!module.lookupSymbol(kCoroFree)) 163 builder.create<LLVMFuncOp>( 164 loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false)); 165 166 if (!module.lookupSymbol(kCoroResume)) 167 builder.create<LLVMFuncOp>(loc, kCoroResume, 168 LLVMType::getFunctionTy(voidTy, i8Ptr, false)); 169 } 170 171 //===----------------------------------------------------------------------===// 172 // Add malloc/free declarations to the module. 173 //===----------------------------------------------------------------------===// 174 175 static constexpr const char *kMalloc = "malloc"; 176 static constexpr const char *kFree = "free"; 177 178 /// Adds malloc/free declarations to the module. 179 static void addCRuntimeDeclarations(ModuleOp module) { 180 using namespace mlir::LLVM; 181 182 MLIRContext *ctx = module.getContext(); 183 Location loc = module.getLoc(); 184 185 OpBuilder builder(module.getBody()->getTerminator()); 186 187 auto voidTy = LLVMType::getVoidTy(ctx); 188 auto i64 = LLVMType::getInt64Ty(ctx); 189 auto i8Ptr = LLVMType::getInt8PtrTy(ctx); 190 191 if (!module.lookupSymbol(kMalloc)) 192 builder.create<LLVM::LLVMFuncOp>( 193 loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false)); 194 195 if (!module.lookupSymbol(kFree)) 196 builder.create<LLVM::LLVMFuncOp>( 197 loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false)); 198 } 199 200 //===----------------------------------------------------------------------===// 201 // Coroutine resume function wrapper. 202 //===----------------------------------------------------------------------===// 203 204 static constexpr const char *kResume = "__resume"; 205 206 // A function that takes a coroutine handle and calls a `llvm.coro.resume` 207 // intrinsics. We need this function to be able to pass it to the async 208 // runtime execute API. 209 static void addResumeFunction(ModuleOp module) { 210 MLIRContext *ctx = module.getContext(); 211 212 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 213 Location loc = module.getLoc(); 214 215 if (module.lookupSymbol(kResume)) 216 return; 217 218 auto voidTy = LLVM::LLVMType::getVoidTy(ctx); 219 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 220 221 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 222 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 223 SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private); 224 225 auto *block = resumeOp.addEntryBlock(); 226 OpBuilder blockBuilder = OpBuilder::atBlockEnd(block); 227 228 blockBuilder.create<LLVM::CallOp>(loc, Type(), 229 blockBuilder.getSymbolRefAttr(kCoroResume), 230 resumeOp.getArgument(0)); 231 232 blockBuilder.create<LLVM::ReturnOp>(loc, ValueRange()); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // async.execute op outlining to the coroutine functions. 237 //===----------------------------------------------------------------------===// 238 239 // Function targeted for coroutine transformation has two additional blocks at 240 // the end: coroutine cleanup and coroutine suspension. 241 // 242 // async.await op lowering additionaly creates a resume block for each 243 // operation to enable non-blocking waiting via coroutine suspension. 244 namespace { 245 struct CoroMachinery { 246 Value asyncToken; 247 Value coroHandle; 248 Block *cleanup; 249 Block *suspend; 250 }; 251 } // namespace 252 253 // Builds an coroutine template compatible with LLVM coroutines lowering. 254 // 255 // - `entry` block sets up the coroutine. 256 // - `cleanup` block cleans up the coroutine state. 257 // - `suspend block after the @llvm.coro.end() defines what value will be 258 // returned to the initial caller of a coroutine. Everything before the 259 // @llvm.coro.end() will be executed at every suspension point. 260 // 261 // Coroutine structure (only the important bits): 262 // 263 // func @async_execute_fn(<function-arguments>) -> !async.token { 264 // ^entryBlock(<function-arguments>): 265 // %token = <async token> : !async.token // create async runtime token 266 // %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle 267 // br ^cleanup 268 // 269 // ^cleanup: 270 // llvm.call @llvm.coro.free(...) // delete coroutine state 271 // br ^suspend 272 // 273 // ^suspend: 274 // llvm.call @llvm.coro.end(...) // marks the end of a coroutine 275 // return %token : !async.token 276 // } 277 // 278 // The actual code for the async.execute operation body region will be inserted 279 // before the entry block terminator. 280 // 281 // 282 static CoroMachinery setupCoroMachinery(FuncOp func) { 283 assert(func.getBody().empty() && "Function must have empty body"); 284 285 MLIRContext *ctx = func.getContext(); 286 287 auto token = LLVM::LLVMTokenType::get(ctx); 288 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 289 auto i32 = LLVM::LLVMType::getInt32Ty(ctx); 290 auto i64 = LLVM::LLVMType::getInt64Ty(ctx); 291 auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); 292 293 Block *entryBlock = func.addEntryBlock(); 294 Location loc = func.getBody().getLoc(); 295 296 OpBuilder builder = OpBuilder::atBlockBegin(entryBlock); 297 298 // ------------------------------------------------------------------------ // 299 // Allocate async tokens/values that we will return from a ramp function. 300 // ------------------------------------------------------------------------ // 301 auto createToken = 302 builder.create<CallOp>(loc, kCreateToken, TokenType::get(ctx)); 303 304 // ------------------------------------------------------------------------ // 305 // Initialize coroutine: allocate frame, get coroutine handle. 306 // ------------------------------------------------------------------------ // 307 308 // Constants for initializing coroutine frame. 309 auto constZero = 310 builder.create<LLVM::ConstantOp>(loc, i32, builder.getI32IntegerAttr(0)); 311 auto constFalse = 312 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 313 auto nullPtr = builder.create<LLVM::NullOp>(loc, i8Ptr); 314 315 // Get coroutine id: @llvm.coro.id 316 auto coroId = builder.create<LLVM::CallOp>( 317 loc, token, builder.getSymbolRefAttr(kCoroId), 318 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 319 320 // Get coroutine frame size: @llvm.coro.size.i64 321 auto coroSize = builder.create<LLVM::CallOp>( 322 loc, i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); 323 324 // Allocate memory for coroutine frame. 325 auto coroAlloc = builder.create<LLVM::CallOp>( 326 loc, i8Ptr, builder.getSymbolRefAttr(kMalloc), 327 ValueRange(coroSize.getResult(0))); 328 329 // Begin a coroutine: @llvm.coro.begin 330 auto coroHdl = builder.create<LLVM::CallOp>( 331 loc, i8Ptr, builder.getSymbolRefAttr(kCoroBegin), 332 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); 333 334 Block *cleanupBlock = func.addBlock(); 335 Block *suspendBlock = func.addBlock(); 336 337 // ------------------------------------------------------------------------ // 338 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 339 // ------------------------------------------------------------------------ // 340 builder.setInsertionPointToStart(cleanupBlock); 341 342 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 343 auto coroMem = builder.create<LLVM::CallOp>( 344 loc, i8Ptr, builder.getSymbolRefAttr(kCoroFree), 345 ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); 346 347 // Free the memory. 348 builder.create<LLVM::CallOp>(loc, Type(), builder.getSymbolRefAttr(kFree), 349 ValueRange(coroMem.getResult(0))); 350 // Branch into the suspend block. 351 builder.create<BranchOp>(loc, suspendBlock); 352 353 // ------------------------------------------------------------------------ // 354 // Coroutine suspend block: mark the end of a coroutine and return allocated 355 // async token. 356 // ------------------------------------------------------------------------ // 357 builder.setInsertionPointToStart(suspendBlock); 358 359 // Mark the end of a coroutine: @llvm.coro.end. 360 builder.create<LLVM::CallOp>(loc, i1, builder.getSymbolRefAttr(kCoroEnd), 361 ValueRange({coroHdl.getResult(0), constFalse})); 362 363 // Return created `async.token` from the suspend block. This will be the 364 // return value of a coroutine ramp function. 365 builder.create<ReturnOp>(loc, createToken.getResult(0)); 366 367 // Branch from the entry block to the cleanup block to create a valid CFG. 368 builder.setInsertionPointToEnd(entryBlock); 369 370 builder.create<BranchOp>(loc, cleanupBlock); 371 372 // `async.await` op lowering will create resume blocks for async 373 // continuations, and will conditionally branch to cleanup or suspend blocks. 374 375 return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock, 376 suspendBlock}; 377 } 378 379 // Adds a suspension point before the `op`, and moves `op` and all operations 380 // after it into the resume block. Returns a pointer to the resume block. 381 // 382 // `coroState` must be a value returned from the call to @llvm.coro.save(...) 383 // intrinsic (saved coroutine state). 384 // 385 // Before: 386 // 387 // ^bb0: 388 // "opBefore"(...) 389 // "op"(...) 390 // ^cleanup: ... 391 // ^suspend: ... 392 // 393 // After: 394 // 395 // ^bb0: 396 // "opBefore"(...) 397 // %suspend = llmv.call @llvm.coro.suspend(...) 398 // switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 399 // ^resume: 400 // "op"(...) 401 // ^cleanup: ... 402 // ^suspend: ... 403 // 404 static Block *addSuspensionPoint(CoroMachinery coro, Value coroState, 405 Operation *op) { 406 MLIRContext *ctx = op->getContext(); 407 auto i1 = LLVM::LLVMType::getInt1Ty(ctx); 408 auto i8 = LLVM::LLVMType::getInt8Ty(ctx); 409 410 Location loc = op->getLoc(); 411 Block *splitBlock = op->getBlock(); 412 413 // Split the block before `op`, newly added block is the resume block. 414 Block *resume = splitBlock->splitBlock(op); 415 416 // Add a coroutine suspension in place of original `op` in the split block. 417 OpBuilder builder = OpBuilder::atBlockEnd(splitBlock); 418 419 auto constFalse = 420 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 421 422 // Suspend a coroutine: @llvm.coro.suspend 423 auto coroSuspend = builder.create<LLVM::CallOp>( 424 loc, i8, builder.getSymbolRefAttr(kCoroSuspend), 425 ValueRange({coroState, constFalse})); 426 427 // After a suspension point decide if we should branch into resume, cleanup 428 // or suspend block of the coroutine (see @llvm.coro.suspend return code 429 // documentation). 430 auto constZero = 431 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0)); 432 auto constNegOne = 433 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1)); 434 435 Block *resumeOrCleanup = builder.createBlock(resume); 436 437 // Suspend the coroutine ...? 438 builder.setInsertionPointToEnd(splitBlock); 439 auto isNegOne = builder.create<LLVM::ICmpOp>( 440 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); 441 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend, 442 /*falseDest=*/resumeOrCleanup); 443 444 // ... or resume or cleanup the coroutine? 445 builder.setInsertionPointToStart(resumeOrCleanup); 446 auto isZero = builder.create<LLVM::ICmpOp>( 447 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); 448 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume, 449 /*falseDest=*/coro.cleanup); 450 451 return resume; 452 } 453 454 // Outline the body region attached to the `async.execute` op into a standalone 455 // function. 456 static std::pair<FuncOp, CoroMachinery> 457 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 458 ModuleOp module = execute.getParentOfType<ModuleOp>(); 459 460 MLIRContext *ctx = module.getContext(); 461 Location loc = execute.getLoc(); 462 463 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 464 465 // Get values captured by the async region 466 llvm::SetVector<mlir::Value> usedAbove; 467 getUsedValuesDefinedAbove(execute.body(), usedAbove); 468 469 // Collect types of the captured values. 470 auto usedAboveTypes = 471 llvm::map_range(usedAbove, [](Value value) { return value.getType(); }); 472 SmallVector<Type, 4> inputTypes(usedAboveTypes.begin(), usedAboveTypes.end()); 473 auto outputTypes = execute.getResultTypes(); 474 475 auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes); 476 auto funcAttrs = ArrayRef<NamedAttribute>(); 477 478 // TODO: Derive outlined function name from the parent FuncOp (support 479 // multiple nested async.execute operations). 480 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 481 symbolTable.insert(func, moduleBuilder.getInsertionPoint()); 482 483 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 484 485 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 486 // blocks, adding llvm.coro instrinsics and setting up control flow. 487 CoroMachinery coro = setupCoroMachinery(func); 488 489 // Suspend async function at the end of an entry block, and resume it using 490 // Async execute API (execution will be resumed in a thread managed by the 491 // async runtime). 492 Block *entryBlock = &func.getBlocks().front(); 493 OpBuilder builder = OpBuilder::atBlockTerminator(entryBlock); 494 495 // A pointer to coroutine resume intrinsic wrapper. 496 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 497 auto resumePtr = builder.create<LLVM::AddressOfOp>( 498 loc, resumeFnTy.getPointerTo(), kResume); 499 500 // Save the coroutine state: @llvm.coro.save 501 auto coroSave = builder.create<LLVM::CallOp>( 502 loc, LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 503 ValueRange({coro.coroHandle})); 504 505 // Call async runtime API to execute a coroutine in the managed thread. 506 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()}; 507 builder.create<CallOp>(loc, Type(), kExecute, executeArgs); 508 509 // Split the entry block before the terminator. 510 Block *resume = addSuspensionPoint(coro, coroSave.getResult(0), 511 entryBlock->getTerminator()); 512 513 // Map from values defined above the execute op to the function arguments. 514 BlockAndValueMapping valueMapping; 515 valueMapping.map(usedAbove, func.getArguments()); 516 517 // Clone all operations from the execute operation body into the outlined 518 // function body, and replace all `async.yield` operations with a call 519 // to async runtime to emplace the result token. 520 builder.setInsertionPointToStart(resume); 521 for (Operation &op : execute.body().getOps()) { 522 if (isa<async::YieldOp>(op)) { 523 builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken); 524 continue; 525 } 526 builder.clone(op, valueMapping); 527 } 528 529 // Replace the original `async.execute` with a call to outlined function. 530 OpBuilder callBuilder(execute); 531 SmallVector<Value, 4> usedAboveArgs(usedAbove.begin(), usedAbove.end()); 532 auto callOutlinedFunc = callBuilder.create<CallOp>( 533 loc, func.getName(), execute.getResultTypes(), usedAboveArgs); 534 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 535 execute.erase(); 536 537 return {func, coro}; 538 } 539 540 //===----------------------------------------------------------------------===// 541 // Convert Async dialect types to LLVM types. 542 //===----------------------------------------------------------------------===// 543 544 namespace { 545 class AsyncRuntimeTypeConverter : public TypeConverter { 546 public: 547 AsyncRuntimeTypeConverter() { addConversion(convertType); } 548 549 static Type convertType(Type type) { 550 MLIRContext *ctx = type.getContext(); 551 // Convert async tokens to opaque pointers. 552 if (type.isa<TokenType>()) 553 return LLVM::LLVMType::getInt8PtrTy(ctx); 554 return type; 555 } 556 }; 557 } // namespace 558 559 //===----------------------------------------------------------------------===// 560 // Convert types for all call operations to lowered async types. 561 //===----------------------------------------------------------------------===// 562 563 namespace { 564 class CallOpOpConversion : public ConversionPattern { 565 public: 566 explicit CallOpOpConversion(MLIRContext *ctx) 567 : ConversionPattern(CallOp::getOperationName(), 1, ctx) {} 568 569 LogicalResult 570 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 571 ConversionPatternRewriter &rewriter) const override { 572 AsyncRuntimeTypeConverter converter; 573 574 SmallVector<Type, 5> resultTypes; 575 converter.convertTypes(op->getResultTypes(), resultTypes); 576 577 CallOp call = cast<CallOp>(op); 578 rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(), 579 call.getOperands()); 580 581 return success(); 582 } 583 }; 584 } // namespace 585 586 //===----------------------------------------------------------------------===// 587 // async.await op lowering to mlirAsyncRuntimeAwaitToken function call. 588 //===----------------------------------------------------------------------===// 589 590 namespace { 591 class AwaitOpLowering : public ConversionPattern { 592 public: 593 explicit AwaitOpLowering( 594 MLIRContext *ctx, 595 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 596 : ConversionPattern(AwaitOp::getOperationName(), 1, ctx), 597 outlinedFunctions(outlinedFunctions) {} 598 599 LogicalResult 600 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 601 ConversionPatternRewriter &rewriter) const override { 602 // We can only await on the token operand. Async valus are not supported. 603 auto await = cast<AwaitOp>(op); 604 if (!await.operand().getType().isa<TokenType>()) 605 return failure(); 606 607 // Check if `async.await` is inside the outlined coroutine function. 608 auto func = await.getParentOfType<FuncOp>(); 609 auto outlined = outlinedFunctions.find(func); 610 const bool isInCoroutine = outlined != outlinedFunctions.end(); 611 612 Location loc = op->getLoc(); 613 614 // Inside regular function we convert await operation to the blocking 615 // async API await function call. 616 if (!isInCoroutine) 617 rewriter.create<CallOp>(loc, Type(), kAwaitToken, 618 ValueRange(op->getOperand(0))); 619 620 // Inside the coroutine we convert await operation into coroutine suspension 621 // point, and resume execution asynchronously. 622 if (isInCoroutine) { 623 const CoroMachinery &coro = outlined->getSecond(); 624 625 OpBuilder builder(op); 626 MLIRContext *ctx = op->getContext(); 627 628 // A pointer to coroutine resume intrinsic wrapper. 629 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 630 auto resumePtr = builder.create<LLVM::AddressOfOp>( 631 loc, resumeFnTy.getPointerTo(), kResume); 632 633 // Save the coroutine state: @llvm.coro.save 634 auto coroSave = builder.create<LLVM::CallOp>( 635 loc, LLVM::LLVMTokenType::get(ctx), 636 builder.getSymbolRefAttr(kCoroSave), ValueRange(coro.coroHandle)); 637 638 // Call async runtime API to resume a coroutine in the managed thread when 639 // the async await argument becomes ready. 640 SmallVector<Value, 3> awaitAndExecuteArgs = { 641 await.getOperand(), coro.coroHandle, resumePtr.res()}; 642 builder.create<CallOp>(loc, Type(), kAwaitAndExecute, 643 awaitAndExecuteArgs); 644 645 // Split the entry block before the await operation. 646 addSuspensionPoint(coro, coroSave.getResult(0), op); 647 } 648 649 // Original operation was replaced by function call or suspension point. 650 rewriter.eraseOp(op); 651 652 return success(); 653 } 654 655 private: 656 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 657 }; 658 } // namespace 659 660 //===----------------------------------------------------------------------===// 661 662 namespace { 663 struct ConvertAsyncToLLVMPass 664 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 665 void runOnOperation() override; 666 }; 667 668 void ConvertAsyncToLLVMPass::runOnOperation() { 669 ModuleOp module = getOperation(); 670 SymbolTable symbolTable(module); 671 672 // Outline all `async.execute` body regions into async functions (coroutines). 673 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 674 675 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 676 // We currently do not support execute operations that take async 677 // token dependencies, async value arguments or produce async results. 678 if (!execute.dependencies().empty() || !execute.operands().empty() || 679 !execute.results().empty()) { 680 execute.emitOpError( 681 "Can't outline async.execute op with async dependencies, arguments " 682 "or returned async results"); 683 return WalkResult::interrupt(); 684 } 685 686 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 687 688 return WalkResult::advance(); 689 }); 690 691 // Failed to outline all async execute operations. 692 if (outlineResult.wasInterrupted()) { 693 signalPassFailure(); 694 return; 695 } 696 697 LLVM_DEBUG({ 698 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 699 << " async functions\n"; 700 }); 701 702 // Add declarations for all functions required by the coroutines lowering. 703 addResumeFunction(module); 704 addAsyncRuntimeApiDeclarations(module); 705 addCoroutineIntrinsicsDeclarations(module); 706 addCRuntimeDeclarations(module); 707 708 MLIRContext *ctx = &getContext(); 709 710 // Convert async dialect types and operations to LLVM dialect. 711 AsyncRuntimeTypeConverter converter; 712 OwningRewritePatternList patterns; 713 714 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 715 patterns.insert<CallOpOpConversion>(ctx); 716 patterns.insert<AwaitOpLowering>(ctx, outlinedFunctions); 717 718 ConversionTarget target(*ctx); 719 target.addLegalDialect<LLVM::LLVMDialect>(); 720 target.addIllegalDialect<AsyncDialect>(); 721 target.addDynamicallyLegalOp<FuncOp>( 722 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 723 target.addDynamicallyLegalOp<CallOp>( 724 [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); 725 726 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 727 signalPassFailure(); 728 } 729 } // namespace 730 731 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 732 return std::make_unique<ConvertAsyncToLLVMPass>(); 733 } 734