1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Async/IR/Async.h" 18 #include "mlir/Dialect/Async/Passes.h" 19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "mlir/IR/ImplicitLocOpBuilder.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "mlir/Transforms/RegionUtils.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/Support/Debug.h" 29 30 using namespace mlir; 31 using namespace mlir::async; 32 33 #define DEBUG_TYPE "async-to-async-runtime" 34 // Prefix for functions outlined from `async.execute` op regions. 35 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 36 37 namespace { 38 39 class AsyncToAsyncRuntimePass 40 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 41 public: 42 AsyncToAsyncRuntimePass() = default; 43 void runOnOperation() override; 44 }; 45 46 } // namespace 47 48 //===----------------------------------------------------------------------===// 49 // async.execute op outlining to the coroutine functions. 50 //===----------------------------------------------------------------------===// 51 52 /// Function targeted for coroutine transformation has two additional blocks at 53 /// the end: coroutine cleanup and coroutine suspension. 54 /// 55 /// async.await op lowering additionaly creates a resume block for each 56 /// operation to enable non-blocking waiting via coroutine suspension. 57 namespace { 58 struct CoroMachinery { 59 FuncOp func; 60 61 // Async execute region returns a completion token, and an async value for 62 // each yielded value. 63 // 64 // %token, %result = async.execute -> !async.value<T> { 65 // %0 = arith.constant ... : T 66 // async.yield %0 : T 67 // } 68 Value asyncToken; // token representing completion of the async region 69 llvm::SmallVector<Value, 4> returnValues; // returned async values 70 71 Value coroHandle; // coroutine handle (!async.coro.handle value) 72 Block *entry; // coroutine entry block 73 Block *setError; // switch completion token and all values to error state 74 Block *cleanup; // coroutine cleanup block 75 Block *suspend; // coroutine suspension block 76 }; 77 } // namespace 78 79 /// Utility to partially update the regular function CFG to the coroutine CFG 80 /// compatible with LLVM coroutines switched-resume lowering using 81 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block 82 /// that branches into preexisting entry block. Also inserts trailing blocks. 83 /// 84 /// The result types of the passed `func` must start with an `async.token` 85 /// and be continued with some number of `async.value`s. 86 /// 87 /// The func given to this function needs to have been preprocessed to have 88 /// either branch or yield ops as terminators. Branches to the cleanup block are 89 /// inserted after each yield. 90 /// 91 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 92 /// 93 /// - `entry` block sets up the coroutine. 94 /// - `set_error` block sets completion token and async values state to error. 95 /// - `cleanup` block cleans up the coroutine state. 96 /// - `suspend block after the @llvm.coro.end() defines what value will be 97 /// returned to the initial caller of a coroutine. Everything before the 98 /// @llvm.coro.end() will be executed at every suspension point. 99 /// 100 /// Coroutine structure (only the important bits): 101 /// 102 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) 103 /// { 104 /// ^entry(<function-arguments>): 105 /// %token = <async token> : !async.token // create async runtime token 106 /// %value = <async value> : !async.value<T> // create async value 107 /// %id = async.coro.id // create a coroutine id 108 /// %hdl = async.coro.begin %id // create a coroutine handle 109 /// cf.br ^preexisting_entry_block 110 /// 111 /// /* preexisting blocks modified to branch to the cleanup block */ 112 /// 113 /// ^set_error: // this block created lazily only if needed (see code below) 114 /// async.runtime.set_error %token : !async.token 115 /// async.runtime.set_error %value : !async.value<T> 116 /// cf.br ^cleanup 117 /// 118 /// ^cleanup: 119 /// async.coro.free %hdl // delete the coroutine state 120 /// cf.br ^suspend 121 /// 122 /// ^suspend: 123 /// async.coro.end %hdl // marks the end of a coroutine 124 /// return %token, %value : !async.token, !async.value<T> 125 /// } 126 /// 127 static CoroMachinery setupCoroMachinery(FuncOp func) { 128 assert(!func.getBlocks().empty() && "Function must have an entry block"); 129 130 MLIRContext *ctx = func.getContext(); 131 Block *entryBlock = &func.getBlocks().front(); 132 Block *originalEntryBlock = 133 entryBlock->splitBlock(entryBlock->getOperations().begin()); 134 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 135 136 // ------------------------------------------------------------------------ // 137 // Allocate async token/values that we will return from a ramp function. 138 // ------------------------------------------------------------------------ // 139 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 140 141 llvm::SmallVector<Value, 4> retValues; 142 for (auto resType : func.getCallableResults().drop_front()) 143 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 144 145 // ------------------------------------------------------------------------ // 146 // Initialize coroutine: get coroutine id and coroutine handle. 147 // ------------------------------------------------------------------------ // 148 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 149 auto coroHdlOp = 150 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 151 builder.create<cf::BranchOp>(originalEntryBlock); 152 153 Block *cleanupBlock = func.addBlock(); 154 Block *suspendBlock = func.addBlock(); 155 156 // ------------------------------------------------------------------------ // 157 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 158 // ------------------------------------------------------------------------ // 159 builder.setInsertionPointToStart(cleanupBlock); 160 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 161 162 // Branch into the suspend block. 163 builder.create<cf::BranchOp>(suspendBlock); 164 165 // ------------------------------------------------------------------------ // 166 // Coroutine suspend block: mark the end of a coroutine and return allocated 167 // async token. 168 // ------------------------------------------------------------------------ // 169 builder.setInsertionPointToStart(suspendBlock); 170 171 // Mark the end of a coroutine: async.coro.end 172 builder.create<CoroEndOp>(coroHdlOp.handle()); 173 174 // Return created `async.token` and `async.values` from the suspend block. 175 // This will be the return value of a coroutine ramp function. 176 SmallVector<Value, 4> ret{retToken}; 177 ret.insert(ret.end(), retValues.begin(), retValues.end()); 178 builder.create<ReturnOp>(ret); 179 180 // `async.await` op lowering will create resume blocks for async 181 // continuations, and will conditionally branch to cleanup or suspend blocks. 182 183 for (Block &block : func.body().getBlocks()) { 184 if (&block == entryBlock || &block == cleanupBlock || 185 &block == suspendBlock) 186 continue; 187 Operation *terminator = block.getTerminator(); 188 if (auto yield = dyn_cast<YieldOp>(terminator)) { 189 builder.setInsertionPointToEnd(&block); 190 builder.create<cf::BranchOp>(cleanupBlock); 191 } 192 } 193 194 // The switch-resumed API based coroutine should be marked with 195 // "coroutine.presplit" attribute with value "0" to mark the function as a 196 // coroutine. 197 func->setAttr("passthrough", builder.getArrayAttr(builder.getArrayAttr( 198 {builder.getStringAttr("coroutine.presplit"), 199 builder.getStringAttr("0")}))); 200 201 CoroMachinery machinery; 202 machinery.func = func; 203 machinery.asyncToken = retToken; 204 machinery.returnValues = retValues; 205 machinery.coroHandle = coroHdlOp.handle(); 206 machinery.entry = entryBlock; 207 machinery.setError = nullptr; // created lazily only if needed 208 machinery.cleanup = cleanupBlock; 209 machinery.suspend = suspendBlock; 210 return machinery; 211 } 212 213 // Lazily creates `set_error` block only if it is required for lowering to the 214 // runtime operations (see for example lowering of assert operation). 215 static Block *setupSetErrorBlock(CoroMachinery &coro) { 216 if (coro.setError) 217 return coro.setError; 218 219 coro.setError = coro.func.addBlock(); 220 coro.setError->moveBefore(coro.cleanup); 221 222 auto builder = 223 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 224 225 // Coroutine set_error block: set error on token and all returned values. 226 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 227 for (Value retValue : coro.returnValues) 228 builder.create<RuntimeSetErrorOp>(retValue); 229 230 // Branch into the cleanup block. 231 builder.create<cf::BranchOp>(coro.cleanup); 232 233 return coro.setError; 234 } 235 236 /// Outline the body region attached to the `async.execute` op into a standalone 237 /// function. 238 /// 239 /// Note that this is not reversible transformation. 240 static std::pair<FuncOp, CoroMachinery> 241 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 242 ModuleOp module = execute->getParentOfType<ModuleOp>(); 243 244 MLIRContext *ctx = module.getContext(); 245 Location loc = execute.getLoc(); 246 247 // Make sure that all constants will be inside the outlined async function to 248 // reduce the number of function arguments. 249 cloneConstantsIntoTheRegion(execute.body()); 250 251 // Collect all outlined function inputs. 252 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 253 execute.dependencies().end()); 254 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 255 getUsedValuesDefinedAbove(execute.body(), functionInputs); 256 257 // Collect types for the outlined function inputs and outputs. 258 auto typesRange = llvm::map_range( 259 functionInputs, [](Value value) { return value.getType(); }); 260 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 261 auto outputTypes = execute.getResultTypes(); 262 263 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 264 auto funcAttrs = ArrayRef<NamedAttribute>(); 265 266 // TODO: Derive outlined function name from the parent FuncOp (support 267 // multiple nested async.execute operations). 268 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 269 symbolTable.insert(func); 270 271 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 272 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); 273 274 // Prepare for coroutine conversion by creating the body of the function. 275 { 276 size_t numDependencies = execute.dependencies().size(); 277 size_t numOperands = execute.operands().size(); 278 279 // Await on all dependencies before starting to execute the body region. 280 for (size_t i = 0; i < numDependencies; ++i) 281 builder.create<AwaitOp>(func.getArgument(i)); 282 283 // Await on all async value operands and unwrap the payload. 284 SmallVector<Value, 4> unwrappedOperands(numOperands); 285 for (size_t i = 0; i < numOperands; ++i) { 286 Value operand = func.getArgument(numDependencies + i); 287 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 288 } 289 290 // Map from function inputs defined above the execute op to the function 291 // arguments. 292 BlockAndValueMapping valueMapping; 293 valueMapping.map(functionInputs, func.getArguments()); 294 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 295 296 // Clone all operations from the execute operation body into the outlined 297 // function body. 298 for (Operation &op : execute.body().getOps()) 299 builder.clone(op, valueMapping); 300 } 301 302 // Adding entry/cleanup/suspend blocks. 303 CoroMachinery coro = setupCoroMachinery(func); 304 305 // Suspend async function at the end of an entry block, and resume it using 306 // Async resume operation (execution will be resumed in a thread managed by 307 // the async runtime). 308 { 309 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator()); 310 builder.setInsertionPointToEnd(coro.entry); 311 312 // Save the coroutine state: async.coro.save 313 auto coroSaveOp = 314 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 315 316 // Pass coroutine to the runtime to be resumed on a runtime managed 317 // thread. 318 builder.create<RuntimeResumeOp>(coro.coroHandle); 319 320 // Add async.coro.suspend as a suspended block terminator. 321 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, 322 branch.getDest(), coro.cleanup); 323 324 branch.erase(); 325 } 326 327 // Replace the original `async.execute` with a call to outlined function. 328 { 329 ImplicitLocOpBuilder callBuilder(loc, execute); 330 auto callOutlinedFunc = callBuilder.create<CallOp>( 331 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 332 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 333 execute.erase(); 334 } 335 336 return {func, coro}; 337 } 338 339 //===----------------------------------------------------------------------===// 340 // Convert async.create_group operation to async.runtime.create_group 341 //===----------------------------------------------------------------------===// 342 343 namespace { 344 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 345 public: 346 using OpConversionPattern::OpConversionPattern; 347 348 LogicalResult 349 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor, 350 ConversionPatternRewriter &rewriter) const override { 351 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 352 op, GroupType::get(op->getContext()), adaptor.getOperands()); 353 return success(); 354 } 355 }; 356 } // namespace 357 358 //===----------------------------------------------------------------------===// 359 // Convert async.add_to_group operation to async.runtime.add_to_group. 360 //===----------------------------------------------------------------------===// 361 362 namespace { 363 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 364 public: 365 using OpConversionPattern::OpConversionPattern; 366 367 LogicalResult 368 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor, 369 ConversionPatternRewriter &rewriter) const override { 370 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 371 op, rewriter.getIndexType(), adaptor.getOperands()); 372 return success(); 373 } 374 }; 375 } // namespace 376 377 //===----------------------------------------------------------------------===// 378 // Convert async.await and async.await_all operations to the async.runtime.await 379 // or async.runtime.await_and_resume operations. 380 //===----------------------------------------------------------------------===// 381 382 namespace { 383 template <typename AwaitType, typename AwaitableType> 384 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 385 using AwaitAdaptor = typename AwaitType::Adaptor; 386 387 public: 388 AwaitOpLoweringBase(MLIRContext *ctx, 389 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 390 : OpConversionPattern<AwaitType>(ctx), 391 outlinedFunctions(outlinedFunctions) {} 392 393 LogicalResult 394 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, 395 ConversionPatternRewriter &rewriter) const override { 396 // We can only await on one the `AwaitableType` (for `await` it can be 397 // a `token` or a `value`, for `await_all` it must be a `group`). 398 if (!op.operand().getType().template isa<AwaitableType>()) 399 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 400 401 // Check if await operation is inside the outlined coroutine function. 402 auto func = op->template getParentOfType<FuncOp>(); 403 auto outlined = outlinedFunctions.find(func); 404 const bool isInCoroutine = outlined != outlinedFunctions.end(); 405 406 Location loc = op->getLoc(); 407 Value operand = adaptor.operand(); 408 409 Type i1 = rewriter.getI1Type(); 410 411 // Inside regular functions we use the blocking wait operation to wait for 412 // the async object (token, value or group) to become available. 413 if (!isInCoroutine) { 414 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 415 builder.create<RuntimeAwaitOp>(loc, operand); 416 417 // Assert that the awaited operands is not in the error state. 418 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand); 419 Value notError = builder.create<arith::XOrIOp>( 420 isError, builder.create<arith::ConstantOp>( 421 loc, i1, builder.getIntegerAttr(i1, 1))); 422 423 builder.create<cf::AssertOp>(notError, 424 "Awaited async operand is in error state"); 425 } 426 427 // Inside the coroutine we convert await operation into coroutine suspension 428 // point, and resume execution asynchronously. 429 if (isInCoroutine) { 430 CoroMachinery &coro = outlined->getSecond(); 431 Block *suspended = op->getBlock(); 432 433 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 434 MLIRContext *ctx = op->getContext(); 435 436 // Save the coroutine state and resume on a runtime managed thread when 437 // the operand becomes available. 438 auto coroSaveOp = 439 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 440 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 441 442 // Split the entry block before the await operation. 443 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 444 445 // Add async.coro.suspend as a suspended block terminator. 446 builder.setInsertionPointToEnd(suspended); 447 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 448 coro.cleanup); 449 450 // Split the resume block into error checking and continuation. 451 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 452 453 // Check if the awaited value is in the error state. 454 builder.setInsertionPointToStart(resume); 455 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand); 456 builder.create<cf::CondBranchOp>(isError, 457 /*trueDest=*/setupSetErrorBlock(coro), 458 /*trueArgs=*/ArrayRef<Value>(), 459 /*falseDest=*/continuation, 460 /*falseArgs=*/ArrayRef<Value>()); 461 462 // Make sure that replacement value will be constructed in the 463 // continuation block. 464 rewriter.setInsertionPointToStart(continuation); 465 } 466 467 // Erase or replace the await operation with the new value. 468 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 469 rewriter.replaceOp(op, replaceWith); 470 else 471 rewriter.eraseOp(op); 472 473 return success(); 474 } 475 476 virtual Value getReplacementValue(AwaitType op, Value operand, 477 ConversionPatternRewriter &rewriter) const { 478 return Value(); 479 } 480 481 private: 482 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 483 }; 484 485 /// Lowering for `async.await` with a token operand. 486 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 487 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 488 489 public: 490 using Base::Base; 491 }; 492 493 /// Lowering for `async.await` with a value operand. 494 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 495 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 496 497 public: 498 using Base::Base; 499 500 Value 501 getReplacementValue(AwaitOp op, Value operand, 502 ConversionPatternRewriter &rewriter) const override { 503 // Load from the async value storage. 504 auto valueType = operand.getType().cast<ValueType>().getValueType(); 505 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 506 } 507 }; 508 509 /// Lowering for `async.await_all` operation. 510 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 511 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 512 513 public: 514 using Base::Base; 515 }; 516 517 } // namespace 518 519 //===----------------------------------------------------------------------===// 520 // Convert async.yield operation to async.runtime operations. 521 //===----------------------------------------------------------------------===// 522 523 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 524 public: 525 YieldOpLowering( 526 MLIRContext *ctx, 527 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 528 : OpConversionPattern<async::YieldOp>(ctx), 529 outlinedFunctions(outlinedFunctions) {} 530 531 LogicalResult 532 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 533 ConversionPatternRewriter &rewriter) const override { 534 // Check if yield operation is inside the async coroutine function. 535 auto func = op->template getParentOfType<FuncOp>(); 536 auto outlined = outlinedFunctions.find(func); 537 if (outlined == outlinedFunctions.end()) 538 return rewriter.notifyMatchFailure( 539 op, "operation is not inside the async coroutine function"); 540 541 Location loc = op->getLoc(); 542 const CoroMachinery &coro = outlined->getSecond(); 543 544 // Store yielded values into the async values storage and switch async 545 // values state to available. 546 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { 547 Value yieldValue = std::get<0>(tuple); 548 Value asyncValue = std::get<1>(tuple); 549 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 550 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 551 } 552 553 // Switch the coroutine completion token to available state. 554 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 555 556 return success(); 557 } 558 559 private: 560 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 561 }; 562 563 //===----------------------------------------------------------------------===// 564 // Convert std.assert operation to cf.cond_br into `set_error` block. 565 //===----------------------------------------------------------------------===// 566 567 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> { 568 public: 569 AssertOpLowering(MLIRContext *ctx, 570 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 571 : OpConversionPattern<cf::AssertOp>(ctx), 572 outlinedFunctions(outlinedFunctions) {} 573 574 LogicalResult 575 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 576 ConversionPatternRewriter &rewriter) const override { 577 // Check if assert operation is inside the async coroutine function. 578 auto func = op->template getParentOfType<FuncOp>(); 579 auto outlined = outlinedFunctions.find(func); 580 if (outlined == outlinedFunctions.end()) 581 return rewriter.notifyMatchFailure( 582 op, "operation is not inside the async coroutine function"); 583 584 Location loc = op->getLoc(); 585 CoroMachinery &coro = outlined->getSecond(); 586 587 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 588 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 589 rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), 590 /*trueDest=*/cont, 591 /*trueArgs=*/ArrayRef<Value>(), 592 /*falseDest=*/setupSetErrorBlock(coro), 593 /*falseArgs=*/ArrayRef<Value>()); 594 rewriter.eraseOp(op); 595 596 return success(); 597 } 598 599 private: 600 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 601 }; 602 603 //===----------------------------------------------------------------------===// 604 605 /// Rewrite a func as a coroutine by: 606 /// 1) Wrapping the results into `async.value`. 607 /// 2) Prepending the results with `async.token`. 608 /// 3) Setting up coroutine blocks. 609 /// 4) Rewriting return ops as yield op and branch op into the suspend block. 610 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) { 611 auto *ctx = func->getContext(); 612 auto loc = func.getLoc(); 613 SmallVector<Type> resultTypes; 614 resultTypes.reserve(func.getCallableResults().size()); 615 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), 616 [](Type type) { return ValueType::get(type); }); 617 func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes)); 618 func.insertResult(0, TokenType::get(ctx), {}); 619 for (Block &block : func.getBlocks()) { 620 Operation *terminator = block.getTerminator(); 621 if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) { 622 ImplicitLocOpBuilder builder(loc, returnOp); 623 builder.create<YieldOp>(returnOp.getOperands()); 624 returnOp.erase(); 625 } 626 } 627 return setupCoroMachinery(func); 628 } 629 630 /// Rewrites a call into a function that has been rewritten as a coroutine. 631 /// 632 /// The invocation of this function is safe only when call ops are traversed in 633 /// reverse order of how they appear in a single block. See `funcsToCoroutines`. 634 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) { 635 auto loc = func.getLoc(); 636 ImplicitLocOpBuilder callBuilder(loc, oldCall); 637 auto newCall = callBuilder.create<CallOp>( 638 func.getName(), func.getCallableResults(), oldCall.getArgOperands()); 639 640 // Await on the async token and all the value results and unwrap the latter. 641 callBuilder.create<AwaitOp>(loc, newCall.getResults().front()); 642 SmallVector<Value> unwrappedResults; 643 unwrappedResults.reserve(newCall->getResults().size() - 1); 644 for (Value result : newCall.getResults().drop_front()) 645 unwrappedResults.push_back( 646 callBuilder.create<AwaitOp>(loc, result).result()); 647 // Careful, when result of a call is piped into another call this could lead 648 // to a dangling pointer. 649 oldCall.replaceAllUsesWith(unwrappedResults); 650 oldCall.erase(); 651 } 652 653 static bool isAllowedToBlock(FuncOp func) { 654 return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName); 655 } 656 657 static LogicalResult 658 funcsToCoroutines(ModuleOp module, 659 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) { 660 // The following code supports the general case when 2 functions mutually 661 // recurse into each other. Because of this and that we are relying on 662 // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase 663 // a FuncOp while inserting an equivalent coroutine, because that could lead 664 // to dangling pointers. 665 666 SmallVector<FuncOp> funcWorklist; 667 668 // Careful, it's okay to add a func to the worklist multiple times if and only 669 // if the loop processing the worklist will skip the functions that have 670 // already been converted to coroutines. 671 auto addToWorklist = [&](FuncOp func) { 672 if (isAllowedToBlock(func)) 673 return; 674 // N.B. To refactor this code into a separate pass the lookup in 675 // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary 676 // func and recognizing if it has a coroutine structure is messy. Passing 677 // this dict between the passes is ugly. 678 if (isAllowedToBlock(func) || 679 outlinedFunctions.find(func) == outlinedFunctions.end()) { 680 for (Operation &op : func.body().getOps()) { 681 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) { 682 funcWorklist.push_back(func); 683 break; 684 } 685 } 686 } 687 }; 688 689 // Traverse in post-order collecting for each func op the await ops it has. 690 for (FuncOp func : module.getOps<FuncOp>()) 691 addToWorklist(func); 692 693 SymbolTableCollection symbolTable; 694 SymbolUserMap symbolUserMap(symbolTable, module); 695 696 // Rewrite funcs, while updating call sites and adding them to the worklist. 697 while (!funcWorklist.empty()) { 698 auto func = funcWorklist.pop_back_val(); 699 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); 700 if (!insertion.second) 701 // This function has already been processed because this is either 702 // the corecursive case, or a caller with multiple calls to a newly 703 // created corouting. Either way, skip updating the call sites. 704 continue; 705 insertion.first->second = rewriteFuncAsCoroutine(func); 706 SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(), 707 symbolUserMap.getUsers(func).end()); 708 // If there are multiple calls from the same block they need to be traversed 709 // in reverse order so that symbolUserMap references are not invalidated 710 // when updating the users of the call op which is earlier in the block. 711 llvm::sort(users, [](Operation *a, Operation *b) { 712 Block *blockA = a->getBlock(); 713 Block *blockB = b->getBlock(); 714 // Impose arbitrary order on blocks so that there is a well-defined order. 715 return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); 716 }); 717 // Rewrite the callsites to await on results of the newly created coroutine. 718 for (Operation *op : users) { 719 if (CallOp call = dyn_cast<mlir::CallOp>(*op)) { 720 FuncOp caller = call->getParentOfType<FuncOp>(); 721 rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. 722 addToWorklist(caller); 723 } else { 724 op->emitError("Unexpected reference to func referenced by symbol"); 725 return failure(); 726 } 727 } 728 } 729 return success(); 730 } 731 732 //===----------------------------------------------------------------------===// 733 void AsyncToAsyncRuntimePass::runOnOperation() { 734 ModuleOp module = getOperation(); 735 SymbolTable symbolTable(module); 736 737 // Outline all `async.execute` body regions into async functions (coroutines). 738 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 739 740 module.walk([&](ExecuteOp execute) { 741 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 742 }); 743 744 LLVM_DEBUG({ 745 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 746 << " functions built from async.execute operations\n"; 747 }); 748 749 // Returns true if operation is inside the coroutine. 750 auto isInCoroutine = [&](Operation *op) -> bool { 751 auto parentFunc = op->getParentOfType<FuncOp>(); 752 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 753 }; 754 755 if (eliminateBlockingAwaitOps && 756 failed(funcsToCoroutines(module, outlinedFunctions))) { 757 signalPassFailure(); 758 return; 759 } 760 761 // Lower async operations to async.runtime operations. 762 MLIRContext *ctx = module->getContext(); 763 RewritePatternSet asyncPatterns(ctx); 764 765 // Conversion to async runtime augments original CFG with the coroutine CFG, 766 // and we have to make sure that structured control flow operations with async 767 // operations in nested regions will be converted to branch-based control flow 768 // before we add the coroutine basic blocks. 769 populateSCFToControlFlowConversionPatterns(asyncPatterns); 770 771 // Async lowering does not use type converter because it must preserve all 772 // types for async.runtime operations. 773 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 774 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 775 AwaitAllOpLowering, YieldOpLowering>(ctx, 776 outlinedFunctions); 777 778 // Lower assertions to conditional branches into error blocks. 779 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 780 781 // All high level async operations must be lowered to the runtime operations. 782 ConversionTarget runtimeTarget(*ctx); 783 runtimeTarget.addLegalDialect<AsyncDialect>(); 784 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 785 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 786 787 // Decide if structured control flow has to be lowered to branch-based CFG. 788 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 789 auto walkResult = op->walk([&](Operation *nested) { 790 bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 791 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 792 : WalkResult::advance(); 793 }); 794 return !walkResult.wasInterrupted(); 795 }); 796 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp, 797 ConstantOp, cf::BranchOp, cf::CondBranchOp>(); 798 799 // Assertions must be converted to runtime errors inside async functions. 800 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>( 801 [&](cf::AssertOp op) -> bool { 802 auto func = op->getParentOfType<FuncOp>(); 803 return outlinedFunctions.find(func) == outlinedFunctions.end(); 804 }); 805 806 if (eliminateBlockingAwaitOps) 807 runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>( 808 [&](RuntimeAwaitOp op) -> bool { 809 return isAllowedToBlock(op->getParentOfType<FuncOp>()); 810 }); 811 812 if (failed(applyPartialConversion(module, runtimeTarget, 813 std::move(asyncPatterns)))) { 814 signalPassFailure(); 815 return; 816 } 817 } 818 819 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 820 return std::make_unique<AsyncToAsyncRuntimePass>(); 821 } 822