1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Async/IR/Async.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/RegionUtils.h" 23 #include "llvm/ADT/SetVector.h" 24 #include "llvm/Support/FormatVariadic.h" 25 26 #define DEBUG_TYPE "convert-async-to-llvm" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 // Prefix for functions outlined from `async.execute` op regions. 32 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 33 34 //===----------------------------------------------------------------------===// 35 // Async Runtime C API declaration. 36 //===----------------------------------------------------------------------===// 37 38 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 39 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 40 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 41 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 42 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 43 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 44 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 45 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 46 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 47 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 48 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 49 static constexpr const char *kGetValueStorage = 50 "mlirAsyncRuntimeGetValueStorage"; 51 static constexpr const char *kAddTokenToGroup = 52 "mlirAsyncRuntimeAddTokenToGroup"; 53 static constexpr const char *kAwaitTokenAndExecute = 54 "mlirAsyncRuntimeAwaitTokenAndExecute"; 55 static constexpr const char *kAwaitValueAndExecute = 56 "mlirAsyncRuntimeAwaitValueAndExecute"; 57 static constexpr const char *kAwaitAllAndExecute = 58 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 59 60 namespace { 61 /// Async Runtime API function types. 62 /// 63 /// Because we can't create API function signature for type parametrized 64 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 65 /// lowering all async data types become opaque pointers at runtime. 66 struct AsyncAPI { 67 // All async types are lowered to opaque i8* LLVM pointers at runtime. 68 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 69 return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); 70 } 71 72 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 73 auto ref = opaquePointerType(ctx); 74 auto count = IntegerType::get(ctx, 32); 75 return FunctionType::get(ctx, {ref, count}, {}); 76 } 77 78 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 79 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 80 } 81 82 static FunctionType createValueFunctionType(MLIRContext *ctx) { 83 auto i32 = IntegerType::get(ctx, 32); 84 auto value = opaquePointerType(ctx); 85 return FunctionType::get(ctx, {i32}, {value}); 86 } 87 88 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 89 return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); 90 } 91 92 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 93 auto value = opaquePointerType(ctx); 94 auto storage = opaquePointerType(ctx); 95 return FunctionType::get(ctx, {value}, {storage}); 96 } 97 98 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 99 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 100 } 101 102 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 103 auto value = opaquePointerType(ctx); 104 return FunctionType::get(ctx, {value}, {}); 105 } 106 107 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 108 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 109 } 110 111 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 112 auto value = opaquePointerType(ctx); 113 return FunctionType::get(ctx, {value}, {}); 114 } 115 116 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 117 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 118 } 119 120 static FunctionType executeFunctionType(MLIRContext *ctx) { 121 auto hdl = opaquePointerType(ctx); 122 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 123 return FunctionType::get(ctx, {hdl, resume}, {}); 124 } 125 126 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 127 auto i64 = IntegerType::get(ctx, 64); 128 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 129 {i64}); 130 } 131 132 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 133 auto hdl = opaquePointerType(ctx); 134 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 135 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 136 } 137 138 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 139 auto value = opaquePointerType(ctx); 140 auto hdl = opaquePointerType(ctx); 141 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 142 return FunctionType::get(ctx, {value, hdl, resume}, {}); 143 } 144 145 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 146 auto hdl = opaquePointerType(ctx); 147 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 148 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 149 } 150 151 // Auxiliary coroutine resume intrinsic wrapper. 152 static Type resumeFunctionType(MLIRContext *ctx) { 153 auto voidTy = LLVM::LLVMVoidType::get(ctx); 154 auto i8Ptr = opaquePointerType(ctx); 155 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 156 } 157 }; 158 } // namespace 159 160 /// Adds Async Runtime C API declarations to the module. 161 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 162 auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), 163 module.getBody()); 164 165 auto addFuncDecl = [&](StringRef name, FunctionType type) { 166 if (module.lookupSymbol(name)) 167 return; 168 builder.create<FuncOp>(name, type).setPrivate(); 169 }; 170 171 MLIRContext *ctx = module.getContext(); 172 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 173 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 174 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 175 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 176 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 177 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 178 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 179 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 180 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 181 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 182 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 183 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 184 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 185 addFuncDecl(kAwaitTokenAndExecute, 186 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 187 addFuncDecl(kAwaitValueAndExecute, 188 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 189 addFuncDecl(kAwaitAllAndExecute, 190 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 191 } 192 193 //===----------------------------------------------------------------------===// 194 // LLVM coroutines intrinsics declarations. 195 //===----------------------------------------------------------------------===// 196 197 static constexpr const char *kCoroId = "llvm.coro.id"; 198 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 199 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 200 static constexpr const char *kCoroSave = "llvm.coro.save"; 201 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 202 static constexpr const char *kCoroEnd = "llvm.coro.end"; 203 static constexpr const char *kCoroFree = "llvm.coro.free"; 204 static constexpr const char *kCoroResume = "llvm.coro.resume"; 205 206 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 207 StringRef name, Type ret, ArrayRef<Type> params) { 208 if (module.lookupSymbol(name)) 209 return; 210 Type type = LLVM::LLVMFunctionType::get(ret, params); 211 builder.create<LLVM::LLVMFuncOp>(name, type); 212 } 213 214 /// Adds coroutine intrinsics declarations to the module. 215 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 216 using namespace mlir::LLVM; 217 218 MLIRContext *ctx = module.getContext(); 219 ImplicitLocOpBuilder builder(module.getLoc(), 220 module.getBody()->getTerminator()); 221 222 auto token = LLVMTokenType::get(ctx); 223 auto voidTy = LLVMVoidType::get(ctx); 224 225 auto i8 = LLVMIntegerType::get(ctx, 8); 226 auto i1 = LLVMIntegerType::get(ctx, 1); 227 auto i32 = LLVMIntegerType::get(ctx, 32); 228 auto i64 = LLVMIntegerType::get(ctx, 64); 229 auto i8Ptr = LLVMPointerType::get(i8); 230 231 addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); 232 addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); 233 addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); 234 addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); 235 addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); 236 addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); 237 addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); 238 addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // Add malloc/free declarations to the module. 243 //===----------------------------------------------------------------------===// 244 245 static constexpr const char *kMalloc = "malloc"; 246 static constexpr const char *kFree = "free"; 247 248 /// Adds malloc/free declarations to the module. 249 static void addCRuntimeDeclarations(ModuleOp module) { 250 using namespace mlir::LLVM; 251 252 MLIRContext *ctx = module.getContext(); 253 ImplicitLocOpBuilder builder(module.getLoc(), 254 module.getBody()->getTerminator()); 255 256 auto voidTy = LLVMVoidType::get(ctx); 257 auto i64 = LLVMIntegerType::get(ctx, 64); 258 auto i8Ptr = LLVMPointerType::get(LLVMIntegerType::get(ctx, 8)); 259 260 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 261 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 262 } 263 264 //===----------------------------------------------------------------------===// 265 // Coroutine resume function wrapper. 266 //===----------------------------------------------------------------------===// 267 268 static constexpr const char *kResume = "__resume"; 269 270 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 271 /// intrinsics. We need this function to be able to pass it to the async 272 /// runtime execute API. 273 static void addResumeFunction(ModuleOp module) { 274 MLIRContext *ctx = module.getContext(); 275 276 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 277 Location loc = module.getLoc(); 278 279 if (module.lookupSymbol(kResume)) 280 return; 281 282 auto voidTy = LLVM::LLVMVoidType::get(ctx); 283 auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); 284 285 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 286 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 287 resumeOp.setPrivate(); 288 289 auto *block = resumeOp.addEntryBlock(); 290 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 291 292 blockBuilder.create<LLVM::CallOp>(TypeRange(), 293 blockBuilder.getSymbolRefAttr(kCoroResume), 294 resumeOp.getArgument(0)); 295 296 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 297 } 298 299 //===----------------------------------------------------------------------===// 300 // async.execute op outlining to the coroutine functions. 301 //===----------------------------------------------------------------------===// 302 303 /// Function targeted for coroutine transformation has two additional blocks at 304 /// the end: coroutine cleanup and coroutine suspension. 305 /// 306 /// async.await op lowering additionaly creates a resume block for each 307 /// operation to enable non-blocking waiting via coroutine suspension. 308 namespace { 309 struct CoroMachinery { 310 // Async execute region returns a completion token, and an async value for 311 // each yielded value. 312 // 313 // %token, %result = async.execute -> !async.value<T> { 314 // %0 = constant ... : T 315 // async.yield %0 : T 316 // } 317 Value asyncToken; // token representing completion of the async region 318 llvm::SmallVector<Value, 4> returnValues; // returned async values 319 320 Value coroHandle; 321 Block *cleanup; 322 Block *suspend; 323 }; 324 } // namespace 325 326 /// Builds an coroutine template compatible with LLVM coroutines lowering. 327 /// 328 /// - `entry` block sets up the coroutine. 329 /// - `cleanup` block cleans up the coroutine state. 330 /// - `suspend block after the @llvm.coro.end() defines what value will be 331 /// returned to the initial caller of a coroutine. Everything before the 332 /// @llvm.coro.end() will be executed at every suspension point. 333 /// 334 /// Coroutine structure (only the important bits): 335 /// 336 /// func @async_execute_fn(<function-arguments>) 337 /// -> (!async.token, !async.value<T>) 338 /// { 339 /// ^entryBlock(<function-arguments>): 340 /// %token = <async token> : !async.token // create async runtime token 341 /// %value = <async value> : !async.value<T> // create async value 342 /// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle 343 /// br ^cleanup 344 /// 345 /// ^cleanup: 346 /// llvm.call @llvm.coro.free(...) // delete coroutine state 347 /// br ^suspend 348 /// 349 /// ^suspend: 350 /// llvm.call @llvm.coro.end(...) // marks the end of a coroutine 351 /// return %token, %value : !async.token, !async.value<T> 352 /// } 353 /// 354 /// The actual code for the async.execute operation body region will be inserted 355 /// before the entry block terminator. 356 /// 357 /// 358 static CoroMachinery setupCoroMachinery(FuncOp func) { 359 assert(func.getBody().empty() && "Function must have empty body"); 360 361 MLIRContext *ctx = func.getContext(); 362 363 auto token = LLVM::LLVMTokenType::get(ctx); 364 auto i1 = LLVM::LLVMIntegerType::get(ctx, 1); 365 auto i32 = LLVM::LLVMIntegerType::get(ctx, 32); 366 auto i64 = LLVM::LLVMIntegerType::get(ctx, 64); 367 auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); 368 369 Block *entryBlock = func.addEntryBlock(); 370 Location loc = func.getBody().getLoc(); 371 372 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock); 373 374 // ------------------------------------------------------------------------ // 375 // Allocate async tokens/values that we will return from a ramp function. 376 // ------------------------------------------------------------------------ // 377 auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx)); 378 379 // Async value operands and results must be convertible to LLVM types. This is 380 // verified before the function outlining. 381 LLVMTypeConverter converter(ctx); 382 383 // Returns the size requirements for the async value storage. 384 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 385 auto sizeOf = [&](ValueType valueType) -> Value { 386 auto storedType = converter.convertType(valueType.getValueType()); 387 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 388 389 // %Size = getelementptr %T* null, int 1 390 // %SizeI = ptrtoint %T* %Size to i32 391 auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType); 392 auto one = builder.create<LLVM::ConstantOp>(loc, i32, 393 builder.getI32IntegerAttr(1)); 394 auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 395 one.getResult()); 396 auto size = builder.create<LLVM::PtrToIntOp>(loc, i32, gep); 397 398 // Cast to std type because runtime API defined using std types. 399 return builder.create<LLVM::DialectCastOp>(loc, builder.getI32Type(), 400 size.getResult()); 401 }; 402 403 // We use the `async.value` type as a return type although it does not match 404 // the `kCreateValue` function signature, because it will be later lowered to 405 // the runtime type (opaque i8* pointer). 406 llvm::SmallVector<CallOp, 4> createValues; 407 for (auto resultType : func.getCallableResults().drop_front(1)) 408 createValues.emplace_back(builder.create<CallOp>( 409 loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>()))); 410 411 auto createdValues = llvm::map_range( 412 createValues, [](CallOp call) { return call.getResult(0); }); 413 llvm::SmallVector<Value, 4> returnValues(createdValues.begin(), 414 createdValues.end()); 415 416 // ------------------------------------------------------------------------ // 417 // Initialize coroutine: allocate frame, get coroutine handle. 418 // ------------------------------------------------------------------------ // 419 420 // Constants for initializing coroutine frame. 421 auto constZero = 422 builder.create<LLVM::ConstantOp>(i32, builder.getI32IntegerAttr(0)); 423 auto constFalse = 424 builder.create<LLVM::ConstantOp>(i1, builder.getBoolAttr(false)); 425 auto nullPtr = builder.create<LLVM::NullOp>(i8Ptr); 426 427 // Get coroutine id: @llvm.coro.id 428 auto coroId = builder.create<LLVM::CallOp>( 429 token, builder.getSymbolRefAttr(kCoroId), 430 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 431 432 // Get coroutine frame size: @llvm.coro.size.i64 433 auto coroSize = builder.create<LLVM::CallOp>( 434 i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); 435 436 // Allocate memory for coroutine frame. 437 auto coroAlloc = 438 builder.create<LLVM::CallOp>(i8Ptr, builder.getSymbolRefAttr(kMalloc), 439 ValueRange(coroSize.getResult(0))); 440 441 // Begin a coroutine: @llvm.coro.begin 442 auto coroHdl = builder.create<LLVM::CallOp>( 443 i8Ptr, builder.getSymbolRefAttr(kCoroBegin), 444 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); 445 446 Block *cleanupBlock = func.addBlock(); 447 Block *suspendBlock = func.addBlock(); 448 449 // ------------------------------------------------------------------------ // 450 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 451 // ------------------------------------------------------------------------ // 452 builder.setInsertionPointToStart(cleanupBlock); 453 454 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 455 auto coroMem = builder.create<LLVM::CallOp>( 456 i8Ptr, builder.getSymbolRefAttr(kCoroFree), 457 ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); 458 459 // Free the memory. 460 builder.create<LLVM::CallOp>(TypeRange(), builder.getSymbolRefAttr(kFree), 461 ValueRange(coroMem.getResult(0))); 462 // Branch into the suspend block. 463 builder.create<BranchOp>(suspendBlock); 464 465 // ------------------------------------------------------------------------ // 466 // Coroutine suspend block: mark the end of a coroutine and return allocated 467 // async token. 468 // ------------------------------------------------------------------------ // 469 builder.setInsertionPointToStart(suspendBlock); 470 471 // Mark the end of a coroutine: @llvm.coro.end. 472 builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd), 473 ValueRange({coroHdl.getResult(0), constFalse})); 474 475 // Return created `async.token` and `async.values` from the suspend block. 476 // This will be the return value of a coroutine ramp function. 477 SmallVector<Value, 4> ret{createToken.getResult(0)}; 478 ret.insert(ret.end(), returnValues.begin(), returnValues.end()); 479 builder.create<ReturnOp>(loc, ret); 480 481 // Branch from the entry block to the cleanup block to create a valid CFG. 482 builder.setInsertionPointToEnd(entryBlock); 483 484 builder.create<BranchOp>(cleanupBlock); 485 486 // `async.await` op lowering will create resume blocks for async 487 // continuations, and will conditionally branch to cleanup or suspend blocks. 488 489 CoroMachinery machinery; 490 machinery.asyncToken = createToken.getResult(0); 491 machinery.returnValues = returnValues; 492 machinery.coroHandle = coroHdl.getResult(0); 493 machinery.cleanup = cleanupBlock; 494 machinery.suspend = suspendBlock; 495 return machinery; 496 } 497 498 /// Add a LLVM coroutine suspension point to the end of suspended block, to 499 /// resume execution in resume block. The caller is responsible for creating the 500 /// two suspended/resume blocks with the desired ops contained in each block. 501 /// This function merely provides the required control flow logic. 502 /// 503 /// `coroState` must be a value returned from the call to @llvm.coro.save(...) 504 /// intrinsic (saved coroutine state). 505 /// 506 /// Before: 507 /// 508 /// ^bb0: 509 /// "opBefore"(...) 510 /// "op"(...) 511 /// ^cleanup: ... 512 /// ^suspend: ... 513 /// ^resume: 514 /// "op"(...) 515 /// 516 /// After: 517 /// 518 /// ^bb0: 519 /// "opBefore"(...) 520 /// %suspend = llmv.call @llvm.coro.suspend(...) 521 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 522 /// ^resume: 523 /// "op"(...) 524 /// ^cleanup: ... 525 /// ^suspend: ... 526 /// 527 static void addSuspensionPoint(CoroMachinery coro, Value coroState, 528 Operation *op, Block *suspended, Block *resume, 529 OpBuilder &builder) { 530 Location loc = op->getLoc(); 531 MLIRContext *ctx = op->getContext(); 532 auto i1 = LLVM::LLVMIntegerType::get(ctx, 1); 533 auto i8 = LLVM::LLVMIntegerType::get(ctx, 8); 534 535 // Add a coroutine suspension in place of original `op` in the split block. 536 OpBuilder::InsertionGuard guard(builder); 537 builder.setInsertionPointToEnd(suspended); 538 539 auto constFalse = 540 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 541 542 // Suspend a coroutine: @llvm.coro.suspend 543 auto coroSuspend = builder.create<LLVM::CallOp>( 544 loc, i8, builder.getSymbolRefAttr(kCoroSuspend), 545 ValueRange({coroState, constFalse})); 546 547 // After a suspension point decide if we should branch into resume, cleanup 548 // or suspend block of the coroutine (see @llvm.coro.suspend return code 549 // documentation). 550 auto constZero = 551 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0)); 552 auto constNegOne = 553 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1)); 554 555 Block *resumeOrCleanup = builder.createBlock(resume); 556 557 // Suspend the coroutine ...? 558 builder.setInsertionPointToEnd(suspended); 559 auto isNegOne = builder.create<LLVM::ICmpOp>( 560 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); 561 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend, 562 /*falseDest=*/resumeOrCleanup); 563 564 // ... or resume or cleanup the coroutine? 565 builder.setInsertionPointToStart(resumeOrCleanup); 566 auto isZero = builder.create<LLVM::ICmpOp>( 567 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); 568 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume, 569 /*falseDest=*/coro.cleanup); 570 } 571 572 /// Outline the body region attached to the `async.execute` op into a standalone 573 /// function. 574 /// 575 /// Note that this is not reversible transformation. 576 static std::pair<FuncOp, CoroMachinery> 577 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 578 ModuleOp module = execute->getParentOfType<ModuleOp>(); 579 580 MLIRContext *ctx = module.getContext(); 581 Location loc = execute.getLoc(); 582 583 // Collect all outlined function inputs. 584 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 585 execute.dependencies().end()); 586 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 587 getUsedValuesDefinedAbove(execute.body(), functionInputs); 588 589 // Collect types for the outlined function inputs and outputs. 590 auto typesRange = llvm::map_range( 591 functionInputs, [](Value value) { return value.getType(); }); 592 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 593 auto outputTypes = execute.getResultTypes(); 594 595 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 596 auto funcAttrs = ArrayRef<NamedAttribute>(); 597 598 // TODO: Derive outlined function name from the parent FuncOp (support 599 // multiple nested async.execute operations). 600 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 601 symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); 602 603 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 604 605 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 606 // blocks, adding llvm.coro instrinsics and setting up control flow. 607 CoroMachinery coro = setupCoroMachinery(func); 608 609 // Suspend async function at the end of an entry block, and resume it using 610 // Async execute API (execution will be resumed in a thread managed by the 611 // async runtime). 612 Block *entryBlock = &func.getBlocks().front(); 613 auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); 614 615 // A pointer to coroutine resume intrinsic wrapper. 616 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 617 auto resumePtr = builder.create<LLVM::AddressOfOp>( 618 LLVM::LLVMPointerType::get(resumeFnTy), kResume); 619 620 // Save the coroutine state: @llvm.coro.save 621 auto coroSave = builder.create<LLVM::CallOp>( 622 LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 623 ValueRange({coro.coroHandle})); 624 625 // Call async runtime API to execute a coroutine in the managed thread. 626 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()}; 627 builder.create<CallOp>(TypeRange(), kExecute, executeArgs); 628 629 // Split the entry block before the terminator. 630 auto *terminatorOp = entryBlock->getTerminator(); 631 Block *suspended = terminatorOp->getBlock(); 632 Block *resume = suspended->splitBlock(terminatorOp); 633 addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended, 634 resume, builder); 635 636 size_t numDependencies = execute.dependencies().size(); 637 size_t numOperands = execute.operands().size(); 638 639 // Await on all dependencies before starting to execute the body region. 640 builder.setInsertionPointToStart(resume); 641 for (size_t i = 0; i < numDependencies; ++i) 642 builder.create<AwaitOp>(func.getArgument(i)); 643 644 // Await on all async value operands and unwrap the payload. 645 SmallVector<Value, 4> unwrappedOperands(numOperands); 646 for (size_t i = 0; i < numOperands; ++i) { 647 Value operand = func.getArgument(numDependencies + i); 648 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 649 } 650 651 // Map from function inputs defined above the execute op to the function 652 // arguments. 653 BlockAndValueMapping valueMapping; 654 valueMapping.map(functionInputs, func.getArguments()); 655 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 656 657 // Clone all operations from the execute operation body into the outlined 658 // function body. 659 for (Operation &op : execute.body().getOps()) 660 builder.clone(op, valueMapping); 661 662 // Replace the original `async.execute` with a call to outlined function. 663 ImplicitLocOpBuilder callBuilder(loc, execute); 664 auto callOutlinedFunc = callBuilder.create<CallOp>( 665 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 666 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 667 execute.erase(); 668 669 return {func, coro}; 670 } 671 672 //===----------------------------------------------------------------------===// 673 // Convert Async dialect types to LLVM types. 674 //===----------------------------------------------------------------------===// 675 676 namespace { 677 678 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 679 /// their runtime type (opaque pointers) and does not convert any other types. 680 class AsyncRuntimeTypeConverter : public TypeConverter { 681 public: 682 AsyncRuntimeTypeConverter() { 683 addConversion([](Type type) { return type; }); 684 addConversion(convertAsyncTypes); 685 } 686 687 static Optional<Type> convertAsyncTypes(Type type) { 688 if (type.isa<TokenType, GroupType, ValueType>()) 689 return AsyncAPI::opaquePointerType(type.getContext()); 690 return llvm::None; 691 } 692 }; 693 } // namespace 694 695 //===----------------------------------------------------------------------===// 696 // Convert return operations that return async values from async regions. 697 //===----------------------------------------------------------------------===// 698 699 namespace { 700 class ReturnOpOpConversion : public ConversionPattern { 701 public: 702 explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx) 703 : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {} 704 705 LogicalResult 706 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 707 ConversionPatternRewriter &rewriter) const override { 708 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 709 return success(); 710 } 711 }; 712 } // namespace 713 714 //===----------------------------------------------------------------------===// 715 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` 716 // to the corresponding API calls). 717 //===----------------------------------------------------------------------===// 718 719 namespace { 720 721 template <typename RefCountingOp> 722 class RefCountingOpLowering : public ConversionPattern { 723 public: 724 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 725 StringRef apiFunctionName) 726 : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx), 727 apiFunctionName(apiFunctionName) {} 728 729 LogicalResult 730 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 731 ConversionPatternRewriter &rewriter) const override { 732 RefCountingOp refCountingOp = cast<RefCountingOp>(op); 733 734 auto count = rewriter.create<ConstantOp>( 735 op->getLoc(), rewriter.getI32Type(), 736 rewriter.getI32IntegerAttr(refCountingOp.count())); 737 738 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 739 ValueRange({operands[0], count})); 740 741 return success(); 742 } 743 744 private: 745 StringRef apiFunctionName; 746 }; 747 748 /// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. 749 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> { 750 public: 751 explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 752 : RefCountingOpLowering(converter, ctx, kAddRef) {} 753 }; 754 755 /// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 756 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> { 757 public: 758 explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 759 : RefCountingOpLowering(converter, ctx, kDropRef) {} 760 }; 761 762 } // namespace 763 764 //===----------------------------------------------------------------------===// 765 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 766 //===----------------------------------------------------------------------===// 767 768 namespace { 769 class CreateGroupOpLowering : public ConversionPattern { 770 public: 771 explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) 772 : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter, 773 ctx) {} 774 775 LogicalResult 776 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 777 ConversionPatternRewriter &rewriter) const override { 778 auto retTy = GroupType::get(op->getContext()); 779 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy); 780 return success(); 781 } 782 }; 783 } // namespace 784 785 //===----------------------------------------------------------------------===// 786 // async.add_to_group op lowering to runtime function call. 787 //===----------------------------------------------------------------------===// 788 789 namespace { 790 class AddToGroupOpLowering : public ConversionPattern { 791 public: 792 explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) 793 : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) { 794 } 795 796 LogicalResult 797 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 798 ConversionPatternRewriter &rewriter) const override { 799 // Currently we can only add tokens to the group. 800 auto addToGroup = cast<AddToGroupOp>(op); 801 if (!addToGroup.operand().getType().isa<TokenType>()) 802 return failure(); 803 804 auto i64 = IntegerType::get(op->getContext(), 64); 805 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands); 806 return success(); 807 } 808 }; 809 } // namespace 810 811 //===----------------------------------------------------------------------===// 812 // async.await and async.await_all op lowerings to the corresponding async 813 // runtime function calls. 814 //===----------------------------------------------------------------------===// 815 816 namespace { 817 818 template <typename AwaitType, typename AwaitableType> 819 class AwaitOpLoweringBase : public ConversionPattern { 820 protected: 821 explicit AwaitOpLoweringBase( 822 TypeConverter &converter, MLIRContext *ctx, 823 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions, 824 StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) 825 : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx), 826 outlinedFunctions(outlinedFunctions), 827 blockingAwaitFuncName(blockingAwaitFuncName), 828 coroAwaitFuncName(coroAwaitFuncName) {} 829 830 public: 831 LogicalResult 832 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 833 ConversionPatternRewriter &rewriter) const override { 834 // We can only await on one the `AwaitableType` (for `await` it can be 835 // a `token` or a `value`, for `await_all` it must be a `group`). 836 auto await = cast<AwaitType>(op); 837 if (!await.operand().getType().template isa<AwaitableType>()) 838 return failure(); 839 840 // Check if await operation is inside the outlined coroutine function. 841 auto func = await->template getParentOfType<FuncOp>(); 842 auto outlined = outlinedFunctions.find(func); 843 const bool isInCoroutine = outlined != outlinedFunctions.end(); 844 845 Location loc = op->getLoc(); 846 847 // Inside regular function we convert await operation to the blocking 848 // async API await function call. 849 if (!isInCoroutine) 850 rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName, 851 ValueRange(operands[0])); 852 853 // Inside the coroutine we convert await operation into coroutine suspension 854 // point, and resume execution asynchronously. 855 if (isInCoroutine) { 856 const CoroMachinery &coro = outlined->getSecond(); 857 858 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 859 MLIRContext *ctx = op->getContext(); 860 861 // A pointer to coroutine resume intrinsic wrapper. 862 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 863 auto resumePtr = builder.create<LLVM::AddressOfOp>( 864 LLVM::LLVMPointerType::get(resumeFnTy), kResume); 865 866 // Save the coroutine state: @llvm.coro.save 867 auto coroSave = builder.create<LLVM::CallOp>( 868 LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 869 ValueRange(coro.coroHandle)); 870 871 // Call async runtime API to resume a coroutine in the managed thread when 872 // the async await argument becomes ready. 873 SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle, 874 resumePtr.res()}; 875 builder.create<CallOp>(TypeRange(), coroAwaitFuncName, 876 awaitAndExecuteArgs); 877 878 Block *suspended = op->getBlock(); 879 880 // Split the entry block before the await operation. 881 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 882 addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume, 883 builder); 884 885 // Make sure that replacement value will be constructed in resume block. 886 rewriter.setInsertionPointToStart(resume); 887 } 888 889 // Replace or erase the await operation with the new value. 890 if (Value replaceWith = getReplacementValue(op, operands[0], rewriter)) 891 rewriter.replaceOp(op, replaceWith); 892 else 893 rewriter.eraseOp(op); 894 895 return success(); 896 } 897 898 virtual Value getReplacementValue(Operation *op, Value operand, 899 ConversionPatternRewriter &rewriter) const { 900 return Value(); 901 } 902 903 private: 904 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 905 StringRef blockingAwaitFuncName; 906 StringRef coroAwaitFuncName; 907 }; 908 909 /// Lowering for `async.await` with a token operand. 910 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 911 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 912 913 public: 914 explicit AwaitTokenOpLowering( 915 TypeConverter &converter, MLIRContext *ctx, 916 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 917 : Base(converter, ctx, outlinedFunctions, kAwaitToken, 918 kAwaitTokenAndExecute) {} 919 }; 920 921 /// Lowering for `async.await` with a value operand. 922 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 923 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 924 925 public: 926 explicit AwaitValueOpLowering( 927 TypeConverter &converter, MLIRContext *ctx, 928 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 929 : Base(converter, ctx, outlinedFunctions, kAwaitValue, 930 kAwaitValueAndExecute) {} 931 932 Value 933 getReplacementValue(Operation *op, Value operand, 934 ConversionPatternRewriter &rewriter) const override { 935 Location loc = op->getLoc(); 936 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 937 938 // Get the underlying value type from the `async.value`. 939 auto await = cast<AwaitOp>(op); 940 auto valueType = await.operand().getType().cast<ValueType>().getValueType(); 941 942 // Get a pointer to an async value storage from the runtime. 943 auto storage = rewriter.create<CallOp>(loc, kGetValueStorage, 944 TypeRange(i8Ptr), operand); 945 946 // Cast from i8* to the pointer pointer to LLVM type. 947 auto llvmValueType = getTypeConverter()->convertType(valueType); 948 auto castedStorage = rewriter.create<LLVM::BitcastOp>( 949 loc, LLVM::LLVMPointerType::get(llvmValueType), storage.getResult(0)); 950 951 // Load from the async value storage. 952 auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult()); 953 954 // Cast from LLVM type to the expected value type. This cast will become 955 // no-op after lowering to LLVM. 956 return rewriter.create<LLVM::DialectCastOp>(loc, valueType, loaded); 957 } 958 }; 959 960 /// Lowering for `async.await_all` operation. 961 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 962 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 963 964 public: 965 explicit AwaitAllOpLowering( 966 TypeConverter &converter, MLIRContext *ctx, 967 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 968 : Base(converter, ctx, outlinedFunctions, kAwaitGroup, 969 kAwaitAllAndExecute) {} 970 }; 971 972 } // namespace 973 974 //===----------------------------------------------------------------------===// 975 // async.yield op lowerings to the corresponding async runtime function calls. 976 //===----------------------------------------------------------------------===// 977 978 class YieldOpLowering : public ConversionPattern { 979 public: 980 explicit YieldOpLowering( 981 TypeConverter &converter, MLIRContext *ctx, 982 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 983 : ConversionPattern(async::YieldOp::getOperationName(), 1, converter, 984 ctx), 985 outlinedFunctions(outlinedFunctions) {} 986 987 LogicalResult 988 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 989 ConversionPatternRewriter &rewriter) const override { 990 // Check if yield operation is inside the outlined coroutine function. 991 auto func = op->template getParentOfType<FuncOp>(); 992 auto outlined = outlinedFunctions.find(func); 993 if (outlined == outlinedFunctions.end()) 994 return op->emitOpError( 995 "async.yield is not inside the outlined coroutine function"); 996 997 Location loc = op->getLoc(); 998 const CoroMachinery &coro = outlined->getSecond(); 999 1000 // Store yielded values into the async values storage and emplace them. 1001 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 1002 1003 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 1004 // Store `yieldValue` into the `asyncValue` storage. 1005 Value yieldValue = std::get<0>(tuple); 1006 Value asyncValue = std::get<1>(tuple); 1007 1008 // Get an opaque i8* pointer to an async value storage from the runtime. 1009 auto storage = rewriter.create<CallOp>(loc, kGetValueStorage, 1010 TypeRange(i8Ptr), asyncValue); 1011 1012 // Cast storage pointer to the yielded value type. 1013 auto castedStorage = rewriter.create<LLVM::BitcastOp>( 1014 loc, LLVM::LLVMPointerType::get(yieldValue.getType()), 1015 storage.getResult(0)); 1016 1017 // Store the yielded value into the async value storage. 1018 rewriter.create<LLVM::StoreOp>(loc, yieldValue, 1019 castedStorage.getResult()); 1020 1021 // Emplace the `async.value` to mark it ready. 1022 rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue); 1023 } 1024 1025 // Emplace the completion token to mark it ready. 1026 rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken); 1027 1028 // Original operation was replaced by the function call(s). 1029 rewriter.eraseOp(op); 1030 1031 return success(); 1032 } 1033 1034 private: 1035 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 1036 }; 1037 1038 //===----------------------------------------------------------------------===// 1039 1040 namespace { 1041 struct ConvertAsyncToLLVMPass 1042 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 1043 void runOnOperation() override; 1044 }; 1045 1046 void ConvertAsyncToLLVMPass::runOnOperation() { 1047 ModuleOp module = getOperation(); 1048 SymbolTable symbolTable(module); 1049 1050 MLIRContext *ctx = &getContext(); 1051 1052 // Outline all `async.execute` body regions into async functions (coroutines). 1053 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 1054 1055 // We use conversion to LLVM type to ensure that all `async.value` operands 1056 // and results can be lowered to LLVM load and store operations. 1057 LLVMTypeConverter llvmConverter(ctx); 1058 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 1059 1060 // Returns true if the `async.value` payload is convertible to LLVM. 1061 auto isConvertibleToLlvm = [&](Type type) -> bool { 1062 auto valueType = type.cast<ValueType>().getValueType(); 1063 return static_cast<bool>(llvmConverter.convertType(valueType)); 1064 }; 1065 1066 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 1067 // All operands and results must be convertible to LLVM. 1068 if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) { 1069 execute.emitOpError("operands payload must be convertible to LLVM type"); 1070 return WalkResult::interrupt(); 1071 } 1072 if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) { 1073 execute.emitOpError("results payload must be convertible to LLVM type"); 1074 return WalkResult::interrupt(); 1075 } 1076 1077 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 1078 1079 return WalkResult::advance(); 1080 }); 1081 1082 // Failed to outline all async execute operations. 1083 if (outlineResult.wasInterrupted()) { 1084 signalPassFailure(); 1085 return; 1086 } 1087 1088 LLVM_DEBUG({ 1089 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 1090 << " async functions\n"; 1091 }); 1092 1093 // Add declarations for all functions required by the coroutines lowering. 1094 addResumeFunction(module); 1095 addAsyncRuntimeApiDeclarations(module); 1096 addCoroutineIntrinsicsDeclarations(module); 1097 addCRuntimeDeclarations(module); 1098 1099 // Convert async dialect types and operations to LLVM dialect. 1100 AsyncRuntimeTypeConverter converter; 1101 OwningRewritePatternList patterns; 1102 1103 // Convert async types in function signatures and function calls. 1104 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 1105 populateCallOpTypeConversionPattern(patterns, ctx, converter); 1106 1107 // Convert return operations inside async.execute regions. 1108 patterns.insert<ReturnOpOpConversion>(converter, ctx); 1109 1110 // Lower async operations to async runtime API calls. 1111 patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx); 1112 patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx); 1113 1114 // Use LLVM type converter to automatically convert between the async value 1115 // payload type and LLVM type when loading/storing from/to the async 1116 // value storage which is an opaque i8* pointer using LLVM load/store ops. 1117 patterns 1118 .insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>( 1119 llvmConverter, ctx, outlinedFunctions); 1120 patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions); 1121 1122 ConversionTarget target(*ctx); 1123 target.addLegalOp<ConstantOp>(); 1124 target.addLegalDialect<LLVM::LLVMDialect>(); 1125 1126 // All operations from Async dialect must be lowered to the runtime API calls. 1127 target.addIllegalDialect<AsyncDialect>(); 1128 1129 // Add dynamic legality constraints to apply conversions defined above. 1130 target.addDynamicallyLegalOp<FuncOp>( 1131 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1132 target.addDynamicallyLegalOp<ReturnOp>( 1133 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1134 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1135 return converter.isSignatureLegal(op.getCalleeType()); 1136 }); 1137 1138 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1139 signalPassFailure(); 1140 } 1141 } // namespace 1142 1143 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1144 return std::make_unique<ConvertAsyncToLLVMPass>(); 1145 } 1146