1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Dialect/Async/IR/Async.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/ImplicitLocOpBuilder.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/RegionUtils.h" 23 #include "llvm/ADT/SetVector.h" 24 #include "llvm/Support/FormatVariadic.h" 25 26 #define DEBUG_TYPE "convert-async-to-llvm" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 // Prefix for functions outlined from `async.execute` op regions. 32 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 33 34 //===----------------------------------------------------------------------===// 35 // Async Runtime C API declaration. 36 //===----------------------------------------------------------------------===// 37 38 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 39 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 40 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 41 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 42 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 43 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 44 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 45 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 46 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 47 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 48 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 49 static constexpr const char *kGetValueStorage = 50 "mlirAsyncRuntimeGetValueStorage"; 51 static constexpr const char *kAddTokenToGroup = 52 "mlirAsyncRuntimeAddTokenToGroup"; 53 static constexpr const char *kAwaitTokenAndExecute = 54 "mlirAsyncRuntimeAwaitTokenAndExecute"; 55 static constexpr const char *kAwaitValueAndExecute = 56 "mlirAsyncRuntimeAwaitValueAndExecute"; 57 static constexpr const char *kAwaitAllAndExecute = 58 "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 59 60 namespace { 61 /// Async Runtime API function types. 62 /// 63 /// Because we can't create API function signature for type parametrized 64 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After 65 /// lowering all async data types become opaque pointers at runtime. 66 struct AsyncAPI { 67 // All async types are lowered to opaque i8* LLVM pointers at runtime. 68 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 69 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 70 } 71 72 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 73 auto ref = opaquePointerType(ctx); 74 auto count = IntegerType::get(ctx, 32); 75 return FunctionType::get(ctx, {ref, count}, {}); 76 } 77 78 static FunctionType createTokenFunctionType(MLIRContext *ctx) { 79 return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 80 } 81 82 static FunctionType createValueFunctionType(MLIRContext *ctx) { 83 auto i32 = IntegerType::get(ctx, 32); 84 auto value = opaquePointerType(ctx); 85 return FunctionType::get(ctx, {i32}, {value}); 86 } 87 88 static FunctionType createGroupFunctionType(MLIRContext *ctx) { 89 return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); 90 } 91 92 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 93 auto value = opaquePointerType(ctx); 94 auto storage = opaquePointerType(ctx); 95 return FunctionType::get(ctx, {value}, {storage}); 96 } 97 98 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 99 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 100 } 101 102 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 103 auto value = opaquePointerType(ctx); 104 return FunctionType::get(ctx, {value}, {}); 105 } 106 107 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 108 return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 109 } 110 111 static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 112 auto value = opaquePointerType(ctx); 113 return FunctionType::get(ctx, {value}, {}); 114 } 115 116 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 117 return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 118 } 119 120 static FunctionType executeFunctionType(MLIRContext *ctx) { 121 auto hdl = opaquePointerType(ctx); 122 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 123 return FunctionType::get(ctx, {hdl, resume}, {}); 124 } 125 126 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 127 auto i64 = IntegerType::get(ctx, 64); 128 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 129 {i64}); 130 } 131 132 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 133 auto hdl = opaquePointerType(ctx); 134 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 135 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); 136 } 137 138 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 139 auto value = opaquePointerType(ctx); 140 auto hdl = opaquePointerType(ctx); 141 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 142 return FunctionType::get(ctx, {value, hdl, resume}, {}); 143 } 144 145 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 146 auto hdl = opaquePointerType(ctx); 147 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); 148 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); 149 } 150 151 // Auxiliary coroutine resume intrinsic wrapper. 152 static Type resumeFunctionType(MLIRContext *ctx) { 153 auto voidTy = LLVM::LLVMVoidType::get(ctx); 154 auto i8Ptr = opaquePointerType(ctx); 155 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); 156 } 157 }; 158 } // namespace 159 160 /// Adds Async Runtime C API declarations to the module. 161 static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 162 auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), 163 module.getBody()); 164 165 auto addFuncDecl = [&](StringRef name, FunctionType type) { 166 if (module.lookupSymbol(name)) 167 return; 168 builder.create<FuncOp>(name, type).setPrivate(); 169 }; 170 171 MLIRContext *ctx = module.getContext(); 172 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 173 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 174 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 175 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 176 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 177 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 178 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 179 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 180 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 181 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 182 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 183 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 184 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 185 addFuncDecl(kAwaitTokenAndExecute, 186 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 187 addFuncDecl(kAwaitValueAndExecute, 188 AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 189 addFuncDecl(kAwaitAllAndExecute, 190 AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 191 } 192 193 //===----------------------------------------------------------------------===// 194 // LLVM coroutines intrinsics declarations. 195 //===----------------------------------------------------------------------===// 196 197 static constexpr const char *kCoroId = "llvm.coro.id"; 198 static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; 199 static constexpr const char *kCoroBegin = "llvm.coro.begin"; 200 static constexpr const char *kCoroSave = "llvm.coro.save"; 201 static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; 202 static constexpr const char *kCoroEnd = "llvm.coro.end"; 203 static constexpr const char *kCoroFree = "llvm.coro.free"; 204 static constexpr const char *kCoroResume = "llvm.coro.resume"; 205 206 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, 207 StringRef name, Type ret, ArrayRef<Type> params) { 208 if (module.lookupSymbol(name)) 209 return; 210 Type type = LLVM::LLVMFunctionType::get(ret, params); 211 builder.create<LLVM::LLVMFuncOp>(name, type); 212 } 213 214 /// Adds coroutine intrinsics declarations to the module. 215 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { 216 using namespace mlir::LLVM; 217 218 MLIRContext *ctx = module.getContext(); 219 ImplicitLocOpBuilder builder(module.getLoc(), 220 module.getBody()->getTerminator()); 221 222 auto token = LLVMTokenType::get(ctx); 223 auto voidTy = LLVMVoidType::get(ctx); 224 225 auto i8 = IntegerType::get(ctx, 8); 226 auto i1 = IntegerType::get(ctx, 1); 227 auto i32 = IntegerType::get(ctx, 32); 228 auto i64 = IntegerType::get(ctx, 64); 229 auto i8Ptr = LLVMPointerType::get(i8); 230 231 addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); 232 addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); 233 addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); 234 addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); 235 addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); 236 addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); 237 addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); 238 addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // Add malloc/free declarations to the module. 243 //===----------------------------------------------------------------------===// 244 245 static constexpr const char *kMalloc = "malloc"; 246 static constexpr const char *kFree = "free"; 247 248 /// Adds malloc/free declarations to the module. 249 static void addCRuntimeDeclarations(ModuleOp module) { 250 using namespace mlir::LLVM; 251 252 MLIRContext *ctx = module.getContext(); 253 ImplicitLocOpBuilder builder(module.getLoc(), 254 module.getBody()->getTerminator()); 255 256 auto voidTy = LLVMVoidType::get(ctx); 257 auto i64 = IntegerType::get(ctx, 64); 258 auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); 259 260 addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); 261 addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); 262 } 263 264 //===----------------------------------------------------------------------===// 265 // Coroutine resume function wrapper. 266 //===----------------------------------------------------------------------===// 267 268 static constexpr const char *kResume = "__resume"; 269 270 /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 271 /// intrinsics. We need this function to be able to pass it to the async 272 /// runtime execute API. 273 static void addResumeFunction(ModuleOp module) { 274 MLIRContext *ctx = module.getContext(); 275 276 OpBuilder moduleBuilder(module.getBody()->getTerminator()); 277 Location loc = module.getLoc(); 278 279 if (module.lookupSymbol(kResume)) 280 return; 281 282 auto voidTy = LLVM::LLVMVoidType::get(ctx); 283 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 284 285 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 286 loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); 287 resumeOp.setPrivate(); 288 289 auto *block = resumeOp.addEntryBlock(); 290 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 291 292 blockBuilder.create<LLVM::CallOp>(TypeRange(), 293 blockBuilder.getSymbolRefAttr(kCoroResume), 294 resumeOp.getArgument(0)); 295 296 blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 297 } 298 299 //===----------------------------------------------------------------------===// 300 // async.execute op outlining to the coroutine functions. 301 //===----------------------------------------------------------------------===// 302 303 /// Function targeted for coroutine transformation has two additional blocks at 304 /// the end: coroutine cleanup and coroutine suspension. 305 /// 306 /// async.await op lowering additionaly creates a resume block for each 307 /// operation to enable non-blocking waiting via coroutine suspension. 308 namespace { 309 struct CoroMachinery { 310 // Async execute region returns a completion token, and an async value for 311 // each yielded value. 312 // 313 // %token, %result = async.execute -> !async.value<T> { 314 // %0 = constant ... : T 315 // async.yield %0 : T 316 // } 317 Value asyncToken; // token representing completion of the async region 318 llvm::SmallVector<Value, 4> returnValues; // returned async values 319 320 Value coroHandle; 321 Block *cleanup; 322 Block *suspend; 323 }; 324 } // namespace 325 326 /// Builds an coroutine template compatible with LLVM coroutines lowering. 327 /// 328 /// - `entry` block sets up the coroutine. 329 /// - `cleanup` block cleans up the coroutine state. 330 /// - `suspend block after the @llvm.coro.end() defines what value will be 331 /// returned to the initial caller of a coroutine. Everything before the 332 /// @llvm.coro.end() will be executed at every suspension point. 333 /// 334 /// Coroutine structure (only the important bits): 335 /// 336 /// func @async_execute_fn(<function-arguments>) 337 /// -> (!async.token, !async.value<T>) 338 /// { 339 /// ^entryBlock(<function-arguments>): 340 /// %token = <async token> : !async.token // create async runtime token 341 /// %value = <async value> : !async.value<T> // create async value 342 /// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle 343 /// br ^cleanup 344 /// 345 /// ^cleanup: 346 /// llvm.call @llvm.coro.free(...) // delete coroutine state 347 /// br ^suspend 348 /// 349 /// ^suspend: 350 /// llvm.call @llvm.coro.end(...) // marks the end of a coroutine 351 /// return %token, %value : !async.token, !async.value<T> 352 /// } 353 /// 354 /// The actual code for the async.execute operation body region will be inserted 355 /// before the entry block terminator. 356 /// 357 /// 358 static CoroMachinery setupCoroMachinery(FuncOp func) { 359 assert(func.getBody().empty() && "Function must have empty body"); 360 361 MLIRContext *ctx = func.getContext(); 362 363 auto token = LLVM::LLVMTokenType::get(ctx); 364 auto i1 = IntegerType::get(ctx, 1); 365 auto i32 = IntegerType::get(ctx, 32); 366 auto i64 = IntegerType::get(ctx, 64); 367 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); 368 369 Block *entryBlock = func.addEntryBlock(); 370 Location loc = func.getBody().getLoc(); 371 372 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, entryBlock); 373 374 // ------------------------------------------------------------------------ // 375 // Allocate async tokens/values that we will return from a ramp function. 376 // ------------------------------------------------------------------------ // 377 auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx)); 378 379 // Async value operands and results must be convertible to LLVM types. This is 380 // verified before the function outlining. 381 LLVMTypeConverter converter(ctx); 382 383 // Returns the size requirements for the async value storage. 384 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 385 auto sizeOf = [&](ValueType valueType) -> Value { 386 auto storedType = converter.convertType(valueType.getValueType()); 387 auto storagePtrType = LLVM::LLVMPointerType::get(storedType); 388 389 // %Size = getelementptr %T* null, int 1 390 // %SizeI = ptrtoint %T* %Size to i32 391 auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType); 392 auto one = builder.create<LLVM::ConstantOp>(loc, i32, 393 builder.getI32IntegerAttr(1)); 394 auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr, 395 one.getResult()); 396 return builder.create<LLVM::PtrToIntOp>(loc, i32, gep); 397 }; 398 399 // We use the `async.value` type as a return type although it does not match 400 // the `kCreateValue` function signature, because it will be later lowered to 401 // the runtime type (opaque i8* pointer). 402 llvm::SmallVector<CallOp, 4> createValues; 403 for (auto resultType : func.getCallableResults().drop_front(1)) 404 createValues.emplace_back(builder.create<CallOp>( 405 loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>()))); 406 407 auto createdValues = llvm::map_range( 408 createValues, [](CallOp call) { return call.getResult(0); }); 409 llvm::SmallVector<Value, 4> returnValues(createdValues.begin(), 410 createdValues.end()); 411 412 // ------------------------------------------------------------------------ // 413 // Initialize coroutine: allocate frame, get coroutine handle. 414 // ------------------------------------------------------------------------ // 415 416 // Constants for initializing coroutine frame. 417 auto constZero = 418 builder.create<LLVM::ConstantOp>(i32, builder.getI32IntegerAttr(0)); 419 auto constFalse = 420 builder.create<LLVM::ConstantOp>(i1, builder.getBoolAttr(false)); 421 auto nullPtr = builder.create<LLVM::NullOp>(i8Ptr); 422 423 // Get coroutine id: @llvm.coro.id 424 auto coroId = builder.create<LLVM::CallOp>( 425 token, builder.getSymbolRefAttr(kCoroId), 426 ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 427 428 // Get coroutine frame size: @llvm.coro.size.i64 429 auto coroSize = builder.create<LLVM::CallOp>( 430 i64, builder.getSymbolRefAttr(kCoroSizeI64), ValueRange()); 431 432 // Allocate memory for coroutine frame. 433 auto coroAlloc = 434 builder.create<LLVM::CallOp>(i8Ptr, builder.getSymbolRefAttr(kMalloc), 435 ValueRange(coroSize.getResult(0))); 436 437 // Begin a coroutine: @llvm.coro.begin 438 auto coroHdl = builder.create<LLVM::CallOp>( 439 i8Ptr, builder.getSymbolRefAttr(kCoroBegin), 440 ValueRange({coroId.getResult(0), coroAlloc.getResult(0)})); 441 442 Block *cleanupBlock = func.addBlock(); 443 Block *suspendBlock = func.addBlock(); 444 445 // ------------------------------------------------------------------------ // 446 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 447 // ------------------------------------------------------------------------ // 448 builder.setInsertionPointToStart(cleanupBlock); 449 450 // Get a pointer to the coroutine frame memory: @llvm.coro.free. 451 auto coroMem = builder.create<LLVM::CallOp>( 452 i8Ptr, builder.getSymbolRefAttr(kCoroFree), 453 ValueRange({coroId.getResult(0), coroHdl.getResult(0)})); 454 455 // Free the memory. 456 builder.create<LLVM::CallOp>(TypeRange(), builder.getSymbolRefAttr(kFree), 457 ValueRange(coroMem.getResult(0))); 458 // Branch into the suspend block. 459 builder.create<BranchOp>(suspendBlock); 460 461 // ------------------------------------------------------------------------ // 462 // Coroutine suspend block: mark the end of a coroutine and return allocated 463 // async token. 464 // ------------------------------------------------------------------------ // 465 builder.setInsertionPointToStart(suspendBlock); 466 467 // Mark the end of a coroutine: @llvm.coro.end. 468 builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd), 469 ValueRange({coroHdl.getResult(0), constFalse})); 470 471 // Return created `async.token` and `async.values` from the suspend block. 472 // This will be the return value of a coroutine ramp function. 473 SmallVector<Value, 4> ret{createToken.getResult(0)}; 474 ret.insert(ret.end(), returnValues.begin(), returnValues.end()); 475 builder.create<ReturnOp>(loc, ret); 476 477 // Branch from the entry block to the cleanup block to create a valid CFG. 478 builder.setInsertionPointToEnd(entryBlock); 479 480 builder.create<BranchOp>(cleanupBlock); 481 482 // `async.await` op lowering will create resume blocks for async 483 // continuations, and will conditionally branch to cleanup or suspend blocks. 484 485 CoroMachinery machinery; 486 machinery.asyncToken = createToken.getResult(0); 487 machinery.returnValues = returnValues; 488 machinery.coroHandle = coroHdl.getResult(0); 489 machinery.cleanup = cleanupBlock; 490 machinery.suspend = suspendBlock; 491 return machinery; 492 } 493 494 /// Add a LLVM coroutine suspension point to the end of suspended block, to 495 /// resume execution in resume block. The caller is responsible for creating the 496 /// two suspended/resume blocks with the desired ops contained in each block. 497 /// This function merely provides the required control flow logic. 498 /// 499 /// `coroState` must be a value returned from the call to @llvm.coro.save(...) 500 /// intrinsic (saved coroutine state). 501 /// 502 /// Before: 503 /// 504 /// ^bb0: 505 /// "opBefore"(...) 506 /// "op"(...) 507 /// ^cleanup: ... 508 /// ^suspend: ... 509 /// ^resume: 510 /// "op"(...) 511 /// 512 /// After: 513 /// 514 /// ^bb0: 515 /// "opBefore"(...) 516 /// %suspend = llmv.call @llvm.coro.suspend(...) 517 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 518 /// ^resume: 519 /// "op"(...) 520 /// ^cleanup: ... 521 /// ^suspend: ... 522 /// 523 static void addSuspensionPoint(CoroMachinery coro, Value coroState, 524 Operation *op, Block *suspended, Block *resume, 525 OpBuilder &builder) { 526 Location loc = op->getLoc(); 527 MLIRContext *ctx = op->getContext(); 528 auto i1 = IntegerType::get(ctx, 1); 529 auto i8 = IntegerType::get(ctx, 8); 530 531 // Add a coroutine suspension in place of original `op` in the split block. 532 OpBuilder::InsertionGuard guard(builder); 533 builder.setInsertionPointToEnd(suspended); 534 535 auto constFalse = 536 builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false)); 537 538 // Suspend a coroutine: @llvm.coro.suspend 539 auto coroSuspend = builder.create<LLVM::CallOp>( 540 loc, i8, builder.getSymbolRefAttr(kCoroSuspend), 541 ValueRange({coroState, constFalse})); 542 543 // After a suspension point decide if we should branch into resume, cleanup 544 // or suspend block of the coroutine (see @llvm.coro.suspend return code 545 // documentation). 546 auto constZero = 547 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(0)); 548 auto constNegOne = 549 builder.create<LLVM::ConstantOp>(loc, i8, builder.getI8IntegerAttr(-1)); 550 551 Block *resumeOrCleanup = builder.createBlock(resume); 552 553 // Suspend the coroutine ...? 554 builder.setInsertionPointToEnd(suspended); 555 auto isNegOne = builder.create<LLVM::ICmpOp>( 556 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); 557 builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend, 558 /*falseDest=*/resumeOrCleanup); 559 560 // ... or resume or cleanup the coroutine? 561 builder.setInsertionPointToStart(resumeOrCleanup); 562 auto isZero = builder.create<LLVM::ICmpOp>( 563 loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); 564 builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume, 565 /*falseDest=*/coro.cleanup); 566 } 567 568 /// Outline the body region attached to the `async.execute` op into a standalone 569 /// function. 570 /// 571 /// Note that this is not reversible transformation. 572 static std::pair<FuncOp, CoroMachinery> 573 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 574 ModuleOp module = execute->getParentOfType<ModuleOp>(); 575 576 MLIRContext *ctx = module.getContext(); 577 Location loc = execute.getLoc(); 578 579 // Collect all outlined function inputs. 580 llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 581 execute.dependencies().end()); 582 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 583 getUsedValuesDefinedAbove(execute.body(), functionInputs); 584 585 // Collect types for the outlined function inputs and outputs. 586 auto typesRange = llvm::map_range( 587 functionInputs, [](Value value) { return value.getType(); }); 588 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 589 auto outputTypes = execute.getResultTypes(); 590 591 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 592 auto funcAttrs = ArrayRef<NamedAttribute>(); 593 594 // TODO: Derive outlined function name from the parent FuncOp (support 595 // multiple nested async.execute operations). 596 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 597 symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); 598 599 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 600 601 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 602 // blocks, adding llvm.coro instrinsics and setting up control flow. 603 CoroMachinery coro = setupCoroMachinery(func); 604 605 // Suspend async function at the end of an entry block, and resume it using 606 // Async execute API (execution will be resumed in a thread managed by the 607 // async runtime). 608 Block *entryBlock = &func.getBlocks().front(); 609 auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); 610 611 // A pointer to coroutine resume intrinsic wrapper. 612 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 613 auto resumePtr = builder.create<LLVM::AddressOfOp>( 614 LLVM::LLVMPointerType::get(resumeFnTy), kResume); 615 616 // Save the coroutine state: @llvm.coro.save 617 auto coroSave = builder.create<LLVM::CallOp>( 618 LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 619 ValueRange({coro.coroHandle})); 620 621 // Call async runtime API to execute a coroutine in the managed thread. 622 SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()}; 623 builder.create<CallOp>(TypeRange(), kExecute, executeArgs); 624 625 // Split the entry block before the terminator. 626 auto *terminatorOp = entryBlock->getTerminator(); 627 Block *suspended = terminatorOp->getBlock(); 628 Block *resume = suspended->splitBlock(terminatorOp); 629 addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended, 630 resume, builder); 631 632 size_t numDependencies = execute.dependencies().size(); 633 size_t numOperands = execute.operands().size(); 634 635 // Await on all dependencies before starting to execute the body region. 636 builder.setInsertionPointToStart(resume); 637 for (size_t i = 0; i < numDependencies; ++i) 638 builder.create<AwaitOp>(func.getArgument(i)); 639 640 // Await on all async value operands and unwrap the payload. 641 SmallVector<Value, 4> unwrappedOperands(numOperands); 642 for (size_t i = 0; i < numOperands; ++i) { 643 Value operand = func.getArgument(numDependencies + i); 644 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 645 } 646 647 // Map from function inputs defined above the execute op to the function 648 // arguments. 649 BlockAndValueMapping valueMapping; 650 valueMapping.map(functionInputs, func.getArguments()); 651 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 652 653 // Clone all operations from the execute operation body into the outlined 654 // function body. 655 for (Operation &op : execute.body().getOps()) 656 builder.clone(op, valueMapping); 657 658 // Replace the original `async.execute` with a call to outlined function. 659 ImplicitLocOpBuilder callBuilder(loc, execute); 660 auto callOutlinedFunc = callBuilder.create<CallOp>( 661 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 662 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 663 execute.erase(); 664 665 return {func, coro}; 666 } 667 668 //===----------------------------------------------------------------------===// 669 // Convert Async dialect types to LLVM types. 670 //===----------------------------------------------------------------------===// 671 672 namespace { 673 674 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 675 /// their runtime type (opaque pointers) and does not convert any other types. 676 class AsyncRuntimeTypeConverter : public TypeConverter { 677 public: 678 AsyncRuntimeTypeConverter() { 679 addConversion([](Type type) { return type; }); 680 addConversion(convertAsyncTypes); 681 } 682 683 static Optional<Type> convertAsyncTypes(Type type) { 684 if (type.isa<TokenType, GroupType, ValueType>()) 685 return AsyncAPI::opaquePointerType(type.getContext()); 686 return llvm::None; 687 } 688 }; 689 } // namespace 690 691 //===----------------------------------------------------------------------===// 692 // Convert return operations that return async values from async regions. 693 //===----------------------------------------------------------------------===// 694 695 namespace { 696 class ReturnOpOpConversion : public ConversionPattern { 697 public: 698 explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx) 699 : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {} 700 701 LogicalResult 702 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 703 ConversionPatternRewriter &rewriter) const override { 704 rewriter.replaceOpWithNewOp<ReturnOp>(op, operands); 705 return success(); 706 } 707 }; 708 } // namespace 709 710 //===----------------------------------------------------------------------===// 711 // Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` 712 // to the corresponding API calls). 713 //===----------------------------------------------------------------------===// 714 715 namespace { 716 717 template <typename RefCountingOp> 718 class RefCountingOpLowering : public ConversionPattern { 719 public: 720 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, 721 StringRef apiFunctionName) 722 : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx), 723 apiFunctionName(apiFunctionName) {} 724 725 LogicalResult 726 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 727 ConversionPatternRewriter &rewriter) const override { 728 RefCountingOp refCountingOp = cast<RefCountingOp>(op); 729 730 auto count = rewriter.create<ConstantOp>( 731 op->getLoc(), rewriter.getI32Type(), 732 rewriter.getI32IntegerAttr(refCountingOp.count())); 733 734 rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName, 735 ValueRange({operands[0], count})); 736 737 return success(); 738 } 739 740 private: 741 StringRef apiFunctionName; 742 }; 743 744 /// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. 745 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> { 746 public: 747 explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 748 : RefCountingOpLowering(converter, ctx, kAddRef) {} 749 }; 750 751 /// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 752 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> { 753 public: 754 explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) 755 : RefCountingOpLowering(converter, ctx, kDropRef) {} 756 }; 757 758 } // namespace 759 760 //===----------------------------------------------------------------------===// 761 // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. 762 //===----------------------------------------------------------------------===// 763 764 namespace { 765 class CreateGroupOpLowering : public ConversionPattern { 766 public: 767 explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) 768 : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter, 769 ctx) {} 770 771 LogicalResult 772 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 773 ConversionPatternRewriter &rewriter) const override { 774 auto retTy = GroupType::get(op->getContext()); 775 rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup, retTy); 776 return success(); 777 } 778 }; 779 } // namespace 780 781 //===----------------------------------------------------------------------===// 782 // async.add_to_group op lowering to runtime function call. 783 //===----------------------------------------------------------------------===// 784 785 namespace { 786 class AddToGroupOpLowering : public ConversionPattern { 787 public: 788 explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx) 789 : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) { 790 } 791 792 LogicalResult 793 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 794 ConversionPatternRewriter &rewriter) const override { 795 // Currently we can only add tokens to the group. 796 auto addToGroup = cast<AddToGroupOp>(op); 797 if (!addToGroup.operand().getType().isa<TokenType>()) 798 return failure(); 799 800 auto i64 = IntegerType::get(op->getContext(), 64); 801 rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands); 802 return success(); 803 } 804 }; 805 } // namespace 806 807 //===----------------------------------------------------------------------===// 808 // async.await and async.await_all op lowerings to the corresponding async 809 // runtime function calls. 810 //===----------------------------------------------------------------------===// 811 812 namespace { 813 814 template <typename AwaitType, typename AwaitableType> 815 class AwaitOpLoweringBase : public ConversionPattern { 816 protected: 817 explicit AwaitOpLoweringBase( 818 TypeConverter &converter, MLIRContext *ctx, 819 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions, 820 StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName) 821 : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx), 822 outlinedFunctions(outlinedFunctions), 823 blockingAwaitFuncName(blockingAwaitFuncName), 824 coroAwaitFuncName(coroAwaitFuncName) {} 825 826 public: 827 LogicalResult 828 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 829 ConversionPatternRewriter &rewriter) const override { 830 // We can only await on one the `AwaitableType` (for `await` it can be 831 // a `token` or a `value`, for `await_all` it must be a `group`). 832 auto await = cast<AwaitType>(op); 833 if (!await.operand().getType().template isa<AwaitableType>()) 834 return failure(); 835 836 // Check if await operation is inside the outlined coroutine function. 837 auto func = await->template getParentOfType<FuncOp>(); 838 auto outlined = outlinedFunctions.find(func); 839 const bool isInCoroutine = outlined != outlinedFunctions.end(); 840 841 Location loc = op->getLoc(); 842 843 // Inside regular function we convert await operation to the blocking 844 // async API await function call. 845 if (!isInCoroutine) 846 rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName, 847 ValueRange(operands[0])); 848 849 // Inside the coroutine we convert await operation into coroutine suspension 850 // point, and resume execution asynchronously. 851 if (isInCoroutine) { 852 const CoroMachinery &coro = outlined->getSecond(); 853 854 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 855 MLIRContext *ctx = op->getContext(); 856 857 // A pointer to coroutine resume intrinsic wrapper. 858 auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); 859 auto resumePtr = builder.create<LLVM::AddressOfOp>( 860 LLVM::LLVMPointerType::get(resumeFnTy), kResume); 861 862 // Save the coroutine state: @llvm.coro.save 863 auto coroSave = builder.create<LLVM::CallOp>( 864 LLVM::LLVMTokenType::get(ctx), builder.getSymbolRefAttr(kCoroSave), 865 ValueRange(coro.coroHandle)); 866 867 // Call async runtime API to resume a coroutine in the managed thread when 868 // the async await argument becomes ready. 869 SmallVector<Value, 3> awaitAndExecuteArgs = {operands[0], coro.coroHandle, 870 resumePtr.res()}; 871 builder.create<CallOp>(TypeRange(), coroAwaitFuncName, 872 awaitAndExecuteArgs); 873 874 Block *suspended = op->getBlock(); 875 876 // Split the entry block before the await operation. 877 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 878 addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume, 879 builder); 880 881 // Make sure that replacement value will be constructed in resume block. 882 rewriter.setInsertionPointToStart(resume); 883 } 884 885 // Replace or erase the await operation with the new value. 886 if (Value replaceWith = getReplacementValue(op, operands[0], rewriter)) 887 rewriter.replaceOp(op, replaceWith); 888 else 889 rewriter.eraseOp(op); 890 891 return success(); 892 } 893 894 virtual Value getReplacementValue(Operation *op, Value operand, 895 ConversionPatternRewriter &rewriter) const { 896 return Value(); 897 } 898 899 private: 900 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 901 StringRef blockingAwaitFuncName; 902 StringRef coroAwaitFuncName; 903 }; 904 905 /// Lowering for `async.await` with a token operand. 906 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 907 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 908 909 public: 910 explicit AwaitTokenOpLowering( 911 TypeConverter &converter, MLIRContext *ctx, 912 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 913 : Base(converter, ctx, outlinedFunctions, kAwaitToken, 914 kAwaitTokenAndExecute) {} 915 }; 916 917 /// Lowering for `async.await` with a value operand. 918 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 919 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 920 921 public: 922 explicit AwaitValueOpLowering( 923 TypeConverter &converter, MLIRContext *ctx, 924 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 925 : Base(converter, ctx, outlinedFunctions, kAwaitValue, 926 kAwaitValueAndExecute) {} 927 928 Value 929 getReplacementValue(Operation *op, Value operand, 930 ConversionPatternRewriter &rewriter) const override { 931 Location loc = op->getLoc(); 932 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 933 934 // Get the underlying value type from the `async.value`. 935 auto await = cast<AwaitOp>(op); 936 auto valueType = await.operand().getType().cast<ValueType>().getValueType(); 937 938 // Get a pointer to an async value storage from the runtime. 939 auto storage = rewriter.create<CallOp>(loc, kGetValueStorage, 940 TypeRange(i8Ptr), operand); 941 942 // Cast from i8* to the pointer pointer to LLVM type. 943 auto llvmValueType = getTypeConverter()->convertType(valueType); 944 auto castedStorage = rewriter.create<LLVM::BitcastOp>( 945 loc, LLVM::LLVMPointerType::get(llvmValueType), storage.getResult(0)); 946 947 // Load from the async value storage. 948 return rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult()); 949 } 950 }; 951 952 /// Lowering for `async.await_all` operation. 953 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 954 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 955 956 public: 957 explicit AwaitAllOpLowering( 958 TypeConverter &converter, MLIRContext *ctx, 959 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 960 : Base(converter, ctx, outlinedFunctions, kAwaitGroup, 961 kAwaitAllAndExecute) {} 962 }; 963 964 } // namespace 965 966 //===----------------------------------------------------------------------===// 967 // async.yield op lowerings to the corresponding async runtime function calls. 968 //===----------------------------------------------------------------------===// 969 970 class YieldOpLowering : public ConversionPattern { 971 public: 972 explicit YieldOpLowering( 973 TypeConverter &converter, MLIRContext *ctx, 974 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 975 : ConversionPattern(async::YieldOp::getOperationName(), 1, converter, 976 ctx), 977 outlinedFunctions(outlinedFunctions) {} 978 979 LogicalResult 980 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 981 ConversionPatternRewriter &rewriter) const override { 982 // Check if yield operation is inside the outlined coroutine function. 983 auto func = op->template getParentOfType<FuncOp>(); 984 auto outlined = outlinedFunctions.find(func); 985 if (outlined == outlinedFunctions.end()) 986 return op->emitOpError( 987 "async.yield is not inside the outlined coroutine function"); 988 989 Location loc = op->getLoc(); 990 const CoroMachinery &coro = outlined->getSecond(); 991 992 // Store yielded values into the async values storage and emplace them. 993 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); 994 995 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 996 // Store `yieldValue` into the `asyncValue` storage. 997 Value yieldValue = std::get<0>(tuple); 998 Value asyncValue = std::get<1>(tuple); 999 1000 // Get an opaque i8* pointer to an async value storage from the runtime. 1001 auto storage = rewriter.create<CallOp>(loc, kGetValueStorage, 1002 TypeRange(i8Ptr), asyncValue); 1003 1004 // Cast storage pointer to the yielded value type. 1005 auto castedStorage = rewriter.create<LLVM::BitcastOp>( 1006 loc, LLVM::LLVMPointerType::get(yieldValue.getType()), 1007 storage.getResult(0)); 1008 1009 // Store the yielded value into the async value storage. 1010 rewriter.create<LLVM::StoreOp>(loc, yieldValue, 1011 castedStorage.getResult()); 1012 1013 // Emplace the `async.value` to mark it ready. 1014 rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue); 1015 } 1016 1017 // Emplace the completion token to mark it ready. 1018 rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken); 1019 1020 // Original operation was replaced by the function call(s). 1021 rewriter.eraseOp(op); 1022 1023 return success(); 1024 } 1025 1026 private: 1027 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 1028 }; 1029 1030 //===----------------------------------------------------------------------===// 1031 1032 namespace { 1033 struct ConvertAsyncToLLVMPass 1034 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> { 1035 void runOnOperation() override; 1036 }; 1037 1038 void ConvertAsyncToLLVMPass::runOnOperation() { 1039 ModuleOp module = getOperation(); 1040 SymbolTable symbolTable(module); 1041 1042 MLIRContext *ctx = &getContext(); 1043 1044 // Outline all `async.execute` body regions into async functions (coroutines). 1045 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 1046 1047 // We use conversion to LLVM type to ensure that all `async.value` operands 1048 // and results can be lowered to LLVM load and store operations. 1049 LLVMTypeConverter llvmConverter(ctx); 1050 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); 1051 1052 // Returns true if the `async.value` payload is convertible to LLVM. 1053 auto isConvertibleToLlvm = [&](Type type) -> bool { 1054 auto valueType = type.cast<ValueType>().getValueType(); 1055 return static_cast<bool>(llvmConverter.convertType(valueType)); 1056 }; 1057 1058 WalkResult outlineResult = module.walk([&](ExecuteOp execute) { 1059 // All operands and results must be convertible to LLVM. 1060 if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) { 1061 execute.emitOpError("operands payload must be convertible to LLVM type"); 1062 return WalkResult::interrupt(); 1063 } 1064 if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) { 1065 execute.emitOpError("results payload must be convertible to LLVM type"); 1066 return WalkResult::interrupt(); 1067 } 1068 1069 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 1070 1071 return WalkResult::advance(); 1072 }); 1073 1074 // Failed to outline all async execute operations. 1075 if (outlineResult.wasInterrupted()) { 1076 signalPassFailure(); 1077 return; 1078 } 1079 1080 LLVM_DEBUG({ 1081 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 1082 << " async functions\n"; 1083 }); 1084 1085 // Add declarations for all functions required by the coroutines lowering. 1086 addResumeFunction(module); 1087 addAsyncRuntimeApiDeclarations(module); 1088 addCoroutineIntrinsicsDeclarations(module); 1089 addCRuntimeDeclarations(module); 1090 1091 // Convert async dialect types and operations to LLVM dialect. 1092 AsyncRuntimeTypeConverter converter; 1093 OwningRewritePatternList patterns; 1094 1095 // Convert async types in function signatures and function calls. 1096 populateFuncOpTypeConversionPattern(patterns, ctx, converter); 1097 populateCallOpTypeConversionPattern(patterns, ctx, converter); 1098 1099 // Convert return operations inside async.execute regions. 1100 patterns.insert<ReturnOpOpConversion>(converter, ctx); 1101 1102 // Lower async operations to async runtime API calls. 1103 patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx); 1104 patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx); 1105 1106 // Use LLVM type converter to automatically convert between the async value 1107 // payload type and LLVM type when loading/storing from/to the async 1108 // value storage which is an opaque i8* pointer using LLVM load/store ops. 1109 patterns 1110 .insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>( 1111 llvmConverter, ctx, outlinedFunctions); 1112 patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions); 1113 1114 ConversionTarget target(*ctx); 1115 target.addLegalOp<ConstantOp>(); 1116 target.addLegalDialect<LLVM::LLVMDialect>(); 1117 1118 // All operations from Async dialect must be lowered to the runtime API calls. 1119 target.addIllegalDialect<AsyncDialect>(); 1120 1121 // Add dynamic legality constraints to apply conversions defined above. 1122 target.addDynamicallyLegalOp<FuncOp>( 1123 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 1124 target.addDynamicallyLegalOp<ReturnOp>( 1125 [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); 1126 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) { 1127 return converter.isSignatureLegal(op.getCalleeType()); 1128 }); 1129 1130 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 1131 signalPassFailure(); 1132 } 1133 } // namespace 1134 1135 namespace { 1136 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1137 public: 1138 using OpConversionPattern::OpConversionPattern; 1139 LogicalResult 1140 matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands, 1141 ConversionPatternRewriter &rewriter) const override { 1142 ExecuteOp newOp = 1143 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1144 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1145 newOp.getRegion().end()); 1146 1147 // Set operands and update block argument and result types. 1148 newOp->setOperands(operands); 1149 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1150 return failure(); 1151 for (auto result : newOp.getResults()) 1152 result.setType(typeConverter->convertType(result.getType())); 1153 1154 rewriter.replaceOp(op, newOp.getResults()); 1155 return success(); 1156 } 1157 }; 1158 1159 // Dummy pattern to trigger the appropriate type conversion / materialization. 1160 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1161 public: 1162 using OpConversionPattern::OpConversionPattern; 1163 LogicalResult 1164 matchAndRewrite(AwaitOp op, ArrayRef<Value> operands, 1165 ConversionPatternRewriter &rewriter) const override { 1166 rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front()); 1167 return success(); 1168 } 1169 }; 1170 1171 // Dummy pattern to trigger the appropriate type conversion / materialization. 1172 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1173 public: 1174 using OpConversionPattern::OpConversionPattern; 1175 LogicalResult 1176 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 1177 ConversionPatternRewriter &rewriter) const override { 1178 rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands); 1179 return success(); 1180 } 1181 }; 1182 } // namespace 1183 1184 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { 1185 return std::make_unique<ConvertAsyncToLLVMPass>(); 1186 } 1187 1188 void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1189 MLIRContext *context, TypeConverter &typeConverter, 1190 OwningRewritePatternList &patterns, ConversionTarget &target) { 1191 typeConverter.addConversion([&](TokenType type) { return type; }); 1192 typeConverter.addConversion([&](ValueType type) { 1193 return ValueType::get(typeConverter.convertType(type.getValueType())); 1194 }); 1195 1196 patterns 1197 .insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 1198 typeConverter, context); 1199 1200 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1201 [&](Operation *op) { return typeConverter.isLegal(op); }); 1202 } 1203