1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Async/IR/Async.h" 18 #include "mlir/Dialect/Async/Passes.h" 19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "mlir/IR/ImplicitLocOpBuilder.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "mlir/Transforms/RegionUtils.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/Support/Debug.h" 29 30 using namespace mlir; 31 using namespace mlir::async; 32 33 #define DEBUG_TYPE "async-to-async-runtime" 34 // Prefix for functions outlined from `async.execute` op regions. 35 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 36 37 namespace { 38 39 class AsyncToAsyncRuntimePass 40 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 41 public: 42 AsyncToAsyncRuntimePass() = default; 43 void runOnOperation() override; 44 }; 45 46 } // namespace 47 48 //===----------------------------------------------------------------------===// 49 // async.execute op outlining to the coroutine functions. 50 //===----------------------------------------------------------------------===// 51 52 /// Function targeted for coroutine transformation has two additional blocks at 53 /// the end: coroutine cleanup and coroutine suspension. 54 /// 55 /// async.await op lowering additionaly creates a resume block for each 56 /// operation to enable non-blocking waiting via coroutine suspension. 57 namespace { 58 struct CoroMachinery { 59 func::FuncOp func; 60 61 // Async execute region returns a completion token, and an async value for 62 // each yielded value. 63 // 64 // %token, %result = async.execute -> !async.value<T> { 65 // %0 = arith.constant ... : T 66 // async.yield %0 : T 67 // } 68 Value asyncToken; // token representing completion of the async region 69 llvm::SmallVector<Value, 4> returnValues; // returned async values 70 71 Value coroHandle; // coroutine handle (!async.coro.handle value) 72 Block *entry; // coroutine entry block 73 Block *setError; // switch completion token and all values to error state 74 Block *cleanup; // coroutine cleanup block 75 Block *suspend; // coroutine suspension block 76 }; 77 } // namespace 78 79 /// Utility to partially update the regular function CFG to the coroutine CFG 80 /// compatible with LLVM coroutines switched-resume lowering using 81 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block 82 /// that branches into preexisting entry block. Also inserts trailing blocks. 83 /// 84 /// The result types of the passed `func` must start with an `async.token` 85 /// and be continued with some number of `async.value`s. 86 /// 87 /// The func given to this function needs to have been preprocessed to have 88 /// either branch or yield ops as terminators. Branches to the cleanup block are 89 /// inserted after each yield. 90 /// 91 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 92 /// 93 /// - `entry` block sets up the coroutine. 94 /// - `set_error` block sets completion token and async values state to error. 95 /// - `cleanup` block cleans up the coroutine state. 96 /// - `suspend block after the @llvm.coro.end() defines what value will be 97 /// returned to the initial caller of a coroutine. Everything before the 98 /// @llvm.coro.end() will be executed at every suspension point. 99 /// 100 /// Coroutine structure (only the important bits): 101 /// 102 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) 103 /// { 104 /// ^entry(<function-arguments>): 105 /// %token = <async token> : !async.token // create async runtime token 106 /// %value = <async value> : !async.value<T> // create async value 107 /// %id = async.coro.id // create a coroutine id 108 /// %hdl = async.coro.begin %id // create a coroutine handle 109 /// cf.br ^preexisting_entry_block 110 /// 111 /// /* preexisting blocks modified to branch to the cleanup block */ 112 /// 113 /// ^set_error: // this block created lazily only if needed (see code below) 114 /// async.runtime.set_error %token : !async.token 115 /// async.runtime.set_error %value : !async.value<T> 116 /// cf.br ^cleanup 117 /// 118 /// ^cleanup: 119 /// async.coro.free %hdl // delete the coroutine state 120 /// cf.br ^suspend 121 /// 122 /// ^suspend: 123 /// async.coro.end %hdl // marks the end of a coroutine 124 /// return %token, %value : !async.token, !async.value<T> 125 /// } 126 /// 127 static CoroMachinery setupCoroMachinery(func::FuncOp func) { 128 assert(!func.getBlocks().empty() && "Function must have an entry block"); 129 130 MLIRContext *ctx = func.getContext(); 131 Block *entryBlock = &func.getBlocks().front(); 132 Block *originalEntryBlock = 133 entryBlock->splitBlock(entryBlock->getOperations().begin()); 134 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 135 136 // ------------------------------------------------------------------------ // 137 // Allocate async token/values that we will return from a ramp function. 138 // ------------------------------------------------------------------------ // 139 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 140 141 llvm::SmallVector<Value, 4> retValues; 142 for (auto resType : func.getCallableResults().drop_front()) 143 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 144 145 // ------------------------------------------------------------------------ // 146 // Initialize coroutine: get coroutine id and coroutine handle. 147 // ------------------------------------------------------------------------ // 148 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 149 auto coroHdlOp = 150 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 151 builder.create<cf::BranchOp>(originalEntryBlock); 152 153 Block *cleanupBlock = func.addBlock(); 154 Block *suspendBlock = func.addBlock(); 155 156 // ------------------------------------------------------------------------ // 157 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 158 // ------------------------------------------------------------------------ // 159 builder.setInsertionPointToStart(cleanupBlock); 160 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 161 162 // Branch into the suspend block. 163 builder.create<cf::BranchOp>(suspendBlock); 164 165 // ------------------------------------------------------------------------ // 166 // Coroutine suspend block: mark the end of a coroutine and return allocated 167 // async token. 168 // ------------------------------------------------------------------------ // 169 builder.setInsertionPointToStart(suspendBlock); 170 171 // Mark the end of a coroutine: async.coro.end 172 builder.create<CoroEndOp>(coroHdlOp.handle()); 173 174 // Return created `async.token` and `async.values` from the suspend block. 175 // This will be the return value of a coroutine ramp function. 176 SmallVector<Value, 4> ret{retToken}; 177 ret.insert(ret.end(), retValues.begin(), retValues.end()); 178 builder.create<func::ReturnOp>(ret); 179 180 // `async.await` op lowering will create resume blocks for async 181 // continuations, and will conditionally branch to cleanup or suspend blocks. 182 183 for (Block &block : func.getBody().getBlocks()) { 184 if (&block == entryBlock || &block == cleanupBlock || 185 &block == suspendBlock) 186 continue; 187 Operation *terminator = block.getTerminator(); 188 if (auto yield = dyn_cast<YieldOp>(terminator)) { 189 builder.setInsertionPointToEnd(&block); 190 builder.create<cf::BranchOp>(cleanupBlock); 191 } 192 } 193 194 // The switch-resumed API based coroutine should be marked with 195 // "coroutine.presplit" attribute with value "0" to mark the function as a 196 // coroutine. 197 func->setAttr("passthrough", builder.getArrayAttr(builder.getArrayAttr( 198 {builder.getStringAttr("coroutine.presplit"), 199 builder.getStringAttr("0")}))); 200 201 CoroMachinery machinery; 202 machinery.func = func; 203 machinery.asyncToken = retToken; 204 machinery.returnValues = retValues; 205 machinery.coroHandle = coroHdlOp.handle(); 206 machinery.entry = entryBlock; 207 machinery.setError = nullptr; // created lazily only if needed 208 machinery.cleanup = cleanupBlock; 209 machinery.suspend = suspendBlock; 210 return machinery; 211 } 212 213 // Lazily creates `set_error` block only if it is required for lowering to the 214 // runtime operations (see for example lowering of assert operation). 215 static Block *setupSetErrorBlock(CoroMachinery &coro) { 216 if (coro.setError) 217 return coro.setError; 218 219 coro.setError = coro.func.addBlock(); 220 coro.setError->moveBefore(coro.cleanup); 221 222 auto builder = 223 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 224 225 // Coroutine set_error block: set error on token and all returned values. 226 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 227 for (Value retValue : coro.returnValues) 228 builder.create<RuntimeSetErrorOp>(retValue); 229 230 // Branch into the cleanup block. 231 builder.create<cf::BranchOp>(coro.cleanup); 232 233 return coro.setError; 234 } 235 236 /// Outline the body region attached to the `async.execute` op into a standalone 237 /// function. 238 /// 239 /// Note that this is not reversible transformation. 240 static std::pair<func::FuncOp, CoroMachinery> 241 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 242 ModuleOp module = execute->getParentOfType<ModuleOp>(); 243 244 MLIRContext *ctx = module.getContext(); 245 Location loc = execute.getLoc(); 246 247 // Make sure that all constants will be inside the outlined async function to 248 // reduce the number of function arguments. 249 cloneConstantsIntoTheRegion(execute.body()); 250 251 // Collect all outlined function inputs. 252 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 253 execute.dependencies().end()); 254 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 255 getUsedValuesDefinedAbove(execute.body(), functionInputs); 256 257 // Collect types for the outlined function inputs and outputs. 258 auto typesRange = llvm::map_range( 259 functionInputs, [](Value value) { return value.getType(); }); 260 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 261 auto outputTypes = execute.getResultTypes(); 262 263 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 264 auto funcAttrs = ArrayRef<NamedAttribute>(); 265 266 // TODO: Derive outlined function name from the parent FuncOp (support 267 // multiple nested async.execute operations). 268 func::FuncOp func = 269 func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 270 symbolTable.insert(func); 271 272 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 273 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); 274 275 // Prepare for coroutine conversion by creating the body of the function. 276 { 277 size_t numDependencies = execute.dependencies().size(); 278 size_t numOperands = execute.operands().size(); 279 280 // Await on all dependencies before starting to execute the body region. 281 for (size_t i = 0; i < numDependencies; ++i) 282 builder.create<AwaitOp>(func.getArgument(i)); 283 284 // Await on all async value operands and unwrap the payload. 285 SmallVector<Value, 4> unwrappedOperands(numOperands); 286 for (size_t i = 0; i < numOperands; ++i) { 287 Value operand = func.getArgument(numDependencies + i); 288 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 289 } 290 291 // Map from function inputs defined above the execute op to the function 292 // arguments. 293 BlockAndValueMapping valueMapping; 294 valueMapping.map(functionInputs, func.getArguments()); 295 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 296 297 // Clone all operations from the execute operation body into the outlined 298 // function body. 299 for (Operation &op : execute.body().getOps()) 300 builder.clone(op, valueMapping); 301 } 302 303 // Adding entry/cleanup/suspend blocks. 304 CoroMachinery coro = setupCoroMachinery(func); 305 306 // Suspend async function at the end of an entry block, and resume it using 307 // Async resume operation (execution will be resumed in a thread managed by 308 // the async runtime). 309 { 310 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator()); 311 builder.setInsertionPointToEnd(coro.entry); 312 313 // Save the coroutine state: async.coro.save 314 auto coroSaveOp = 315 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 316 317 // Pass coroutine to the runtime to be resumed on a runtime managed 318 // thread. 319 builder.create<RuntimeResumeOp>(coro.coroHandle); 320 321 // Add async.coro.suspend as a suspended block terminator. 322 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, 323 branch.getDest(), coro.cleanup); 324 325 branch.erase(); 326 } 327 328 // Replace the original `async.execute` with a call to outlined function. 329 { 330 ImplicitLocOpBuilder callBuilder(loc, execute); 331 auto callOutlinedFunc = callBuilder.create<func::CallOp>( 332 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 333 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 334 execute.erase(); 335 } 336 337 return {func, coro}; 338 } 339 340 //===----------------------------------------------------------------------===// 341 // Convert async.create_group operation to async.runtime.create_group 342 //===----------------------------------------------------------------------===// 343 344 namespace { 345 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 346 public: 347 using OpConversionPattern::OpConversionPattern; 348 349 LogicalResult 350 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor, 351 ConversionPatternRewriter &rewriter) const override { 352 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 353 op, GroupType::get(op->getContext()), adaptor.getOperands()); 354 return success(); 355 } 356 }; 357 } // namespace 358 359 //===----------------------------------------------------------------------===// 360 // Convert async.add_to_group operation to async.runtime.add_to_group. 361 //===----------------------------------------------------------------------===// 362 363 namespace { 364 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 365 public: 366 using OpConversionPattern::OpConversionPattern; 367 368 LogicalResult 369 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor, 370 ConversionPatternRewriter &rewriter) const override { 371 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 372 op, rewriter.getIndexType(), adaptor.getOperands()); 373 return success(); 374 } 375 }; 376 } // namespace 377 378 //===----------------------------------------------------------------------===// 379 // Convert async.await and async.await_all operations to the async.runtime.await 380 // or async.runtime.await_and_resume operations. 381 //===----------------------------------------------------------------------===// 382 383 namespace { 384 template <typename AwaitType, typename AwaitableType> 385 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 386 using AwaitAdaptor = typename AwaitType::Adaptor; 387 388 public: 389 AwaitOpLoweringBase( 390 MLIRContext *ctx, 391 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) 392 : OpConversionPattern<AwaitType>(ctx), 393 outlinedFunctions(outlinedFunctions) {} 394 395 LogicalResult 396 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, 397 ConversionPatternRewriter &rewriter) const override { 398 // We can only await on one the `AwaitableType` (for `await` it can be 399 // a `token` or a `value`, for `await_all` it must be a `group`). 400 if (!op.operand().getType().template isa<AwaitableType>()) 401 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 402 403 // Check if await operation is inside the outlined coroutine function. 404 auto func = op->template getParentOfType<func::FuncOp>(); 405 auto outlined = outlinedFunctions.find(func); 406 const bool isInCoroutine = outlined != outlinedFunctions.end(); 407 408 Location loc = op->getLoc(); 409 Value operand = adaptor.operand(); 410 411 Type i1 = rewriter.getI1Type(); 412 413 // Inside regular functions we use the blocking wait operation to wait for 414 // the async object (token, value or group) to become available. 415 if (!isInCoroutine) { 416 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 417 builder.create<RuntimeAwaitOp>(loc, operand); 418 419 // Assert that the awaited operands is not in the error state. 420 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand); 421 Value notError = builder.create<arith::XOrIOp>( 422 isError, builder.create<arith::ConstantOp>( 423 loc, i1, builder.getIntegerAttr(i1, 1))); 424 425 builder.create<cf::AssertOp>(notError, 426 "Awaited async operand is in error state"); 427 } 428 429 // Inside the coroutine we convert await operation into coroutine suspension 430 // point, and resume execution asynchronously. 431 if (isInCoroutine) { 432 CoroMachinery &coro = outlined->getSecond(); 433 Block *suspended = op->getBlock(); 434 435 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 436 MLIRContext *ctx = op->getContext(); 437 438 // Save the coroutine state and resume on a runtime managed thread when 439 // the operand becomes available. 440 auto coroSaveOp = 441 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 442 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 443 444 // Split the entry block before the await operation. 445 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 446 447 // Add async.coro.suspend as a suspended block terminator. 448 builder.setInsertionPointToEnd(suspended); 449 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 450 coro.cleanup); 451 452 // Split the resume block into error checking and continuation. 453 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 454 455 // Check if the awaited value is in the error state. 456 builder.setInsertionPointToStart(resume); 457 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand); 458 builder.create<cf::CondBranchOp>(isError, 459 /*trueDest=*/setupSetErrorBlock(coro), 460 /*trueArgs=*/ArrayRef<Value>(), 461 /*falseDest=*/continuation, 462 /*falseArgs=*/ArrayRef<Value>()); 463 464 // Make sure that replacement value will be constructed in the 465 // continuation block. 466 rewriter.setInsertionPointToStart(continuation); 467 } 468 469 // Erase or replace the await operation with the new value. 470 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 471 rewriter.replaceOp(op, replaceWith); 472 else 473 rewriter.eraseOp(op); 474 475 return success(); 476 } 477 478 virtual Value getReplacementValue(AwaitType op, Value operand, 479 ConversionPatternRewriter &rewriter) const { 480 return Value(); 481 } 482 483 private: 484 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions; 485 }; 486 487 /// Lowering for `async.await` with a token operand. 488 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 489 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 490 491 public: 492 using Base::Base; 493 }; 494 495 /// Lowering for `async.await` with a value operand. 496 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 497 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 498 499 public: 500 using Base::Base; 501 502 Value 503 getReplacementValue(AwaitOp op, Value operand, 504 ConversionPatternRewriter &rewriter) const override { 505 // Load from the async value storage. 506 auto valueType = operand.getType().cast<ValueType>().getValueType(); 507 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 508 } 509 }; 510 511 /// Lowering for `async.await_all` operation. 512 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 513 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 514 515 public: 516 using Base::Base; 517 }; 518 519 } // namespace 520 521 //===----------------------------------------------------------------------===// 522 // Convert async.yield operation to async.runtime operations. 523 //===----------------------------------------------------------------------===// 524 525 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 526 public: 527 YieldOpLowering( 528 MLIRContext *ctx, 529 const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) 530 : OpConversionPattern<async::YieldOp>(ctx), 531 outlinedFunctions(outlinedFunctions) {} 532 533 LogicalResult 534 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 535 ConversionPatternRewriter &rewriter) const override { 536 // Check if yield operation is inside the async coroutine function. 537 auto func = op->template getParentOfType<func::FuncOp>(); 538 auto outlined = outlinedFunctions.find(func); 539 if (outlined == outlinedFunctions.end()) 540 return rewriter.notifyMatchFailure( 541 op, "operation is not inside the async coroutine function"); 542 543 Location loc = op->getLoc(); 544 const CoroMachinery &coro = outlined->getSecond(); 545 546 // Store yielded values into the async values storage and switch async 547 // values state to available. 548 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { 549 Value yieldValue = std::get<0>(tuple); 550 Value asyncValue = std::get<1>(tuple); 551 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 552 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 553 } 554 555 // Switch the coroutine completion token to available state. 556 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 557 558 return success(); 559 } 560 561 private: 562 const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions; 563 }; 564 565 //===----------------------------------------------------------------------===// 566 // Convert cf.assert operation to cf.cond_br into `set_error` block. 567 //===----------------------------------------------------------------------===// 568 569 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> { 570 public: 571 AssertOpLowering( 572 MLIRContext *ctx, 573 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) 574 : OpConversionPattern<cf::AssertOp>(ctx), 575 outlinedFunctions(outlinedFunctions) {} 576 577 LogicalResult 578 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 579 ConversionPatternRewriter &rewriter) const override { 580 // Check if assert operation is inside the async coroutine function. 581 auto func = op->template getParentOfType<func::FuncOp>(); 582 auto outlined = outlinedFunctions.find(func); 583 if (outlined == outlinedFunctions.end()) 584 return rewriter.notifyMatchFailure( 585 op, "operation is not inside the async coroutine function"); 586 587 Location loc = op->getLoc(); 588 CoroMachinery &coro = outlined->getSecond(); 589 590 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 591 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 592 rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), 593 /*trueDest=*/cont, 594 /*trueArgs=*/ArrayRef<Value>(), 595 /*falseDest=*/setupSetErrorBlock(coro), 596 /*falseArgs=*/ArrayRef<Value>()); 597 rewriter.eraseOp(op); 598 599 return success(); 600 } 601 602 private: 603 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions; 604 }; 605 606 //===----------------------------------------------------------------------===// 607 608 /// Rewrite a func as a coroutine by: 609 /// 1) Wrapping the results into `async.value`. 610 /// 2) Prepending the results with `async.token`. 611 /// 3) Setting up coroutine blocks. 612 /// 4) Rewriting return ops as yield op and branch op into the suspend block. 613 static CoroMachinery rewriteFuncAsCoroutine(func::FuncOp func) { 614 auto *ctx = func->getContext(); 615 auto loc = func.getLoc(); 616 SmallVector<Type> resultTypes; 617 resultTypes.reserve(func.getCallableResults().size()); 618 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), 619 [](Type type) { return ValueType::get(type); }); 620 func.setType( 621 FunctionType::get(ctx, func.getFunctionType().getInputs(), resultTypes)); 622 func.insertResult(0, TokenType::get(ctx), {}); 623 for (Block &block : func.getBlocks()) { 624 Operation *terminator = block.getTerminator(); 625 if (auto returnOp = dyn_cast<func::ReturnOp>(*terminator)) { 626 ImplicitLocOpBuilder builder(loc, returnOp); 627 builder.create<YieldOp>(returnOp.getOperands()); 628 returnOp.erase(); 629 } 630 } 631 return setupCoroMachinery(func); 632 } 633 634 /// Rewrites a call into a function that has been rewritten as a coroutine. 635 /// 636 /// The invocation of this function is safe only when call ops are traversed in 637 /// reverse order of how they appear in a single block. See `funcsToCoroutines`. 638 static void rewriteCallsiteForCoroutine(func::CallOp oldCall, 639 func::FuncOp func) { 640 auto loc = func.getLoc(); 641 ImplicitLocOpBuilder callBuilder(loc, oldCall); 642 auto newCall = callBuilder.create<func::CallOp>( 643 func.getName(), func.getCallableResults(), oldCall.getArgOperands()); 644 645 // Await on the async token and all the value results and unwrap the latter. 646 callBuilder.create<AwaitOp>(loc, newCall.getResults().front()); 647 SmallVector<Value> unwrappedResults; 648 unwrappedResults.reserve(newCall->getResults().size() - 1); 649 for (Value result : newCall.getResults().drop_front()) 650 unwrappedResults.push_back( 651 callBuilder.create<AwaitOp>(loc, result).result()); 652 // Careful, when result of a call is piped into another call this could lead 653 // to a dangling pointer. 654 oldCall.replaceAllUsesWith(unwrappedResults); 655 oldCall.erase(); 656 } 657 658 static bool isAllowedToBlock(func::FuncOp func) { 659 return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName); 660 } 661 662 static LogicalResult funcsToCoroutines( 663 ModuleOp module, 664 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) { 665 // The following code supports the general case when 2 functions mutually 666 // recurse into each other. Because of this and that we are relying on 667 // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase 668 // a FuncOp while inserting an equivalent coroutine, because that could lead 669 // to dangling pointers. 670 671 SmallVector<func::FuncOp> funcWorklist; 672 673 // Careful, it's okay to add a func to the worklist multiple times if and only 674 // if the loop processing the worklist will skip the functions that have 675 // already been converted to coroutines. 676 auto addToWorklist = [&](func::FuncOp func) { 677 if (isAllowedToBlock(func)) 678 return; 679 // N.B. To refactor this code into a separate pass the lookup in 680 // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary 681 // func and recognizing if it has a coroutine structure is messy. Passing 682 // this dict between the passes is ugly. 683 if (isAllowedToBlock(func) || 684 outlinedFunctions.find(func) == outlinedFunctions.end()) { 685 for (Operation &op : func.getBody().getOps()) { 686 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) { 687 funcWorklist.push_back(func); 688 break; 689 } 690 } 691 } 692 }; 693 694 // Traverse in post-order collecting for each func op the await ops it has. 695 for (func::FuncOp func : module.getOps<func::FuncOp>()) 696 addToWorklist(func); 697 698 SymbolTableCollection symbolTable; 699 SymbolUserMap symbolUserMap(symbolTable, module); 700 701 // Rewrite funcs, while updating call sites and adding them to the worklist. 702 while (!funcWorklist.empty()) { 703 auto func = funcWorklist.pop_back_val(); 704 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); 705 if (!insertion.second) 706 // This function has already been processed because this is either 707 // the corecursive case, or a caller with multiple calls to a newly 708 // created corouting. Either way, skip updating the call sites. 709 continue; 710 insertion.first->second = rewriteFuncAsCoroutine(func); 711 SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(), 712 symbolUserMap.getUsers(func).end()); 713 // If there are multiple calls from the same block they need to be traversed 714 // in reverse order so that symbolUserMap references are not invalidated 715 // when updating the users of the call op which is earlier in the block. 716 llvm::sort(users, [](Operation *a, Operation *b) { 717 Block *blockA = a->getBlock(); 718 Block *blockB = b->getBlock(); 719 // Impose arbitrary order on blocks so that there is a well-defined order. 720 return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); 721 }); 722 // Rewrite the callsites to await on results of the newly created coroutine. 723 for (Operation *op : users) { 724 if (func::CallOp call = dyn_cast<func::CallOp>(*op)) { 725 func::FuncOp caller = call->getParentOfType<func::FuncOp>(); 726 rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. 727 addToWorklist(caller); 728 } else { 729 op->emitError("Unexpected reference to func referenced by symbol"); 730 return failure(); 731 } 732 } 733 } 734 return success(); 735 } 736 737 //===----------------------------------------------------------------------===// 738 void AsyncToAsyncRuntimePass::runOnOperation() { 739 ModuleOp module = getOperation(); 740 SymbolTable symbolTable(module); 741 742 // Outline all `async.execute` body regions into async functions (coroutines). 743 llvm::DenseMap<func::FuncOp, CoroMachinery> outlinedFunctions; 744 745 module.walk([&](ExecuteOp execute) { 746 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 747 }); 748 749 LLVM_DEBUG({ 750 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 751 << " functions built from async.execute operations\n"; 752 }); 753 754 // Returns true if operation is inside the coroutine. 755 auto isInCoroutine = [&](Operation *op) -> bool { 756 auto parentFunc = op->getParentOfType<func::FuncOp>(); 757 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 758 }; 759 760 if (eliminateBlockingAwaitOps && 761 failed(funcsToCoroutines(module, outlinedFunctions))) { 762 signalPassFailure(); 763 return; 764 } 765 766 // Lower async operations to async.runtime operations. 767 MLIRContext *ctx = module->getContext(); 768 RewritePatternSet asyncPatterns(ctx); 769 770 // Conversion to async runtime augments original CFG with the coroutine CFG, 771 // and we have to make sure that structured control flow operations with async 772 // operations in nested regions will be converted to branch-based control flow 773 // before we add the coroutine basic blocks. 774 populateSCFToControlFlowConversionPatterns(asyncPatterns); 775 776 // Async lowering does not use type converter because it must preserve all 777 // types for async.runtime operations. 778 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 779 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 780 AwaitAllOpLowering, YieldOpLowering>(ctx, 781 outlinedFunctions); 782 783 // Lower assertions to conditional branches into error blocks. 784 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 785 786 // All high level async operations must be lowered to the runtime operations. 787 ConversionTarget runtimeTarget(*ctx); 788 runtimeTarget.addLegalDialect<AsyncDialect>(); 789 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 790 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 791 792 // Decide if structured control flow has to be lowered to branch-based CFG. 793 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 794 auto walkResult = op->walk([&](Operation *nested) { 795 bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 796 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 797 : WalkResult::advance(); 798 }); 799 return !walkResult.wasInterrupted(); 800 }); 801 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp, 802 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>(); 803 804 // Assertions must be converted to runtime errors inside async functions. 805 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>( 806 [&](cf::AssertOp op) -> bool { 807 auto func = op->getParentOfType<func::FuncOp>(); 808 return outlinedFunctions.find(func) == outlinedFunctions.end(); 809 }); 810 811 if (eliminateBlockingAwaitOps) 812 runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>( 813 [&](RuntimeAwaitOp op) -> bool { 814 return isAllowedToBlock(op->getParentOfType<func::FuncOp>()); 815 }); 816 817 if (failed(applyPartialConversion(module, runtimeTarget, 818 std::move(asyncPatterns)))) { 819 signalPassFailure(); 820 return; 821 } 822 } 823 824 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 825 return std::make_unique<AsyncToAsyncRuntimePass>(); 826 } 827