1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Async/IR/Async.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/RegionUtils.h" 23 #include "llvm/ADT/SetVector.h" 24 #include "llvm/Support/FormatVariadic.h" 25 26 #define DEBUG_TYPE "convert-async-to-llvm" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 // Prefix for functions outlined from `async.execute` op regions. 32 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 33 34 //===----------------------------------------------------------------------===// 35 // Async Runtime C API declaration. 36 //===----------------------------------------------------------------------===// 37 38 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 39 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 40 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 41 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 42 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 43 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 44 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 45 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 46 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 47 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 48 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 49 static constexpr const char *kGetValueStorage = 50 "mlirAsyncRuntimeGetValueStorage"; 51 static constexpr const char *kAddTokenToGroup = 52 "mlirAsyncRuntimeAddTokenToGroup"; 53 static constexpr const char *kAwaitTokenAndExecute = 54 "mlirAsyncRuntimeAwaitTokenAndExecute"; 55 static constexpr const char *kAwaitValueAndExecute = 56 "mlirAsyncRuntimeAwaitValueAndExecute"; 57 static constexpr const char *kAwaitAllAndExecute = 58 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 59 60 namespace { 61 /// Async Runtime API function types. 62 /// 63 /// Because we can't create API function signature for type parametrized 64 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 65 /// lowering all async data types become opaque pointers at runtime. 66 struct AsyncAPI { 67 // All async types are lowered to opaque i8* LLVM pointers at runtime. 68 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 69 return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); 70 } 71 72 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 73 auto ref = opaquePointerType(ctx); 74 auto count = IntegerType::get(ctx, 32); 75 return FunctionType::get(ctx, {ref, count}, {}); 76 } 77 78 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 79 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 80 } 81 82 static FunctionType createValueFunctionType(MLIRContext *ctx) { 83 auto i32 = IntegerType::get(ctx, 32); 84 auto value = opaquePointerType(ctx); 85 return FunctionType::get(ctx, {i32}, {value}); 86 } 87 88 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 89 return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); 90 } 91 92 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 93 auto value = opaquePointerType(ctx); 94 auto storage = opaquePointerType(ctx); 95 return FunctionType::get(ctx, {value}, {storage}); 96 } 97 98 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 99 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 100 } 101 102 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 103 auto value = opaquePointerType(ctx); 104 return FunctionType::get(ctx, {value}, {}); 105 } 106 107 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 108 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 109 } 110 111 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 112 auto value = opaquePointerType(ctx); 113 return FunctionType::get(ctx, {value}, {}); 114 } 115 116 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 117 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 118 } 119 120 static FunctionType executeFunctionType(MLIRContext *ctx) { 121 auto hdl = opaquePointerType(ctx); 122 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 123 return FunctionType::get(ctx, {hdl, resume}, {}); 124 } 125 126 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 127 auto i64 = IntegerType::get(ctx, 64); 128 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 129 {i64}); 130 } 131 132 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 133 auto hdl = opaquePointerType(ctx); 134 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 135 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 136 } 137 138 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 139 auto value = opaquePointerType(ctx); 140 auto hdl = opaquePointerType(ctx); 141 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 142 return FunctionType::get(ctx, {value, hdl, resume}, {}); 143 } 144 145 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 146 auto hdl = opaquePointerType(ctx); 147 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 148 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 149 } 150 151 // Auxiliary coroutine resume intrinsic wrapper. 152 static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { 153 auto voidTy = LLVM::LLVMVoidType::get(ctx); 154 auto i8Ptr = opaquePointerType(ctx); 155 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 156 } 157 }; 158 } // namespace 159 160 /// Adds Async Runtime C API declarations to the module. 161 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 162 auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), 163 module.getBody()); 164 165 auto addFuncDecl = [&](StringRef name, FunctionType type) { 166 if (module.lookupSymbol(name)) 167 return; 168 builder.create<FuncOp>(name, type).setPrivate(); 169 }; 170 171 MLIRContext *ctx = module.getContext(); 172 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 173 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 174 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 175 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 176 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 177 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 178 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 179 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 180 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 181 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 182 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 183 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 184 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 185 addFuncDecl(kAwaitTokenAndExecute, 186 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 187 addFuncDecl(kAwaitValueAndExecute, 188 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 189 addFuncDecl(kAwaitAllAndExecute, 190 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 191 } 192 193 //===----------------------------------------------------------------------===// 194 // LLVM coroutines intrinsics declarations. 195 //===----------------------------------------------------------------------===// 196 197 static constexpr const char *kCoroId = "llvm.coro.id"; 198 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 199 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 200 static constexpr const char *kCoroSave = "llvm.coro.save"; 201 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 202 static constexpr const char *kCoroEnd = "llvm.coro.end"; 203 static constexpr const char *kCoroFree = "llvm.coro.free"; 204 static constexpr const char *kCoroResume = "llvm.coro.resume"; 205 206 /// Adds an LLVM function declaration to a module. 207 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 208 StringRef name, LLVM::LLVMType ret, 209 ArrayRef<LLVM::LLVMType> params) { 210 if (module.lookupSymbol(name)) 211 return; 212 LLVM::LLVMType type = LLVM::LLVMFunctionType::get(ret, params); 213 builder.create<LLVM::LLVMFuncOp>(name, type); 214 } 215 216 /// Adds coroutine intrinsics declarations to the module. 217 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 218 using namespace mlir::LLVM; 219 220 MLIRContext *ctx = module.getContext(); 221 ImplicitLocOpBuilder builder(module.getLoc(), 222 module.getBody()->getTerminator()); 223 224 auto token = LLVMTokenType::get(ctx); 225 auto voidTy = LLVMVoidType::get(ctx); 226 227 auto i8 = LLVMIntegerType::get(ctx, 8); 228 auto i1 = LLVMIntegerType::get(ctx, 1); 229 auto i32 = LLVMIntegerType::get(ctx, 32); 230 auto i64 = LLVMIntegerType::get(ctx, 64); 231 auto i8Ptr = LLVMPointerType::get(i8); 232 233 addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); 234 addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); 235 addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); 236 addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); 237 addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); 238 addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); 239 addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); 240 addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); 241 } 242 243 //===----------------------------------------------------------------------===// 244 // Add malloc/free declarations to the module. 245 //===----------------------------------------------------------------------===// 246 247 static constexpr const char *kMalloc = "malloc"; 248 static constexpr const char *kFree = "free"; 249 250 /// Adds malloc/free declarations to the module. 251 static void addCRuntimeDeclarations(ModuleOp module) { 252 using namespace mlir::LLVM; 253 254 MLIRContext *ctx = module.getContext(); 255 ImplicitLocOpBuilder builder(module.getLoc(), 256 module.getBody()->getTerminator()); 257 258 auto voidTy = LLVMVoidType::get(ctx); 259 auto i64 = LLVMIntegerType::get(ctx, 64); 260 auto i8Ptr = LLVMPointerType::get(LLVMIntegerType::get(ctx, 8)); 261 262 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 263 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 264 } 265 266 //===----------------------------------------------------------------------===// 267 // Coroutine resume function wrapper. 268 //===----------------------------------------------------------------------===// 269 270 static constexpr const char *kResume = "__resume"; 271 272 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 273 /// intrinsics. We need this function to be able to pass it to the async 274 /// runtime execute API. 275 static void addResumeFunction(ModuleOp module) { 276 MLIRContext *ctx = module.getContext(); 277 278 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 279 Location loc = module.getLoc(); 280 281 if (module.lookupSymbol(kResume)) 282 return; 283 284 auto voidTy = LLVM::LLVMVoidType::get(ctx); 285 auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); 286 287 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 288 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 289 resumeOp.setPrivate(); 290 291 auto *block = resumeOp.addEntryBlock(); 292 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 293 294 blockBuilder.create<LLVM::CallOp>(TypeRange(), 295 blockBuilder.getSymbolRefAttr(kCoroResume), 296 resumeOp.getArgument(0)); 297 298 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // async.execute op outlining to the coroutine functions. 303 //===----------------------------------------------------------------------===// 304 305 /// Function targeted for coroutine transformation has two additional blocks at 306 /// the end: coroutine cleanup and coroutine suspension. 307 /// 308 /// async.await op lowering additionaly creates a resume block for each 309 /// operation to enable non-blocking waiting via coroutine suspension. 310 namespace { 311 struct CoroMachinery { 312 // Async execute region returns a completion token, and an async value for 313 // each yielded value. 314 // 315 // %token, %result = async.execute -> !async.value<T> { 316 // %0 = constant ... : T 317 // async.yield %0 : T 318 // } 319 Value asyncToken; // token representing completion of the async region 320 llvm::SmallVector<Value, 4> returnValues; // returned async values 321 322 Value coroHandle; 323 Block *cleanup; 324 Block *suspend; 325 }; 326 } // namespace 327 328 /// Builds an coroutine template compatible with LLVM coroutines lowering. 329 /// 330 /// - `entry` block sets up the coroutine. 331 /// - `cleanup` block cleans up the coroutine state. 332 /// - `suspend block after the @llvm.coro.end() defines what value will be 333 /// returned to the initial caller of a coroutine. Everything before the 334 /// @llvm.coro.end() will be executed at every suspension point. 335 /// 336 /// Coroutine structure (only the important bits): 337 /// 338 /// func @async_execute_fn(<function-arguments>) 339 /// -> (!async.token, !async.value<T>) 340 /// { 341 /// ^entryBlock(<function-arguments>): 342 /// %token = <async token> : !async.token // create async runtime token 343 /// %value = <async value> : !async.value<T> // create async value 344 /// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle 345 /// br ^cleanup 346 /// 347 /// ^cleanup: 348 /// llvm.call @llvm.coro.free(...) // delete coroutine state 349 /// br ^suspend 350 /// 351 /// ^suspend: 352 /// llvm.call @llvm.coro.end(...) // marks the end of a coroutine 353 /// return %token, %value : !async.token, !async.value<T> 354 /// } 355 /// 356 /// The actual code for the async.execute operation body region will be inserted 357 /// before the entry block terminator. 358 /// 359 /// 360 static CoroMachinery setupCoroMachinery(FuncOp func) { 361 assert(func.getBody().empty() && "Function must have empty body"); 362 363 MLIRContext *ctx = func.getContext(); 364 365 auto token = LLVM::LLVMTokenType::get(ctx); 366 auto i1 = LLVM::LLVMIntegerType::get(ctx, 1); 367 auto i32 = LLVM::LLVMIntegerType::get(ctx, 32); 368 auto i64 = LLVM::LLVMIntegerType::get(ctx, 64); 369 auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); 370 371 Block *entryBlock = func.addEntryBlock(); 372 Location loc = func.getBody().getLoc(); 373 374 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock); 375 376 // ------------------------------------------------------------------------ // 377 // Allocate async tokens/values that we will return from a ramp function. 378 // ------------------------------------------------------------------------ // 379 auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx)); 380 381 // Async value operands and results must be convertible to LLVM types. This is 382 // verified before the function outlining. 383 LLVMTypeConverter converter(ctx); 384 385 // Returns the size requirements for the async value storage. 386 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 387 auto sizeOf = [&](ValueType valueType) -> Value { 388 auto storedType = converter.convertType(valueType.getValueType()); 389 auto storagePtrType = 390 LLVM::LLVMPointerType::get(storedType.cast<LLVM::LLVMType>()); 391 392 // %Size = getelementptr %T* null, int 1 393 // %SizeI = ptrtoint %T* %Size to i32 394 auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType); 395 auto one = builder.create<LLVM::ConstantOp>(loc, i32, 396 builder.getI32IntegerAttr(1)); 397 auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 398 one.getResult()); 399 auto size = builder.create<LLVM::PtrToIntOp>(loc, i32, gep); 400 401 // Cast to std type because runtime API defined using std types. 402 return builder.create<LLVM::DialectCastOp>(loc, builder.getI32Type(), 403 size.getResult()); 404 }; 405 406 // We use the `async.value` type as a return type although it does not match 407 // the `kCreateValue` function signature, because it will be later lowered to 408 // the runtime type (opaque i8* pointer). 409 llvm::SmallVector<CallOp, 4> createValues; 410 for (auto resultType : func.getCallableResults().drop_front(1)) 411 createValues.emplace_back(builder.create<CallOp>( 412 loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>()))); 413 414 auto createdValues = llvm::map_range( 415 createValues, [](CallOp call) { return call.getResult(0); }); 416 llvm::SmallVector<Value, 4> returnValues(createdValues.begin(), 417 createdValues.end()); 418 419 // ------------------------------------------------------------------------ // 420 // Initialize coroutine: allocate frame, get coroutine handle. 421 // ------------------------------------------------------------------------ // 422 423 // Constants for initializing coroutine frame. 424 auto constZero = 425 builder.create<LLVM::ConstantOp>(i32, builder.getI32IntegerAttr(0)); 426 auto constFalse = 427 builder.create<LLVM::ConstantOp>(i1, builder.getBoolAttr(false)); 428 auto nullPtr = builder.create<LLVM::NullOp>(i8Ptr); 429 430 // Get coroutine id: @llvm.coro.id 431 auto coroId = builder.create<LLVM::CallOp>( 432 token, builder.getSymbolRefAttr(kCoroId), 433 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 434 435 // Get coroutine frame size: @llvm.coro.size.i64 436 auto coroSize = builder.create<LLVM::CallOp>( 437 i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); 438 439 // Allocate memory for coroutine frame. 440 auto coroAlloc = 441 builder.create<LLVM::CallOp>(i8Ptr, builder.getSymbolRefAttr(kMalloc), 442 ValueRange(coroSize.getResult(0))); 443 444 // Begin a coroutine: @llvm.coro.begin 445 auto coroHdl = builder.create<LLVM::CallOp>( 446 i8Ptr, builder.getSymbolRefAttr(kCoroBegin), 447 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); 448 449 Block *cleanupBlock = func.addBlock(); 450 Block *suspendBlock = func.addBlock(); 451 452 // ------------------------------------------------------------------------ // 453 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 454 // ------------------------------------------------------------------------ // 455 builder.setInsertionPointToStart(cleanupBlock); 456 457 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 458 auto coroMem = builder.create<LLVM::CallOp>( 459 i8Ptr, builder.getSymbolRefAttr(kCoroFree), 460 ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); 461 462 // Free the memory. 463 builder.create<LLVM::CallOp>(TypeRange(), builder.getSymbolRefAttr(kFree), 464 ValueRange(coroMem.getResult(0))); 465 // Branch into the suspend block. 466 builder.create<BranchOp>(suspendBlock); 467 468 // ------------------------------------------------------------------------ // 469 // Coroutine suspend block: mark the end of a coroutine and return allocated 470 // async token. 471 // ------------------------------------------------------------------------ // 472 builder.setInsertionPointToStart(suspendBlock); 473 474 // Mark the end of a coroutine: @llvm.coro.end. 475 builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd), 476 ValueRange({coroHdl.getResult(0), constFalse})); 477 478 // Return created `async.token` and `async.values` from the suspend block. 479 // This will be the return value of a coroutine ramp function. 480 SmallVector<Value, 4> ret{createToken.getResult(0)}; 481 ret.insert(ret.end(), returnValues.begin(), returnValues.end()); 482 builder.create<ReturnOp>(loc, ret); 483 484 // Branch from the entry block to the cleanup block to create a valid CFG. 485 builder.setInsertionPointToEnd(entryBlock); 486 487 builder.create<BranchOp>(cleanupBlock); 488 489 // `async.await` op lowering will create resume blocks for async 490 // continuations, and will conditionally branch to cleanup or suspend blocks. 491 492 CoroMachinery machinery; 493 machinery.asyncToken = createToken.getResult(0); 494 machinery.returnValues = returnValues; 495 machinery.coroHandle = coroHdl.getResult(0); 496 machinery.cleanup = cleanupBlock; 497 machinery.suspend = suspendBlock; 498 return machinery; 499 } 500 501 /// Add a LLVM coroutine suspension point to the end of suspended block, to 502 /// resume execution in resume block. The caller is responsible for creating the 503 /// two suspended/resume blocks with the desired ops contained in each block. 504 /// This function merely provides the required control flow logic. 505 /// 506 /// `coroState` must be a value returned from the call to @llvm.coro.save(...) 507 /// intrinsic (saved coroutine state). 508 /// 509 /// Before: 510 /// 511 /// ^bb0: 512 /// "opBefore"(...) 513 /// "op"(...) 514 /// ^cleanup: ... 515 /// ^suspend: ... 516 /// ^resume: 517 /// "op"(...) 518 /// 519 /// After: 520 /// 521 /// ^bb0: 522 /// "opBefore"(...) 523 /// %suspend = llmv.call @llvm.coro.suspend(...) 524 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 525 /// ^resume: 526 /// "op"(...) 527 /// ^cleanup: ... 528 /// ^suspend: ... 529 /// 530 static void addSuspensionPoint(CoroMachinery coro, Value coroState, 531 Operation *op, Block *suspended, Block *resume, 532 OpBuilder &builder) { 533 Location loc = op->getLoc(); 534 MLIRContext *ctx = op->getContext(); 535 auto i1 = LLVM::LLVMIntegerType::get(ctx, 1); 536 auto i8 = LLVM::LLVMIntegerType::get(ctx, 8); 537 538 // Add a coroutine suspension in place of original `op` in the split block. 539 OpBuilder::InsertionGuard guard(builder); 540 builder.setInsertionPointToEnd(suspended); 541 542 auto constFalse = 543 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 544 545 // Suspend a coroutine: @llvm.coro.suspend 546 auto coroSuspend = builder.create<LLVM::CallOp>( 547 loc, i8, builder.getSymbolRefAttr(kCoroSuspend), 548 ValueRange({coroState, constFalse})); 549 550 // After a suspension point decide if we should branch into resume, cleanup 551 // or suspend block of the coroutine (see @llvm.coro.suspend return code 552 // documentation). 553 auto constZero = 554 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0)); 555 auto constNegOne = 556 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1)); 557 558 Block *resumeOrCleanup = builder.createBlock(resume); 559 560 // Suspend the coroutine ...? 561 builder.setInsertionPointToEnd(suspended); 562 auto isNegOne = builder.create<LLVM::ICmpOp>( 563 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); 564 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend, 565 /*falseDest=*/resumeOrCleanup); 566 567 // ... or resume or cleanup the coroutine? 568 builder.setInsertionPointToStart(resumeOrCleanup); 569 auto isZero = builder.create<LLVM::ICmpOp>( 570 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); 571 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume, 572 /*falseDest=*/coro.cleanup); 573 } 574 575 /// Outline the body region attached to the `async.execute` op into a standalone 576 /// function. 577 /// 578 /// Note that this is not reversible transformation. 579 static std::pair<FuncOp, CoroMachinery> 580 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 581 ModuleOp module = execute->getParentOfType<ModuleOp>(); 582 583 MLIRContext *ctx = module.getContext(); 584 Location loc = execute.getLoc(); 585 586 // Collect all outlined function inputs. 587 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 588 execute.dependencies().end()); 589 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 590 getUsedValuesDefinedAbove(execute.body(), functionInputs); 591 592 // Collect types for the outlined function inputs and outputs. 593 auto typesRange = llvm::map_range( 594 functionInputs, [](Value value) { return value.getType(); }); 595 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 596 auto outputTypes = execute.getResultTypes(); 597 598 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 599 auto funcAttrs = ArrayRef<NamedAttribute>(); 600 601 // TODO: Derive outlined function name from the parent FuncOp (support 602 // multiple nested async.execute operations). 603 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 604 symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); 605 606 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 607 608 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 609 // blocks, adding llvm.coro instrinsics and setting up control flow. 610 CoroMachinery coro = setupCoroMachinery(func); 611 612 // Suspend async function at the end of an entry block, and resume it using 613 // Async execute API (execution will be resumed in a thread managed by the 614 // async runtime). 615 Block *entryBlock = &func.getBlocks().front(); 616 auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); 617 618 // A pointer to coroutine resume intrinsic wrapper. 619 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 620 auto resumePtr = builder.create<LLVM::AddressOfOp>( 621 LLVM::LLVMPointerType::get(resumeFnTy), kResume); 622 623 // Save the coroutine state: @llvm.coro.save 624 auto coroSave = builder.create<LLVM::CallOp>( 625 LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 626 ValueRange({coro.coroHandle})); 627 628 // Call async runtime API to execute a coroutine in the managed thread. 629 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()}; 630 builder.create<CallOp>(TypeRange(), kExecute, executeArgs); 631 632 // Split the entry block before the terminator. 633 auto *terminatorOp = entryBlock->getTerminator(); 634 Block *suspended = terminatorOp->getBlock(); 635 Block *resume = suspended->splitBlock(terminatorOp); 636 addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended, 637 resume, builder); 638 639 size_t numDependencies = execute.dependencies().size(); 640 size_t numOperands = execute.operands().size(); 641 642 // Await on all dependencies before starting to execute the body region. 643 builder.setInsertionPointToStart(resume); 644 for (size_t i = 0; i < numDependencies; ++i) 645 builder.create<AwaitOp>(func.getArgument(i)); 646 647 // Await on all async value operands and unwrap the payload. 648 SmallVector<Value, 4> unwrappedOperands(numOperands); 649 for (size_t i = 0; i < numOperands; ++i) { 650 Value operand = func.getArgument(numDependencies + i); 651 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 652 } 653 654 // Map from function inputs defined above the execute op to the function 655 // arguments. 656 BlockAndValueMapping valueMapping; 657 valueMapping.map(functionInputs, func.getArguments()); 658 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 659 660 // Clone all operations from the execute operation body into the outlined 661 // function body. 662 for (Operation &op : execute.body().getOps()) 663 builder.clone(op, valueMapping); 664 665 // Replace the original `async.execute` with a call to outlined function. 666 ImplicitLocOpBuilder callBuilder(loc, execute); 667 auto callOutlinedFunc = callBuilder.create<CallOp>( 668 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 669 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 670 execute.erase(); 671 672 return {func, coro}; 673 } 674 675 //===----------------------------------------------------------------------===// 676 // Convert Async dialect types to LLVM types. 677 //===----------------------------------------------------------------------===// 678 679 namespace { 680 681 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 682 /// their runtime type (opaque pointers) and does not convert any other types. 683 class AsyncRuntimeTypeConverter : public TypeConverter { 684 public: 685 AsyncRuntimeTypeConverter() { 686 addConversion([](Type type) { return type; }); 687 addConversion(convertAsyncTypes); 688 } 689 690 static Optional<Type> convertAsyncTypes(Type type) { 691 if (type.isa<TokenType, GroupType, ValueType>()) 692 return AsyncAPI::opaquePointerType(type.getContext()); 693 return llvm::None; 694 } 695 }; 696 } // namespace 697 698 //===----------------------------------------------------------------------===// 699 // Convert return operations that return async values from async regions. 700 //===----------------------------------------------------------------------===// 701 702 namespace { 703 class ReturnOpOpConversion : public ConversionPattern { 704 public: 705 explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx) 706 : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {} 707 708 LogicalResult 709 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 710 ConversionPatternRewriter &rewriter) const override { 711 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 712 return success(); 713 } 714 }; 715 } // namespace 716 717 //===----------------------------------------------------------------------===// 718 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` 719 // to the corresponding API calls). 720 //===----------------------------------------------------------------------===// 721 722 namespace { 723 724 template <typename RefCountingOp> 725 class RefCountingOpLowering : public ConversionPattern { 726 public: 727 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 728 StringRef apiFunctionName) 729 : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx), 730 apiFunctionName(apiFunctionName) {} 731 732 LogicalResult 733 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 734 ConversionPatternRewriter &rewriter) const override { 735 RefCountingOp refCountingOp = cast<RefCountingOp>(op); 736 737 auto count = rewriter.create<ConstantOp>( 738 op->getLoc(), rewriter.getI32Type(), 739 rewriter.getI32IntegerAttr(refCountingOp.count())); 740 741 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 742 ValueRange({operands[0], count})); 743 744 return success(); 745 } 746 747 private: 748 StringRef apiFunctionName; 749 }; 750 751 /// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. 752 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> { 753 public: 754 explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 755 : RefCountingOpLowering(converter, ctx, kAddRef) {} 756 }; 757 758 /// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 759 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> { 760 public: 761 explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 762 : RefCountingOpLowering(converter, ctx, kDropRef) {} 763 }; 764 765 } // namespace 766 767 //===----------------------------------------------------------------------===// 768 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 769 //===----------------------------------------------------------------------===// 770 771 namespace { 772 class CreateGroupOpLowering : public ConversionPattern { 773 public: 774 explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) 775 : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter, 776 ctx) {} 777 778 LogicalResult 779 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 780 ConversionPatternRewriter &rewriter) const override { 781 auto retTy = GroupType::get(op->getContext()); 782 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy); 783 return success(); 784 } 785 }; 786 } // namespace 787 788 //===----------------------------------------------------------------------===// 789 // async.add_to_group op lowering to runtime function call. 790 //===----------------------------------------------------------------------===// 791 792 namespace { 793 class AddToGroupOpLowering : public ConversionPattern { 794 public: 795 explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) 796 : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) { 797 } 798 799 LogicalResult 800 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 801 ConversionPatternRewriter &rewriter) const override { 802 // Currently we can only add tokens to the group. 803 auto addToGroup = cast<AddToGroupOp>(op); 804 if (!addToGroup.operand().getType().isa<TokenType>()) 805 return failure(); 806 807 auto i64 = IntegerType::get(op->getContext(), 64); 808 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands); 809 return success(); 810 } 811 }; 812 } // namespace 813 814 //===----------------------------------------------------------------------===// 815 // async.await and async.await_all op lowerings to the corresponding async 816 // runtime function calls. 817 //===----------------------------------------------------------------------===// 818 819 namespace { 820 821 template <typename AwaitType, typename AwaitableType> 822 class AwaitOpLoweringBase : public ConversionPattern { 823 protected: 824 explicit AwaitOpLoweringBase( 825 TypeConverter &converter, MLIRContext *ctx, 826 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions, 827 StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) 828 : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx), 829 outlinedFunctions(outlinedFunctions), 830 blockingAwaitFuncName(blockingAwaitFuncName), 831 coroAwaitFuncName(coroAwaitFuncName) {} 832 833 public: 834 LogicalResult 835 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 836 ConversionPatternRewriter &rewriter) const override { 837 // We can only await on one the `AwaitableType` (for `await` it can be 838 // a `token` or a `value`, for `await_all` it must be a `group`). 839 auto await = cast<AwaitType>(op); 840 if (!await.operand().getType().template isa<AwaitableType>()) 841 return failure(); 842 843 // Check if await operation is inside the outlined coroutine function. 844 auto func = await->template getParentOfType<FuncOp>(); 845 auto outlined = outlinedFunctions.find(func); 846 const bool isInCoroutine = outlined != outlinedFunctions.end(); 847 848 Location loc = op->getLoc(); 849 850 // Inside regular function we convert await operation to the blocking 851 // async API await function call. 852 if (!isInCoroutine) 853 rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName, 854 ValueRange(operands[0])); 855 856 // Inside the coroutine we convert await operation into coroutine suspension 857 // point, and resume execution asynchronously. 858 if (isInCoroutine) { 859 const CoroMachinery &coro = outlined->getSecond(); 860 861 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 862 MLIRContext *ctx = op->getContext(); 863 864 // A pointer to coroutine resume intrinsic wrapper. 865 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 866 auto resumePtr = builder.create<LLVM::AddressOfOp>( 867 LLVM::LLVMPointerType::get(resumeFnTy), kResume); 868 869 // Save the coroutine state: @llvm.coro.save 870 auto coroSave = builder.create<LLVM::CallOp>( 871 LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 872 ValueRange(coro.coroHandle)); 873 874 // Call async runtime API to resume a coroutine in the managed thread when 875 // the async await argument becomes ready. 876 SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle, 877 resumePtr.res()}; 878 builder.create<CallOp>(TypeRange(), coroAwaitFuncName, 879 awaitAndExecuteArgs); 880 881 Block *suspended = op->getBlock(); 882 883 // Split the entry block before the await operation. 884 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 885 addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume, 886 builder); 887 888 // Make sure that replacement value will be constructed in resume block. 889 rewriter.setInsertionPointToStart(resume); 890 } 891 892 // Replace or erase the await operation with the new value. 893 if (Value replaceWith = getReplacementValue(op, operands[0], rewriter)) 894 rewriter.replaceOp(op, replaceWith); 895 else 896 rewriter.eraseOp(op); 897 898 return success(); 899 } 900 901 virtual Value getReplacementValue(Operation *op, Value operand, 902 ConversionPatternRewriter &rewriter) const { 903 return Value(); 904 } 905 906 private: 907 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 908 StringRef blockingAwaitFuncName; 909 StringRef coroAwaitFuncName; 910 }; 911 912 /// Lowering for `async.await` with a token operand. 913 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 914 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 915 916 public: 917 explicit AwaitTokenOpLowering( 918 TypeConverter &converter, MLIRContext *ctx, 919 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 920 : Base(converter, ctx, outlinedFunctions, kAwaitToken, 921 kAwaitTokenAndExecute) {} 922 }; 923 924 /// Lowering for `async.await` with a value operand. 925 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 926 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 927 928 public: 929 explicit AwaitValueOpLowering( 930 TypeConverter &converter, MLIRContext *ctx, 931 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 932 : Base(converter, ctx, outlinedFunctions, kAwaitValue, 933 kAwaitValueAndExecute) {} 934 935 Value 936 getReplacementValue(Operation *op, Value operand, 937 ConversionPatternRewriter &rewriter) const override { 938 Location loc = op->getLoc(); 939 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 940 941 // Get the underlying value type from the `async.value`. 942 auto await = cast<AwaitOp>(op); 943 auto valueType = await.operand().getType().cast<ValueType>().getValueType(); 944 945 // Get a pointer to an async value storage from the runtime. 946 auto storage = rewriter.create<CallOp>(loc, kGetValueStorage, 947 TypeRange(i8Ptr), operand); 948 949 // Cast from i8* to the pointer pointer to LLVM type. 950 auto llvmValueType = getTypeConverter()->convertType(valueType); 951 auto castedStorage = rewriter.create<LLVM::BitcastOp>( 952 loc, LLVM::LLVMPointerType::get(llvmValueType.cast<LLVM::LLVMType>()), 953 storage.getResult(0)); 954 955 // Load from the async value storage. 956 auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult()); 957 958 // Cast from LLVM type to the expected value type. This cast will become 959 // no-op after lowering to LLVM. 960 return rewriter.create<LLVM::DialectCastOp>(loc, valueType, loaded); 961 } 962 }; 963 964 /// Lowering for `async.await_all` operation. 965 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 966 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 967 968 public: 969 explicit AwaitAllOpLowering( 970 TypeConverter &converter, MLIRContext *ctx, 971 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 972 : Base(converter, ctx, outlinedFunctions, kAwaitGroup, 973 kAwaitAllAndExecute) {} 974 }; 975 976 } // namespace 977 978 //===----------------------------------------------------------------------===// 979 // async.yield op lowerings to the corresponding async runtime function calls. 980 //===----------------------------------------------------------------------===// 981 982 class YieldOpLowering : public ConversionPattern { 983 public: 984 explicit YieldOpLowering( 985 TypeConverter &converter, MLIRContext *ctx, 986 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 987 : ConversionPattern(async::YieldOp::getOperationName(), 1, converter, 988 ctx), 989 outlinedFunctions(outlinedFunctions) {} 990 991 LogicalResult 992 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 993 ConversionPatternRewriter &rewriter) const override { 994 // Check if yield operation is inside the outlined coroutine function. 995 auto func = op->template getParentOfType<FuncOp>(); 996 auto outlined = outlinedFunctions.find(func); 997 if (outlined == outlinedFunctions.end()) 998 return op->emitOpError( 999 "async.yield is not inside the outlined coroutine function"); 1000 1001 Location loc = op->getLoc(); 1002 const CoroMachinery &coro = outlined->getSecond(); 1003 1004 // Store yielded values into the async values storage and emplace them. 1005 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 1006 1007 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 1008 // Store `yieldValue` into the `asyncValue` storage. 1009 Value yieldValue = std::get<0>(tuple); 1010 Value asyncValue = std::get<1>(tuple); 1011 1012 // Get an opaque i8* pointer to an async value storage from the runtime. 1013 auto storage = rewriter.create<CallOp>(loc, kGetValueStorage, 1014 TypeRange(i8Ptr), asyncValue); 1015 1016 // Cast storage pointer to the yielded value type. 1017 auto castedStorage = rewriter.create<LLVM::BitcastOp>( 1018 loc, 1019 LLVM::LLVMPointerType::get( 1020 yieldValue.getType().cast<LLVM::LLVMType>()), 1021 storage.getResult(0)); 1022 1023 // Store the yielded value into the async value storage. 1024 rewriter.create<LLVM::StoreOp>(loc, yieldValue, 1025 castedStorage.getResult()); 1026 1027 // Emplace the `async.value` to mark it ready. 1028 rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue); 1029 } 1030 1031 // Emplace the completion token to mark it ready. 1032 rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken); 1033 1034 // Original operation was replaced by the function call(s). 1035 rewriter.eraseOp(op); 1036 1037 return success(); 1038 } 1039 1040 private: 1041 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 1042 }; 1043 1044 //===----------------------------------------------------------------------===// 1045 1046 namespace { 1047 struct ConvertAsyncToLLVMPass 1048 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 1049 void runOnOperation() override; 1050 }; 1051 1052 void ConvertAsyncToLLVMPass::runOnOperation() { 1053 ModuleOp module = getOperation(); 1054 SymbolTable symbolTable(module); 1055 1056 MLIRContext *ctx = &getContext(); 1057 1058 // Outline all `async.execute` body regions into async functions (coroutines). 1059 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 1060 1061 // We use conversion to LLVM type to ensure that all `async.value` operands 1062 // and results can be lowered to LLVM load and store operations. 1063 LLVMTypeConverter llvmConverter(ctx); 1064 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 1065 1066 // Returns true if the `async.value` payload is convertible to LLVM. 1067 auto isConvertibleToLlvm = [&](Type type) -> bool { 1068 auto valueType = type.cast<ValueType>().getValueType(); 1069 return static_cast<bool>(llvmConverter.convertType(valueType)); 1070 }; 1071 1072 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 1073 // All operands and results must be convertible to LLVM. 1074 if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) { 1075 execute.emitOpError("operands payload must be convertible to LLVM type"); 1076 return WalkResult::interrupt(); 1077 } 1078 if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) { 1079 execute.emitOpError("results payload must be convertible to LLVM type"); 1080 return WalkResult::interrupt(); 1081 } 1082 1083 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 1084 1085 return WalkResult::advance(); 1086 }); 1087 1088 // Failed to outline all async execute operations. 1089 if (outlineResult.wasInterrupted()) { 1090 signalPassFailure(); 1091 return; 1092 } 1093 1094 LLVM_DEBUG({ 1095 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 1096 << " async functions\n"; 1097 }); 1098 1099 // Add declarations for all functions required by the coroutines lowering. 1100 addResumeFunction(module); 1101 addAsyncRuntimeApiDeclarations(module); 1102 addCoroutineIntrinsicsDeclarations(module); 1103 addCRuntimeDeclarations(module); 1104 1105 // Convert async dialect types and operations to LLVM dialect. 1106 AsyncRuntimeTypeConverter converter; 1107 OwningRewritePatternList patterns; 1108 1109 // Convert async types in function signatures and function calls. 1110 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 1111 populateCallOpTypeConversionPattern(patterns, ctx, converter); 1112 1113 // Convert return operations inside async.execute regions. 1114 patterns.insert<ReturnOpOpConversion>(converter, ctx); 1115 1116 // Lower async operations to async runtime API calls. 1117 patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx); 1118 patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx); 1119 1120 // Use LLVM type converter to automatically convert between the async value 1121 // payload type and LLVM type when loading/storing from/to the async 1122 // value storage which is an opaque i8* pointer using LLVM load/store ops. 1123 patterns 1124 .insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>( 1125 llvmConverter, ctx, outlinedFunctions); 1126 patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions); 1127 1128 ConversionTarget target(*ctx); 1129 target.addLegalOp<ConstantOp>(); 1130 target.addLegalDialect<LLVM::LLVMDialect>(); 1131 1132 // All operations from Async dialect must be lowered to the runtime API calls. 1133 target.addIllegalDialect<AsyncDialect>(); 1134 1135 // Add dynamic legality constraints to apply conversions defined above. 1136 target.addDynamicallyLegalOp<FuncOp>( 1137 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1138 target.addDynamicallyLegalOp<ReturnOp>( 1139 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1140 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1141 return converter.isSignatureLegal(op.getCalleeType()); 1142 }); 1143 1144 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1145 signalPassFailure(); 1146 } 1147 } // namespace 1148 1149 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1150 return std::make_unique<ConvertAsyncToLLVMPass>(); 1151 } 1152