1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Async/IR/Async.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/RegionUtils.h" 23 #include "llvm/ADT/SetVector.h" 24 #include "llvm/Support/FormatVariadic.h" 25 26 #define DEBUG_TYPE "convert-async-to-llvm" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 // Prefix for functions outlined from `async.execute` op regions. 32 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 33 34 //===----------------------------------------------------------------------===// 35 // Async Runtime C API declaration. 36 //===----------------------------------------------------------------------===// 37 38 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 39 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 40 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 41 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 42 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 43 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 44 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 45 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 46 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 47 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 48 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 49 static constexpr const char *kGetValueStorage = 50 "mlirAsyncRuntimeGetValueStorage"; 51 static constexpr const char *kAddTokenToGroup = 52 "mlirAsyncRuntimeAddTokenToGroup"; 53 static constexpr const char *kAwaitTokenAndExecute = 54 "mlirAsyncRuntimeAwaitTokenAndExecute"; 55 static constexpr const char *kAwaitValueAndExecute = 56 "mlirAsyncRuntimeAwaitValueAndExecute"; 57 static constexpr const char *kAwaitAllAndExecute = 58 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 59 60 namespace { 61 /// Async Runtime API function types. 62 /// 63 /// Because we can't create API function signature for type parametrized 64 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 65 /// lowering all async data types become opaque pointers at runtime. 66 struct AsyncAPI { 67 // All async types are lowered to opaque i8* LLVM pointers at runtime. 68 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 69 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 70 } 71 72 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 73 return LLVM::LLVMTokenType::get(ctx); 74 } 75 76 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 77 auto ref = opaquePointerType(ctx); 78 auto count = IntegerType::get(ctx, 32); 79 return FunctionType::get(ctx, {ref, count}, {}); 80 } 81 82 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 83 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 84 } 85 86 static FunctionType createValueFunctionType(MLIRContext *ctx) { 87 auto i32 = IntegerType::get(ctx, 32); 88 auto value = opaquePointerType(ctx); 89 return FunctionType::get(ctx, {i32}, {value}); 90 } 91 92 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 93 return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); 94 } 95 96 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 97 auto value = opaquePointerType(ctx); 98 auto storage = opaquePointerType(ctx); 99 return FunctionType::get(ctx, {value}, {storage}); 100 } 101 102 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 103 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 104 } 105 106 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 107 auto value = opaquePointerType(ctx); 108 return FunctionType::get(ctx, {value}, {}); 109 } 110 111 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 112 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 113 } 114 115 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 116 auto value = opaquePointerType(ctx); 117 return FunctionType::get(ctx, {value}, {}); 118 } 119 120 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 121 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 122 } 123 124 static FunctionType executeFunctionType(MLIRContext *ctx) { 125 auto hdl = opaquePointerType(ctx); 126 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 127 return FunctionType::get(ctx, {hdl, resume}, {}); 128 } 129 130 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 131 auto i64 = IntegerType::get(ctx, 64); 132 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 133 {i64}); 134 } 135 136 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 137 auto hdl = opaquePointerType(ctx); 138 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 139 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 140 } 141 142 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 143 auto value = opaquePointerType(ctx); 144 auto hdl = opaquePointerType(ctx); 145 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 146 return FunctionType::get(ctx, {value, hdl, resume}, {}); 147 } 148 149 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 150 auto hdl = opaquePointerType(ctx); 151 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 152 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 153 } 154 155 // Auxiliary coroutine resume intrinsic wrapper. 156 static Type resumeFunctionType(MLIRContext *ctx) { 157 auto voidTy = LLVM::LLVMVoidType::get(ctx); 158 auto i8Ptr = opaquePointerType(ctx); 159 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 160 } 161 }; 162 } // namespace 163 164 /// Adds Async Runtime C API declarations to the module. 165 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 166 auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), 167 module.getBody()); 168 169 auto addFuncDecl = [&](StringRef name, FunctionType type) { 170 if (module.lookupSymbol(name)) 171 return; 172 builder.create<FuncOp>(name, type).setPrivate(); 173 }; 174 175 MLIRContext *ctx = module.getContext(); 176 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 177 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 178 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 179 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 180 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 181 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 182 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 183 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 184 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 185 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 186 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 187 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 188 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 189 addFuncDecl(kAwaitTokenAndExecute, 190 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 191 addFuncDecl(kAwaitValueAndExecute, 192 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 193 addFuncDecl(kAwaitAllAndExecute, 194 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 195 } 196 197 //===----------------------------------------------------------------------===// 198 // LLVM coroutines intrinsics declarations. 199 //===----------------------------------------------------------------------===// 200 201 static constexpr const char *kCoroId = "llvm.coro.id"; 202 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 203 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 204 static constexpr const char *kCoroSave = "llvm.coro.save"; 205 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 206 static constexpr const char *kCoroEnd = "llvm.coro.end"; 207 static constexpr const char *kCoroFree = "llvm.coro.free"; 208 static constexpr const char *kCoroResume = "llvm.coro.resume"; 209 210 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 211 StringRef name, Type ret, ArrayRef<Type> params) { 212 if (module.lookupSymbol(name)) 213 return; 214 Type type = LLVM::LLVMFunctionType::get(ret, params); 215 builder.create<LLVM::LLVMFuncOp>(name, type); 216 } 217 218 /// Adds coroutine intrinsics declarations to the module. 219 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 220 using namespace mlir::LLVM; 221 222 MLIRContext *ctx = module.getContext(); 223 ImplicitLocOpBuilder builder(module.getLoc(), 224 module.getBody()->getTerminator()); 225 226 auto token = LLVMTokenType::get(ctx); 227 auto voidTy = LLVMVoidType::get(ctx); 228 229 auto i8 = IntegerType::get(ctx, 8); 230 auto i1 = IntegerType::get(ctx, 1); 231 auto i32 = IntegerType::get(ctx, 32); 232 auto i64 = IntegerType::get(ctx, 64); 233 auto i8Ptr = LLVMPointerType::get(i8); 234 235 addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); 236 addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); 237 addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); 238 addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); 239 addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); 240 addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); 241 addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); 242 addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); 243 } 244 245 //===----------------------------------------------------------------------===// 246 // Add malloc/free declarations to the module. 247 //===----------------------------------------------------------------------===// 248 249 static constexpr const char *kMalloc = "malloc"; 250 static constexpr const char *kFree = "free"; 251 252 /// Adds malloc/free declarations to the module. 253 static void addCRuntimeDeclarations(ModuleOp module) { 254 using namespace mlir::LLVM; 255 256 MLIRContext *ctx = module.getContext(); 257 ImplicitLocOpBuilder builder(module.getLoc(), 258 module.getBody()->getTerminator()); 259 260 auto voidTy = LLVMVoidType::get(ctx); 261 auto i64 = IntegerType::get(ctx, 64); 262 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 263 264 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 265 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 266 } 267 268 //===----------------------------------------------------------------------===// 269 // Coroutine resume function wrapper. 270 //===----------------------------------------------------------------------===// 271 272 static constexpr const char *kResume = "__resume"; 273 274 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 275 /// intrinsics. We need this function to be able to pass it to the async 276 /// runtime execute API. 277 static void addResumeFunction(ModuleOp module) { 278 MLIRContext *ctx = module.getContext(); 279 280 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 281 Location loc = module.getLoc(); 282 283 if (module.lookupSymbol(kResume)) 284 return; 285 286 auto voidTy = LLVM::LLVMVoidType::get(ctx); 287 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 288 289 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 290 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 291 resumeOp.setPrivate(); 292 293 auto *block = resumeOp.addEntryBlock(); 294 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 295 296 blockBuilder.create<LLVM::CallOp>(TypeRange(), 297 blockBuilder.getSymbolRefAttr(kCoroResume), 298 resumeOp.getArgument(0)); 299 300 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 301 } 302 303 //===----------------------------------------------------------------------===// 304 // async.execute op outlining to the coroutine functions. 305 //===----------------------------------------------------------------------===// 306 307 /// Function targeted for coroutine transformation has two additional blocks at 308 /// the end: coroutine cleanup and coroutine suspension. 309 /// 310 /// async.await op lowering additionaly creates a resume block for each 311 /// operation to enable non-blocking waiting via coroutine suspension. 312 namespace { 313 struct CoroMachinery { 314 // Async execute region returns a completion token, and an async value for 315 // each yielded value. 316 // 317 // %token, %result = async.execute -> !async.value<T> { 318 // %0 = constant ... : T 319 // async.yield %0 : T 320 // } 321 Value asyncToken; // token representing completion of the async region 322 llvm::SmallVector<Value, 4> returnValues; // returned async values 323 324 Value coroHandle; // coroutine handle (!async.coro.handle value) 325 Block *cleanup; // coroutine cleanup block 326 Block *suspend; // coroutine suspension block 327 }; 328 } // namespace 329 330 /// Builds an coroutine template compatible with LLVM coroutines switched-resume 331 /// lowering using `async.runtime.*` and `async.coro.*` operations. 332 /// 333 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 334 /// 335 /// - `entry` block sets up the coroutine. 336 /// - `cleanup` block cleans up the coroutine state. 337 /// - `suspend block after the @llvm.coro.end() defines what value will be 338 /// returned to the initial caller of a coroutine. Everything before the 339 /// @llvm.coro.end() will be executed at every suspension point. 340 /// 341 /// Coroutine structure (only the important bits): 342 /// 343 /// func @async_execute_fn(<function-arguments>) 344 /// -> (!async.token, !async.value<T>) 345 /// { 346 /// ^entry(<function-arguments>): 347 /// %token = <async token> : !async.token // create async runtime token 348 /// %value = <async value> : !async.value<T> // create async value 349 /// %id = async.coro.id // create a coroutine id 350 /// %hdl = async.coro.begin %id // create a coroutine handle 351 /// br ^cleanup 352 /// 353 /// ^cleanup: 354 /// async.coro.free %hdl // delete the coroutine state 355 /// br ^suspend 356 /// 357 /// ^suspend: 358 /// async.coro.end %hdl // marks the end of a coroutine 359 /// return %token, %value : !async.token, !async.value<T> 360 /// } 361 /// 362 /// The actual code for the async.execute operation body region will be inserted 363 /// before the entry block terminator. 364 /// 365 /// 366 static CoroMachinery setupCoroMachinery(FuncOp func) { 367 assert(func.getBody().empty() && "Function must have empty body"); 368 369 MLIRContext *ctx = func.getContext(); 370 Block *entryBlock = func.addEntryBlock(); 371 372 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 373 374 // ------------------------------------------------------------------------ // 375 // Allocate async token/values that we will return from a ramp function. 376 // ------------------------------------------------------------------------ // 377 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 378 379 llvm::SmallVector<Value, 4> retValues; 380 for (auto resType : func.getCallableResults().drop_front()) 381 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 382 383 // ------------------------------------------------------------------------ // 384 // Initialize coroutine: get coroutine id and coroutine handle. 385 // ------------------------------------------------------------------------ // 386 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 387 auto coroHdlOp = 388 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 389 390 Block *cleanupBlock = func.addBlock(); 391 Block *suspendBlock = func.addBlock(); 392 393 // ------------------------------------------------------------------------ // 394 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 395 // ------------------------------------------------------------------------ // 396 builder.setInsertionPointToStart(cleanupBlock); 397 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 398 399 // Branch into the suspend block. 400 builder.create<BranchOp>(suspendBlock); 401 402 // ------------------------------------------------------------------------ // 403 // Coroutine suspend block: mark the end of a coroutine and return allocated 404 // async token. 405 // ------------------------------------------------------------------------ // 406 builder.setInsertionPointToStart(suspendBlock); 407 408 // Mark the end of a coroutine: async.coro.end 409 builder.create<CoroEndOp>(coroHdlOp.handle()); 410 411 // Return created `async.token` and `async.values` from the suspend block. 412 // This will be the return value of a coroutine ramp function. 413 SmallVector<Value, 4> ret{retToken}; 414 ret.insert(ret.end(), retValues.begin(), retValues.end()); 415 builder.create<ReturnOp>(ret); 416 417 // Branch from the entry block to the cleanup block to create a valid CFG. 418 builder.setInsertionPointToEnd(entryBlock); 419 builder.create<BranchOp>(cleanupBlock); 420 421 // `async.await` op lowering will create resume blocks for async 422 // continuations, and will conditionally branch to cleanup or suspend blocks. 423 424 CoroMachinery machinery; 425 machinery.asyncToken = retToken; 426 machinery.returnValues = retValues; 427 machinery.coroHandle = coroHdlOp.handle(); 428 machinery.cleanup = cleanupBlock; 429 machinery.suspend = suspendBlock; 430 return machinery; 431 } 432 433 /// Outline the body region attached to the `async.execute` op into a standalone 434 /// function. 435 /// 436 /// Note that this is not reversible transformation. 437 static std::pair<FuncOp, CoroMachinery> 438 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 439 ModuleOp module = execute->getParentOfType<ModuleOp>(); 440 441 MLIRContext *ctx = module.getContext(); 442 Location loc = execute.getLoc(); 443 444 // Collect all outlined function inputs. 445 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 446 execute.dependencies().end()); 447 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 448 getUsedValuesDefinedAbove(execute.body(), functionInputs); 449 450 // Collect types for the outlined function inputs and outputs. 451 auto typesRange = llvm::map_range( 452 functionInputs, [](Value value) { return value.getType(); }); 453 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 454 auto outputTypes = execute.getResultTypes(); 455 456 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 457 auto funcAttrs = ArrayRef<NamedAttribute>(); 458 459 // TODO: Derive outlined function name from the parent FuncOp (support 460 // multiple nested async.execute operations). 461 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 462 symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); 463 464 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 465 466 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 467 // blocks, adding async.coro operations and setting up control flow. 468 CoroMachinery coro = setupCoroMachinery(func); 469 470 // Suspend async function at the end of an entry block, and resume it using 471 // Async resume operation (execution will be resumed in a thread managed by 472 // the async runtime). 473 Block *entryBlock = &func.getBlocks().front(); 474 auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); 475 476 // Save the coroutine state: async.coro.save 477 auto coroSaveOp = 478 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 479 480 // Pass coroutine to the runtime to be resumed on a runtime managed thread. 481 builder.create<RuntimeResumeOp>(coro.coroHandle); 482 483 // Split the entry block before the terminator (branch to suspend block). 484 auto *terminatorOp = entryBlock->getTerminator(); 485 Block *suspended = terminatorOp->getBlock(); 486 Block *resume = suspended->splitBlock(terminatorOp); 487 488 // Add async.coro.suspend as a suspended block terminator. 489 builder.setInsertionPointToEnd(suspended); 490 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 491 coro.cleanup); 492 493 size_t numDependencies = execute.dependencies().size(); 494 size_t numOperands = execute.operands().size(); 495 496 // Await on all dependencies before starting to execute the body region. 497 builder.setInsertionPointToStart(resume); 498 for (size_t i = 0; i < numDependencies; ++i) 499 builder.create<AwaitOp>(func.getArgument(i)); 500 501 // Await on all async value operands and unwrap the payload. 502 SmallVector<Value, 4> unwrappedOperands(numOperands); 503 for (size_t i = 0; i < numOperands; ++i) { 504 Value operand = func.getArgument(numDependencies + i); 505 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 506 } 507 508 // Map from function inputs defined above the execute op to the function 509 // arguments. 510 BlockAndValueMapping valueMapping; 511 valueMapping.map(functionInputs, func.getArguments()); 512 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 513 514 // Clone all operations from the execute operation body into the outlined 515 // function body. 516 for (Operation &op : execute.body().getOps()) 517 builder.clone(op, valueMapping); 518 519 // Replace the original `async.execute` with a call to outlined function. 520 ImplicitLocOpBuilder callBuilder(loc, execute); 521 auto callOutlinedFunc = callBuilder.create<CallOp>( 522 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 523 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 524 execute.erase(); 525 526 return {func, coro}; 527 } 528 529 //===----------------------------------------------------------------------===// 530 // Convert Async dialect types to LLVM types. 531 //===----------------------------------------------------------------------===// 532 533 namespace { 534 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 535 /// their runtime type (opaque pointers) and does not convert any other types. 536 class AsyncRuntimeTypeConverter : public TypeConverter { 537 public: 538 AsyncRuntimeTypeConverter() { 539 addConversion([](Type type) { return type; }); 540 addConversion(convertAsyncTypes); 541 } 542 543 static Optional<Type> convertAsyncTypes(Type type) { 544 if (type.isa<TokenType, GroupType, ValueType>()) 545 return AsyncAPI::opaquePointerType(type.getContext()); 546 547 if (type.isa<CoroIdType, CoroStateType>()) 548 return AsyncAPI::tokenType(type.getContext()); 549 if (type.isa<CoroHandleType>()) 550 return AsyncAPI::opaquePointerType(type.getContext()); 551 552 return llvm::None; 553 } 554 }; 555 } // namespace 556 557 //===----------------------------------------------------------------------===// 558 // Convert async.coro.id to @llvm.coro.id intrinsic. 559 //===----------------------------------------------------------------------===// 560 561 namespace { 562 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> { 563 public: 564 using OpConversionPattern::OpConversionPattern; 565 566 LogicalResult 567 matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands, 568 ConversionPatternRewriter &rewriter) const override { 569 auto token = AsyncAPI::tokenType(op->getContext()); 570 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 571 auto loc = op->getLoc(); 572 573 // Constants for initializing coroutine frame. 574 auto constZero = rewriter.create<LLVM::ConstantOp>( 575 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); 576 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr); 577 578 // Get coroutine id: @llvm.coro.id. 579 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 580 op, token, rewriter.getSymbolRefAttr(kCoroId), 581 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 582 583 return success(); 584 } 585 }; 586 } // namespace 587 588 //===----------------------------------------------------------------------===// 589 // Convert async.coro.begin to @llvm.coro.begin intrinsic. 590 //===----------------------------------------------------------------------===// 591 592 namespace { 593 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> { 594 public: 595 using OpConversionPattern::OpConversionPattern; 596 597 LogicalResult 598 matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands, 599 ConversionPatternRewriter &rewriter) const override { 600 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 601 auto loc = op->getLoc(); 602 603 // Get coroutine frame size: @llvm.coro.size.i64. 604 auto coroSize = rewriter.create<LLVM::CallOp>( 605 loc, rewriter.getI64Type(), rewriter.getSymbolRefAttr(kCoroSizeI64), 606 ValueRange()); 607 608 // Allocate memory for the coroutine frame. 609 auto coroAlloc = rewriter.create<LLVM::CallOp>( 610 loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc), 611 ValueRange(coroSize.getResult(0))); 612 613 // Begin a coroutine: @llvm.coro.begin. 614 auto coroId = CoroBeginOpAdaptor(operands).id(); 615 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 616 op, i8Ptr, rewriter.getSymbolRefAttr(kCoroBegin), 617 ValueRange({coroId, coroAlloc.getResult(0)})); 618 619 return success(); 620 } 621 }; 622 } // namespace 623 624 //===----------------------------------------------------------------------===// 625 // Convert async.coro.free to @llvm.coro.free intrinsic. 626 //===----------------------------------------------------------------------===// 627 628 namespace { 629 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> { 630 public: 631 using OpConversionPattern::OpConversionPattern; 632 633 LogicalResult 634 matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands, 635 ConversionPatternRewriter &rewriter) const override { 636 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); 637 auto loc = op->getLoc(); 638 639 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 640 auto coroMem = rewriter.create<LLVM::CallOp>( 641 loc, i8Ptr, rewriter.getSymbolRefAttr(kCoroFree), operands); 642 643 // Free the memory. 644 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(), 645 rewriter.getSymbolRefAttr(kFree), 646 ValueRange(coroMem.getResult(0))); 647 648 return success(); 649 } 650 }; 651 } // namespace 652 653 //===----------------------------------------------------------------------===// 654 // Convert async.coro.end to @llvm.coro.end intrinsic. 655 //===----------------------------------------------------------------------===// 656 657 namespace { 658 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 659 public: 660 using OpConversionPattern::OpConversionPattern; 661 662 LogicalResult 663 matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands, 664 ConversionPatternRewriter &rewriter) const override { 665 // We are not in the block that is part of the unwind sequence. 666 auto constFalse = rewriter.create<LLVM::ConstantOp>( 667 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 668 669 // Mark the end of a coroutine: @llvm.coro.end. 670 auto coroHdl = CoroEndOpAdaptor(operands).handle(); 671 rewriter.create<LLVM::CallOp>(op->getLoc(), rewriter.getI1Type(), 672 rewriter.getSymbolRefAttr(kCoroEnd), 673 ValueRange({coroHdl, constFalse})); 674 rewriter.eraseOp(op); 675 676 return success(); 677 } 678 }; 679 } // namespace 680 681 //===----------------------------------------------------------------------===// 682 // Convert async.coro.save to @llvm.coro.save intrinsic. 683 //===----------------------------------------------------------------------===// 684 685 namespace { 686 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 687 public: 688 using OpConversionPattern::OpConversionPattern; 689 690 LogicalResult 691 matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands, 692 ConversionPatternRewriter &rewriter) const override { 693 // Save the coroutine state: @llvm.coro.save 694 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 695 op, AsyncAPI::tokenType(op->getContext()), 696 rewriter.getSymbolRefAttr(kCoroSave), operands); 697 698 return success(); 699 } 700 }; 701 } // namespace 702 703 //===----------------------------------------------------------------------===// 704 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 705 //===----------------------------------------------------------------------===// 706 707 namespace { 708 709 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 710 /// branch to the appropriate block based on the return code. 711 /// 712 /// Before: 713 /// 714 /// ^suspended: 715 /// "opBefore"(...) 716 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 717 /// ^resume: 718 /// "op"(...) 719 /// ^cleanup: ... 720 /// ^suspend: ... 721 /// 722 /// After: 723 /// 724 /// ^suspended: 725 /// "opBefore"(...) 726 /// %suspend = llmv.call @llvm.coro.suspend(...) 727 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 728 /// ^resume: 729 /// "op"(...) 730 /// ^cleanup: ... 731 /// ^suspend: ... 732 /// 733 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 734 public: 735 using OpConversionPattern::OpConversionPattern; 736 737 LogicalResult 738 matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands, 739 ConversionPatternRewriter &rewriter) const override { 740 auto i8 = rewriter.getIntegerType(8); 741 auto i32 = rewriter.getI32Type(); 742 auto loc = op->getLoc(); 743 744 // This is not a final suspension point. 745 auto constFalse = rewriter.create<LLVM::ConstantOp>( 746 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 747 748 // Suspend a coroutine: @llvm.coro.suspend 749 auto coroState = CoroSuspendOpAdaptor(operands).state(); 750 auto coroSuspend = rewriter.create<LLVM::CallOp>( 751 loc, i8, rewriter.getSymbolRefAttr(kCoroSuspend), 752 ValueRange({coroState, constFalse})); 753 754 // Cast return code to i32. 755 756 // After a suspension point decide if we should branch into resume, cleanup 757 // or suspend block of the coroutine (see @llvm.coro.suspend return code 758 // documentation). 759 llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 760 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(), 761 op.cleanupDest()}; 762 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 763 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult(0)), 764 /*defaultDestination=*/op.suspendDest(), 765 /*defaultOperands=*/ValueRange(), 766 /*caseValues=*/caseValues, 767 /*caseDestinations=*/caseDest, 768 /*caseOperands=*/ArrayRef<ValueRange>(), 769 /*branchWeights=*/ArrayRef<int32_t>()); 770 771 return success(); 772 } 773 }; 774 } // namespace 775 776 //===----------------------------------------------------------------------===// 777 // Convert async.runtime.create to the corresponding runtime API call. 778 // 779 // To allocate storage for the async values we use getelementptr trick: 780 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 781 //===----------------------------------------------------------------------===// 782 783 namespace { 784 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> { 785 public: 786 using OpConversionPattern::OpConversionPattern; 787 788 LogicalResult 789 matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands, 790 ConversionPatternRewriter &rewriter) const override { 791 TypeConverter *converter = getTypeConverter(); 792 Type resultType = op->getResultTypes()[0]; 793 794 // Tokens and Groups lowered to function calls without arguments. 795 if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) { 796 rewriter.replaceOpWithNewOp<CallOp>( 797 op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup, 798 converter->convertType(resultType)); 799 return success(); 800 } 801 802 // To create a value we need to compute the storage requirement. 803 if (auto value = resultType.dyn_cast<ValueType>()) { 804 // Returns the size requirements for the async value storage. 805 auto sizeOf = [&](ValueType valueType) -> Value { 806 auto loc = op->getLoc(); 807 auto i32 = rewriter.getI32Type(); 808 809 auto storedType = converter->convertType(valueType.getValueType()); 810 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 811 812 // %Size = getelementptr %T* null, int 1 813 // %SizeI = ptrtoint %T* %Size to i32 814 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType); 815 auto one = rewriter.create<LLVM::ConstantOp>( 816 loc, i32, rewriter.getI32IntegerAttr(1)); 817 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 818 one.getResult()); 819 return rewriter.create<LLVM::PtrToIntOp>(loc, i32, gep); 820 }; 821 822 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateValue, resultType, 823 sizeOf(value)); 824 825 return success(); 826 } 827 828 return rewriter.notifyMatchFailure(op, "unsupported async type"); 829 } 830 }; 831 } // namespace 832 833 //===----------------------------------------------------------------------===// 834 // Convert async.runtime.set_available to the corresponding runtime API call. 835 //===----------------------------------------------------------------------===// 836 837 namespace { 838 class RuntimeSetAvailableOpLowering 839 : public OpConversionPattern<RuntimeSetAvailableOp> { 840 public: 841 using OpConversionPattern::OpConversionPattern; 842 843 LogicalResult 844 matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands, 845 ConversionPatternRewriter &rewriter) const override { 846 Type operandType = op.operand().getType(); 847 848 if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) { 849 rewriter.create<CallOp>(op->getLoc(), 850 operandType.isa<TokenType>() ? kEmplaceToken 851 : kEmplaceValue, 852 TypeRange(), operands); 853 rewriter.eraseOp(op); 854 return success(); 855 } 856 857 return rewriter.notifyMatchFailure(op, "unsupported async type"); 858 } 859 }; 860 } // namespace 861 862 //===----------------------------------------------------------------------===// 863 // Convert async.runtime.await to the corresponding runtime API call. 864 //===----------------------------------------------------------------------===// 865 866 namespace { 867 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 868 public: 869 using OpConversionPattern::OpConversionPattern; 870 871 LogicalResult 872 matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands, 873 ConversionPatternRewriter &rewriter) const override { 874 Type operandType = op.operand().getType(); 875 876 StringRef apiFuncName; 877 if (operandType.isa<TokenType>()) 878 apiFuncName = kAwaitToken; 879 else if (operandType.isa<ValueType>()) 880 apiFuncName = kAwaitValue; 881 else if (operandType.isa<GroupType>()) 882 apiFuncName = kAwaitGroup; 883 else 884 return rewriter.notifyMatchFailure(op, "unsupported async type"); 885 886 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands); 887 rewriter.eraseOp(op); 888 889 return success(); 890 } 891 }; 892 } // namespace 893 894 //===----------------------------------------------------------------------===// 895 // Convert async.runtime.await_and_resume to the corresponding runtime API call. 896 //===----------------------------------------------------------------------===// 897 898 namespace { 899 class RuntimeAwaitAndResumeOpLowering 900 : public OpConversionPattern<RuntimeAwaitAndResumeOp> { 901 public: 902 using OpConversionPattern::OpConversionPattern; 903 904 LogicalResult 905 matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands, 906 ConversionPatternRewriter &rewriter) const override { 907 Type operandType = op.operand().getType(); 908 909 StringRef apiFuncName; 910 if (operandType.isa<TokenType>()) 911 apiFuncName = kAwaitTokenAndExecute; 912 else if (operandType.isa<ValueType>()) 913 apiFuncName = kAwaitValueAndExecute; 914 else if (operandType.isa<GroupType>()) 915 apiFuncName = kAwaitAllAndExecute; 916 else 917 return rewriter.notifyMatchFailure(op, "unsupported async type"); 918 919 Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); 920 Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); 921 922 // A pointer to coroutine resume intrinsic wrapper. 923 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 924 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 925 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 926 927 rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), 928 ValueRange({operand, handle, resumePtr.res()})); 929 rewriter.eraseOp(op); 930 931 return success(); 932 } 933 }; 934 } // namespace 935 936 //===----------------------------------------------------------------------===// 937 // Convert async.runtime.resume to the corresponding runtime API call. 938 //===----------------------------------------------------------------------===// 939 940 namespace { 941 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> { 942 public: 943 using OpConversionPattern::OpConversionPattern; 944 945 LogicalResult 946 matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands, 947 ConversionPatternRewriter &rewriter) const override { 948 // A pointer to coroutine resume intrinsic wrapper. 949 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); 950 auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 951 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); 952 953 // Call async runtime API to execute a coroutine in the managed thread. 954 auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); 955 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute, 956 ValueRange({coroHdl, resumePtr.res()})); 957 958 return success(); 959 } 960 }; 961 } // namespace 962 963 //===----------------------------------------------------------------------===// 964 // Convert async.runtime.store to the corresponding runtime API call. 965 //===----------------------------------------------------------------------===// 966 967 namespace { 968 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> { 969 public: 970 using OpConversionPattern::OpConversionPattern; 971 972 LogicalResult 973 matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands, 974 ConversionPatternRewriter &rewriter) const override { 975 Location loc = op->getLoc(); 976 977 // Get a pointer to the async value storage from the runtime. 978 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 979 auto storage = RuntimeStoreOpAdaptor(operands).storage(); 980 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 981 TypeRange(i8Ptr), storage); 982 983 // Cast from i8* to the LLVM pointer type. 984 auto valueType = op.value().getType(); 985 auto llvmValueType = getTypeConverter()->convertType(valueType); 986 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 987 loc, LLVM::LLVMPointerType::get(llvmValueType), 988 storagePtr.getResult(0)); 989 990 // Store the yielded value into the async value storage. 991 auto value = RuntimeStoreOpAdaptor(operands).value(); 992 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult()); 993 994 // Erase the original runtime store operation. 995 rewriter.eraseOp(op); 996 997 return success(); 998 } 999 }; 1000 } // namespace 1001 1002 //===----------------------------------------------------------------------===// 1003 // Convert async.runtime.load to the corresponding runtime API call. 1004 //===----------------------------------------------------------------------===// 1005 1006 namespace { 1007 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> { 1008 public: 1009 using OpConversionPattern::OpConversionPattern; 1010 1011 LogicalResult 1012 matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands, 1013 ConversionPatternRewriter &rewriter) const override { 1014 Location loc = op->getLoc(); 1015 1016 // Get a pointer to the async value storage from the runtime. 1017 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 1018 auto storage = RuntimeLoadOpAdaptor(operands).storage(); 1019 auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage, 1020 TypeRange(i8Ptr), storage); 1021 1022 // Cast from i8* to the LLVM pointer type. 1023 auto valueType = op.result().getType(); 1024 auto llvmValueType = getTypeConverter()->convertType(valueType); 1025 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>( 1026 loc, LLVM::LLVMPointerType::get(llvmValueType), 1027 storagePtr.getResult(0)); 1028 1029 // Load from the casted pointer. 1030 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult()); 1031 1032 return success(); 1033 } 1034 }; 1035 } // namespace 1036 1037 //===----------------------------------------------------------------------===// 1038 // Convert async.runtime.add_to_group to the corresponding runtime API call. 1039 //===----------------------------------------------------------------------===// 1040 1041 namespace { 1042 class RuntimeAddToGroupOpLowering 1043 : public OpConversionPattern<RuntimeAddToGroupOp> { 1044 public: 1045 using OpConversionPattern::OpConversionPattern; 1046 1047 LogicalResult 1048 matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands, 1049 ConversionPatternRewriter &rewriter) const override { 1050 // Currently we can only add tokens to the group. 1051 if (!op.operand().getType().isa<TokenType>()) 1052 return rewriter.notifyMatchFailure(op, "only token type is supported"); 1053 1054 // Replace with a runtime API function call. 1055 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, 1056 rewriter.getI64Type(), operands); 1057 1058 return success(); 1059 } 1060 }; 1061 } // namespace 1062 1063 //===----------------------------------------------------------------------===// 1064 // Async reference counting ops lowering (`async.runtime.add_ref` and 1065 // `async.runtime.drop_ref` to the corresponding API calls). 1066 //===----------------------------------------------------------------------===// 1067 1068 namespace { 1069 template <typename RefCountingOp> 1070 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 1071 public: 1072 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 1073 StringRef apiFunctionName) 1074 : OpConversionPattern<RefCountingOp>(converter, ctx), 1075 apiFunctionName(apiFunctionName) {} 1076 1077 LogicalResult 1078 matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands, 1079 ConversionPatternRewriter &rewriter) const override { 1080 auto count = 1081 rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(), 1082 rewriter.getI32IntegerAttr(op.count())); 1083 1084 auto operand = typename RefCountingOp::Adaptor(operands).operand(); 1085 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 1086 ValueRange({operand, count})); 1087 1088 return success(); 1089 } 1090 1091 private: 1092 StringRef apiFunctionName; 1093 }; 1094 1095 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 1096 public: 1097 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 1098 : RefCountingOpLowering(converter, ctx, kAddRef) {} 1099 }; 1100 1101 class RuntimeDropRefOpLowering 1102 : public RefCountingOpLowering<RuntimeDropRefOp> { 1103 public: 1104 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 1105 : RefCountingOpLowering(converter, ctx, kDropRef) {} 1106 }; 1107 } // namespace 1108 1109 //===----------------------------------------------------------------------===// 1110 // Convert return operations that return async values from async regions. 1111 //===----------------------------------------------------------------------===// 1112 1113 namespace { 1114 class ReturnOpOpConversion : public OpConversionPattern<ReturnOp> { 1115 public: 1116 using OpConversionPattern::OpConversionPattern; 1117 1118 LogicalResult 1119 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 1120 ConversionPatternRewriter &rewriter) const override { 1121 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 1122 return success(); 1123 } 1124 }; 1125 } // namespace 1126 1127 //===----------------------------------------------------------------------===// 1128 // Convert async.create_group operation to async.runtime.create 1129 //===----------------------------------------------------------------------===// 1130 1131 namespace { 1132 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 1133 public: 1134 using OpConversionPattern::OpConversionPattern; 1135 1136 LogicalResult 1137 matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands, 1138 ConversionPatternRewriter &rewriter) const override { 1139 rewriter.replaceOpWithNewOp<RuntimeCreateOp>( 1140 op, GroupType::get(op->getContext())); 1141 return success(); 1142 } 1143 }; 1144 } // namespace 1145 1146 //===----------------------------------------------------------------------===// 1147 // Convert async.add_to_group operation to async.runtime.add_to_group. 1148 //===----------------------------------------------------------------------===// 1149 1150 namespace { 1151 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 1152 public: 1153 using OpConversionPattern::OpConversionPattern; 1154 1155 LogicalResult 1156 matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands, 1157 ConversionPatternRewriter &rewriter) const override { 1158 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 1159 op, rewriter.getIndexType(), operands); 1160 return success(); 1161 } 1162 }; 1163 } // namespace 1164 1165 //===----------------------------------------------------------------------===// 1166 // Convert async.await and async.await_all operations to the async.runtime.await 1167 // or async.runtime.await_and_resume operations. 1168 //===----------------------------------------------------------------------===// 1169 1170 namespace { 1171 template <typename AwaitType, typename AwaitableType> 1172 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 1173 using AwaitAdaptor = typename AwaitType::Adaptor; 1174 1175 public: 1176 AwaitOpLoweringBase( 1177 MLIRContext *ctx, 1178 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 1179 : OpConversionPattern<AwaitType>(ctx), 1180 outlinedFunctions(outlinedFunctions) {} 1181 1182 LogicalResult 1183 matchAndRewrite(AwaitType op, ArrayRef<Value> operands, 1184 ConversionPatternRewriter &rewriter) const override { 1185 // We can only await on one the `AwaitableType` (for `await` it can be 1186 // a `token` or a `value`, for `await_all` it must be a `group`). 1187 if (!op.operand().getType().template isa<AwaitableType>()) 1188 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 1189 1190 // Check if await operation is inside the outlined coroutine function. 1191 auto func = op->template getParentOfType<FuncOp>(); 1192 auto outlined = outlinedFunctions.find(func); 1193 const bool isInCoroutine = outlined != outlinedFunctions.end(); 1194 1195 Location loc = op->getLoc(); 1196 Value operand = AwaitAdaptor(operands).operand(); 1197 1198 // Inside regular functions we use the blocking wait operation to wait for 1199 // the async object (token, value or group) to become available. 1200 if (!isInCoroutine) 1201 rewriter.create<RuntimeAwaitOp>(loc, operand); 1202 1203 // Inside the coroutine we convert await operation into coroutine suspension 1204 // point, and resume execution asynchronously. 1205 if (isInCoroutine) { 1206 const CoroMachinery &coro = outlined->getSecond(); 1207 Block *suspended = op->getBlock(); 1208 1209 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 1210 MLIRContext *ctx = op->getContext(); 1211 1212 // Save the coroutine state and resume on a runtime managed thread when 1213 // the operand becomes available. 1214 auto coroSaveOp = 1215 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 1216 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 1217 1218 // Split the entry block before the await operation. 1219 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 1220 1221 // Add async.coro.suspend as a suspended block terminator. 1222 builder.setInsertionPointToEnd(suspended); 1223 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 1224 coro.cleanup); 1225 1226 // Make sure that replacement value will be constructed in resume block. 1227 rewriter.setInsertionPointToStart(resume); 1228 } 1229 1230 // Erase or replace the await operation with the new value. 1231 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 1232 rewriter.replaceOp(op, replaceWith); 1233 else 1234 rewriter.eraseOp(op); 1235 1236 return success(); 1237 } 1238 1239 virtual Value getReplacementValue(AwaitType op, Value operand, 1240 ConversionPatternRewriter &rewriter) const { 1241 return Value(); 1242 } 1243 1244 private: 1245 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 1246 }; 1247 1248 /// Lowering for `async.await` with a token operand. 1249 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 1250 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 1251 1252 public: 1253 using Base::Base; 1254 }; 1255 1256 /// Lowering for `async.await` with a value operand. 1257 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 1258 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 1259 1260 public: 1261 using Base::Base; 1262 1263 Value 1264 getReplacementValue(AwaitOp op, Value operand, 1265 ConversionPatternRewriter &rewriter) const override { 1266 // Load from the async value storage. 1267 auto valueType = operand.getType().cast<ValueType>().getValueType(); 1268 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 1269 } 1270 }; 1271 1272 /// Lowering for `async.await_all` operation. 1273 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 1274 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 1275 1276 public: 1277 using Base::Base; 1278 }; 1279 1280 } // namespace 1281 1282 //===----------------------------------------------------------------------===// 1283 // Convert async.yield operation to async.runtime operations. 1284 //===----------------------------------------------------------------------===// 1285 1286 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 1287 public: 1288 YieldOpLowering( 1289 MLIRContext *ctx, 1290 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 1291 : OpConversionPattern<async::YieldOp>(ctx), 1292 outlinedFunctions(outlinedFunctions) {} 1293 1294 LogicalResult 1295 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 1296 ConversionPatternRewriter &rewriter) const override { 1297 // Check if yield operation is inside the outlined coroutine function. 1298 auto func = op->template getParentOfType<FuncOp>(); 1299 auto outlined = outlinedFunctions.find(func); 1300 if (outlined == outlinedFunctions.end()) 1301 return rewriter.notifyMatchFailure( 1302 op, "operation is not inside the outlined async.execute function"); 1303 1304 Location loc = op->getLoc(); 1305 const CoroMachinery &coro = outlined->getSecond(); 1306 1307 // Store yielded values into the async values storage and switch async 1308 // values state to available. 1309 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 1310 Value yieldValue = std::get<0>(tuple); 1311 Value asyncValue = std::get<1>(tuple); 1312 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 1313 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 1314 } 1315 1316 // Switch the coroutine completion token to available state. 1317 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 1318 1319 return success(); 1320 } 1321 1322 private: 1323 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 1324 }; 1325 1326 //===----------------------------------------------------------------------===// 1327 1328 namespace { 1329 struct ConvertAsyncToLLVMPass 1330 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 1331 void runOnOperation() override; 1332 }; 1333 } // namespace 1334 1335 void ConvertAsyncToLLVMPass::runOnOperation() { 1336 ModuleOp module = getOperation(); 1337 SymbolTable symbolTable(module); 1338 1339 MLIRContext *ctx = &getContext(); 1340 1341 // Outline all `async.execute` body regions into async functions (coroutines). 1342 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 1343 1344 // We use conversion to LLVM type to ensure that all `async.value` operands 1345 // and results can be lowered to LLVM load and store operations. 1346 LLVMTypeConverter llvmConverter(ctx); 1347 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 1348 1349 // Returns true if the `async.value` payload is convertible to LLVM. 1350 auto isConvertibleToLlvm = [&](Type type) -> bool { 1351 auto valueType = type.cast<ValueType>().getValueType(); 1352 return static_cast<bool>(llvmConverter.convertType(valueType)); 1353 }; 1354 1355 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 1356 // All operands and results must be convertible to LLVM. 1357 if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) { 1358 execute.emitOpError("operands payload must be convertible to LLVM type"); 1359 return WalkResult::interrupt(); 1360 } 1361 if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) { 1362 execute.emitOpError("results payload must be convertible to LLVM type"); 1363 return WalkResult::interrupt(); 1364 } 1365 1366 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 1367 1368 return WalkResult::advance(); 1369 }); 1370 1371 // Failed to outline all async execute operations. 1372 if (outlineResult.wasInterrupted()) { 1373 signalPassFailure(); 1374 return; 1375 } 1376 1377 LLVM_DEBUG({ 1378 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 1379 << " async functions\n"; 1380 }); 1381 1382 // Add declarations for all functions required by the coroutines lowering. 1383 addResumeFunction(module); 1384 addAsyncRuntimeApiDeclarations(module); 1385 addCoroutineIntrinsicsDeclarations(module); 1386 addCRuntimeDeclarations(module); 1387 1388 // ------------------------------------------------------------------------ // 1389 // Lower async operations to async.runtime operations. 1390 // ------------------------------------------------------------------------ // 1391 OwningRewritePatternList asyncPatterns; 1392 1393 // Async lowering does not use type converter because it must preserve all 1394 // types for async.runtime operations. 1395 asyncPatterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 1396 asyncPatterns.insert<AwaitTokenOpLowering, AwaitValueOpLowering, 1397 AwaitAllOpLowering, YieldOpLowering>(ctx, 1398 outlinedFunctions); 1399 1400 // All high level async operations must be lowered to the runtime operations. 1401 ConversionTarget runtimeTarget(*ctx); 1402 runtimeTarget.addLegalDialect<AsyncDialect>(); 1403 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 1404 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 1405 1406 if (failed(applyPartialConversion(module, runtimeTarget, 1407 std::move(asyncPatterns)))) { 1408 signalPassFailure(); 1409 return; 1410 } 1411 1412 // ------------------------------------------------------------------------ // 1413 // Lower async.runtime and async.coro operations to Async Runtime API and 1414 // LLVM coroutine intrinsics. 1415 // ------------------------------------------------------------------------ // 1416 1417 // Convert async dialect types and operations to LLVM dialect. 1418 AsyncRuntimeTypeConverter converter; 1419 OwningRewritePatternList patterns; 1420 1421 // Convert async types in function signatures and function calls. 1422 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 1423 populateCallOpTypeConversionPattern(patterns, ctx, converter); 1424 1425 // Convert return operations inside async.execute regions. 1426 patterns.insert<ReturnOpOpConversion>(converter, ctx); 1427 1428 // Lower async.runtime operations to the async runtime API calls. 1429 patterns.insert<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering, 1430 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 1431 RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering, 1432 RuntimeDropRefOpLowering>(converter, ctx); 1433 1434 // Lower async.runtime operations that rely on LLVM type converter to convert 1435 // from async value payload type to the LLVM type. 1436 patterns.insert<RuntimeCreateOpLowering, RuntimeStoreOpLowering, 1437 RuntimeLoadOpLowering>(llvmConverter, ctx); 1438 1439 // Lower async coroutine operations to LLVM coroutine intrinsics. 1440 patterns.insert<CoroIdOpConversion, CoroBeginOpConversion, 1441 CoroFreeOpConversion, CoroEndOpConversion, 1442 CoroSaveOpConversion, CoroSuspendOpConversion>(converter, 1443 ctx); 1444 1445 ConversionTarget target(*ctx); 1446 target.addLegalOp<ConstantOp>(); 1447 target.addLegalDialect<LLVM::LLVMDialect>(); 1448 1449 // All operations from Async dialect must be lowered to the runtime API and 1450 // LLVM intrinsics calls. 1451 target.addIllegalDialect<AsyncDialect>(); 1452 1453 // Add dynamic legality constraints to apply conversions defined above. 1454 target.addDynamicallyLegalOp<FuncOp>( 1455 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1456 target.addDynamicallyLegalOp<ReturnOp>( 1457 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1458 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1459 return converter.isSignatureLegal(op.getCalleeType()); 1460 }); 1461 1462 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1463 signalPassFailure(); 1464 } 1465 1466 //===----------------------------------------------------------------------===// 1467 // Patterns for structural type conversions for the Async dialect operations. 1468 //===----------------------------------------------------------------------===// 1469 1470 namespace { 1471 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1472 public: 1473 using OpConversionPattern::OpConversionPattern; 1474 LogicalResult 1475 matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands, 1476 ConversionPatternRewriter &rewriter) const override { 1477 ExecuteOp newOp = 1478 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1479 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1480 newOp.getRegion().end()); 1481 1482 // Set operands and update block argument and result types. 1483 newOp->setOperands(operands); 1484 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1485 return failure(); 1486 for (auto result : newOp.getResults()) 1487 result.setType(typeConverter->convertType(result.getType())); 1488 1489 rewriter.replaceOp(op, newOp.getResults()); 1490 return success(); 1491 } 1492 }; 1493 1494 // Dummy pattern to trigger the appropriate type conversion / materialization. 1495 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1496 public: 1497 using OpConversionPattern::OpConversionPattern; 1498 LogicalResult 1499 matchAndRewrite(AwaitOp op, ArrayRef<Value> operands, 1500 ConversionPatternRewriter &rewriter) const override { 1501 rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front()); 1502 return success(); 1503 } 1504 }; 1505 1506 // Dummy pattern to trigger the appropriate type conversion / materialization. 1507 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1508 public: 1509 using OpConversionPattern::OpConversionPattern; 1510 LogicalResult 1511 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 1512 ConversionPatternRewriter &rewriter) const override { 1513 rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands); 1514 return success(); 1515 } 1516 }; 1517 } // namespace 1518 1519 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1520 return std::make_unique<ConvertAsyncToLLVMPass>(); 1521 } 1522 1523 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1524 MLIRContext *context, TypeConverter &typeConverter, 1525 OwningRewritePatternList &patterns, ConversionTarget &target) { 1526 typeConverter.addConversion([&](TokenType type) { return type; }); 1527 typeConverter.addConversion([&](ValueType type) { 1528 return ValueType::get(typeConverter.convertType(type.getValueType())); 1529 }); 1530 1531 patterns 1532 .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1533 typeConverter, context); 1534 1535 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1536 [&](Operation *op) { return typeConverter.isLegal(op); }); 1537 } 1538