1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Async/IR/Async.h" 18 #include "mlir/Dialect/Async/Passes.h" 19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "mlir/IR/ImplicitLocOpBuilder.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "mlir/Transforms/RegionUtils.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/Support/Debug.h" 29 30 using namespace mlir; 31 using namespace mlir::async; 32 33 #define DEBUG_TYPE "async-to-async-runtime" 34 // Prefix for functions outlined from `async.execute` op regions. 35 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 36 37 namespace { 38 39 class AsyncToAsyncRuntimePass 40 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 41 public: 42 AsyncToAsyncRuntimePass() = default; 43 void runOnOperation() override; 44 }; 45 46 } // namespace 47 48 //===----------------------------------------------------------------------===// 49 // async.execute op outlining to the coroutine functions. 50 //===----------------------------------------------------------------------===// 51 52 /// Function targeted for coroutine transformation has two additional blocks at 53 /// the end: coroutine cleanup and coroutine suspension. 54 /// 55 /// async.await op lowering additionaly creates a resume block for each 56 /// operation to enable non-blocking waiting via coroutine suspension. 57 namespace { 58 struct CoroMachinery { 59 func::FuncOp func; 60 61 // Async execute region returns a completion token, and an async value for 62 // each yielded value. 63 // 64 // %token, %result = async.execute -> !async.value<T> { 65 // %0 = arith.constant ... : T 66 // async.yield %0 : T 67 // } 68 Value asyncToken; // token representing completion of the async region 69 llvm::SmallVector<Value, 4> returnValues; // returned async values 70 71 Value coroHandle; // coroutine handle (!async.coro.handle value) 72 Block *entry; // coroutine entry block 73 Block *setError; // switch completion token and all values to error state 74 Block *cleanup; // coroutine cleanup block 75 Block *suspend; // coroutine suspension block 76 }; 77 } // namespace 78 79 /// Utility to partially update the regular function CFG to the coroutine CFG 80 /// compatible with LLVM coroutines switched-resume lowering using 81 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block 82 /// that branches into preexisting entry block. Also inserts trailing blocks. 83 /// 84 /// The result types of the passed `func` must start with an `async.token` 85 /// and be continued with some number of `async.value`s. 86 /// 87 /// The func given to this function needs to have been preprocessed to have 88 /// either branch or yield ops as terminators. Branches to the cleanup block are 89 /// inserted after each yield. 90 /// 91 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 92 /// 93 /// - `entry` block sets up the coroutine. 94 /// - `set_error` block sets completion token and async values state to error. 95 /// - `cleanup` block cleans up the coroutine state. 96 /// - `suspend block after the @llvm.coro.end() defines what value will be 97 /// returned to the initial caller of a coroutine. Everything before the 98 /// @llvm.coro.end() will be executed at every suspension point. 99 /// 100 /// Coroutine structure (only the important bits): 101 /// 102 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) 103 /// { 104 /// ^entry(<function-arguments>): 105 /// %token = <async token> : !async.token // create async runtime token 106 /// %value = <async value> : !async.value<T> // create async value 107 /// %id = async.coro.id // create a coroutine id 108 /// %hdl = async.coro.begin %id // create a coroutine handle 109 /// cf.br ^preexisting_entry_block 110 /// 111 /// /* preexisting blocks modified to branch to the cleanup block */ 112 /// 113 /// ^set_error: // this block created lazily only if needed (see code below) 114 /// async.runtime.set_error %token : !async.token 115 /// async.runtime.set_error %value : !async.value<T> 116 /// cf.br ^cleanup 117 /// 118 /// ^cleanup: 119 /// async.coro.free %hdl // delete the coroutine state 120 /// cf.br ^suspend 121 /// 122 /// ^suspend: 123 /// async.coro.end %hdl // marks the end of a coroutine 124 /// return %token, %value : !async.token, !async.value<T> 125 /// } 126 /// 127 static CoroMachinery setupCoroMachinery(func::FuncOp func) { 128 assert(!func.getBlocks().empty() && "Function must have an entry block"); 129 130 MLIRContext *ctx = func.getContext(); 131 Block *entryBlock = &func.getBlocks().front(); 132 Block *originalEntryBlock = 133 entryBlock->splitBlock(entryBlock->getOperations().begin()); 134 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 135 136 // ------------------------------------------------------------------------ // 137 // Allocate async token/values that we will return from a ramp function. 138 // ------------------------------------------------------------------------ // 139 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 140 141 llvm::SmallVector<Value, 4> retValues; 142 for (auto resType : func.getCallableResults().drop_front()) 143 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 144 145 // ------------------------------------------------------------------------ // 146 // Initialize coroutine: get coroutine id and coroutine handle. 147 // ------------------------------------------------------------------------ // 148 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 149 auto coroHdlOp = 150 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 151 builder.create<cf::BranchOp>(originalEntryBlock); 152 153 Block *cleanupBlock = func.addBlock(); 154 Block *suspendBlock = func.addBlock(); 155 156 // ------------------------------------------------------------------------ // 157 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 158 // ------------------------------------------------------------------------ // 159 builder.setInsertionPointToStart(cleanupBlock); 160 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 161 162 // Branch into the suspend block. 163 builder.create<cf::BranchOp>(suspendBlock); 164 165 // ------------------------------------------------------------------------ // 166 // Coroutine suspend block: mark the end of a coroutine and return allocated 167 // async token. 168 // ------------------------------------------------------------------------ // 169 builder.setInsertionPointToStart(suspendBlock); 170 171 // Mark the end of a coroutine: async.coro.end 172 builder.create<CoroEndOp>(coroHdlOp.handle()); 173 174 // Return created `async.token` and `async.values` from the suspend block. 175 // This will be the return value of a coroutine ramp function. 176 SmallVector<Value, 4> ret{retToken}; 177 ret.insert(ret.end(), retValues.begin(), retValues.end()); 178 builder.create<func::ReturnOp>(ret); 179 180 // `async.await` op lowering will create resume blocks for async 181 // continuations, and will conditionally branch to cleanup or suspend blocks. 182 183 for (Block &block : func.getBody().getBlocks()) { 184 if (&block == entryBlock || &block == cleanupBlock || 185 &block == suspendBlock) 186 continue; 187 Operation *terminator = block.getTerminator(); 188 if (auto yield = dyn_cast<YieldOp>(terminator)) { 189 builder.setInsertionPointToEnd(&block); 190 builder.create<cf::BranchOp>(cleanupBlock); 191 } 192 } 193 194 // The switch-resumed API based coroutine should be marked with 195 // coroutine.presplit attribute to mark the function as a coroutine. 196 func->setAttr("passthrough", builder.getArrayAttr( 197 StringAttr::get(ctx, "presplitcoroutine"))); 198 199 CoroMachinery machinery; 200 machinery.func = func; 201 machinery.asyncToken = retToken; 202 machinery.returnValues = retValues; 203 machinery.coroHandle = coroHdlOp.handle(); 204 machinery.entry = entryBlock; 205 machinery.setError = nullptr; // created lazily only if needed 206 machinery.cleanup = cleanupBlock; 207 machinery.suspend = suspendBlock; 208 return machinery; 209 } 210 211 // Lazily creates `set_error` block only if it is required for lowering to the 212 // runtime operations (see for example lowering of assert operation). 213 static Block *setupSetErrorBlock(CoroMachinery &coro) { 214 if (coro.setError) 215 return coro.setError; 216 217 coro.setError = coro.func.addBlock(); 218 coro.setError->moveBefore(coro.cleanup); 219 220 auto builder = 221 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 222 223 // Coroutine set_error block: set error on token and all returned values. 224 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 225 for (Value retValue : coro.returnValues) 226 builder.create<RuntimeSetErrorOp>(retValue); 227 228 // Branch into the cleanup block. 229 builder.create<cf::BranchOp>(coro.cleanup); 230 231 return coro.setError; 232 } 233 234 /// Outline the body region attached to the `async.execute` op into a standalone 235 /// function. 236 /// 237 /// Note that this is not reversible transformation. 238 static std::pair<func::FuncOp, CoroMachinery> 239 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 240 ModuleOp module = execute->getParentOfType<ModuleOp>(); 241 242 MLIRContext *ctx = module.getContext(); 243 Location loc = execute.getLoc(); 244 245 // Make sure that all constants will be inside the outlined async function to 246 // reduce the number of function arguments. 247 cloneConstantsIntoTheRegion(execute.body()); 248 249 // Collect all outlined function inputs. 250 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 251 execute.dependencies().end()); 252 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 253 getUsedValuesDefinedAbove(execute.body(), functionInputs); 254 255 // Collect types for the outlined function inputs and outputs. 256 auto typesRange = llvm::map_range( 257 functionInputs, [](Value value) { return value.getType(); }); 258 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 259 auto outputTypes = execute.getResultTypes(); 260 261 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 262 auto funcAttrs = ArrayRef<NamedAttribute>(); 263 264 // TODO: Derive outlined function name from the parent FuncOp (support 265 // multiple nested async.execute operations). 266 func::FuncOp func = 267 func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 268 symbolTable.insert(func); 269 270 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 271 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); 272 273 // Prepare for coroutine conversion by creating the body of the function. 274 { 275 size_t numDependencies = execute.dependencies().size(); 276 size_t numOperands = execute.operands().size(); 277 278 // Await on all dependencies before starting to execute the body region. 279 for (size_t i = 0; i < numDependencies; ++i) 280 builder.create<AwaitOp>(func.getArgument(i)); 281 282 // Await on all async value operands and unwrap the payload. 283 SmallVector<Value, 4> unwrappedOperands(numOperands); 284 for (size_t i = 0; i < numOperands; ++i) { 285 Value operand = func.getArgument(numDependencies + i); 286 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 287 } 288 289 // Map from function inputs defined above the execute op to the function 290 // arguments. 291 BlockAndValueMapping valueMapping; 292 valueMapping.map(functionInputs, func.getArguments()); 293 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 294 295 // Clone all operations from the execute operation body into the outlined 296 // function body. 297 for (Operation &op : execute.body().getOps()) 298 builder.clone(op, valueMapping); 299 } 300 301 // Adding entry/cleanup/suspend blocks. 302 CoroMachinery coro = setupCoroMachinery(func); 303 304 // Suspend async function at the end of an entry block, and resume it using 305 // Async resume operation (execution will be resumed in a thread managed by 306 // the async runtime). 307 { 308 cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator()); 309 builder.setInsertionPointToEnd(coro.entry); 310 311 // Save the coroutine state: async.coro.save 312 auto coroSaveOp = 313 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 314 315 // Pass coroutine to the runtime to be resumed on a runtime managed 316 // thread. 317 builder.create<RuntimeResumeOp>(coro.coroHandle); 318 319 // Add async.coro.suspend as a suspended block terminator. 320 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, 321 branch.getDest(), coro.cleanup); 322 323 branch.erase(); 324 } 325 326 // Replace the original `async.execute` with a call to outlined function. 327 { 328 ImplicitLocOpBuilder callBuilder(loc, execute); 329 auto callOutlinedFunc = callBuilder.create<func::CallOp>( 330 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 331 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 332 execute.erase(); 333 } 334 335 return {func, coro}; 336 } 337 338 //===----------------------------------------------------------------------===// 339 // Convert async.create_group operation to async.runtime.create_group 340 //===----------------------------------------------------------------------===// 341 342 namespace { 343 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 344 public: 345 using OpConversionPattern::OpConversionPattern; 346 347 LogicalResult 348 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor, 349 ConversionPatternRewriter &rewriter) const override { 350 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 351 op, GroupType::get(op->getContext()), adaptor.getOperands()); 352 return success(); 353 } 354 }; 355 } // namespace 356 357 //===----------------------------------------------------------------------===// 358 // Convert async.add_to_group operation to async.runtime.add_to_group. 359 //===----------------------------------------------------------------------===// 360 361 namespace { 362 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 363 public: 364 using OpConversionPattern::OpConversionPattern; 365 366 LogicalResult 367 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor, 368 ConversionPatternRewriter &rewriter) const override { 369 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 370 op, rewriter.getIndexType(), adaptor.getOperands()); 371 return success(); 372 } 373 }; 374 } // namespace 375 376 //===----------------------------------------------------------------------===// 377 // Convert async.await and async.await_all operations to the async.runtime.await 378 // or async.runtime.await_and_resume operations. 379 //===----------------------------------------------------------------------===// 380 381 namespace { 382 template <typename AwaitType, typename AwaitableType> 383 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 384 using AwaitAdaptor = typename AwaitType::Adaptor; 385 386 public: 387 AwaitOpLoweringBase( 388 MLIRContext *ctx, 389 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) 390 : OpConversionPattern<AwaitType>(ctx), 391 outlinedFunctions(outlinedFunctions) {} 392 393 LogicalResult 394 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, 395 ConversionPatternRewriter &rewriter) const override { 396 // We can only await on one the `AwaitableType` (for `await` it can be 397 // a `token` or a `value`, for `await_all` it must be a `group`). 398 if (!op.operand().getType().template isa<AwaitableType>()) 399 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 400 401 // Check if await operation is inside the outlined coroutine function. 402 auto func = op->template getParentOfType<func::FuncOp>(); 403 auto outlined = outlinedFunctions.find(func); 404 const bool isInCoroutine = outlined != outlinedFunctions.end(); 405 406 Location loc = op->getLoc(); 407 Value operand = adaptor.operand(); 408 409 Type i1 = rewriter.getI1Type(); 410 411 // Inside regular functions we use the blocking wait operation to wait for 412 // the async object (token, value or group) to become available. 413 if (!isInCoroutine) { 414 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 415 builder.create<RuntimeAwaitOp>(loc, operand); 416 417 // Assert that the awaited operands is not in the error state. 418 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand); 419 Value notError = builder.create<arith::XOrIOp>( 420 isError, builder.create<arith::ConstantOp>( 421 loc, i1, builder.getIntegerAttr(i1, 1))); 422 423 builder.create<cf::AssertOp>(notError, 424 "Awaited async operand is in error state"); 425 } 426 427 // Inside the coroutine we convert await operation into coroutine suspension 428 // point, and resume execution asynchronously. 429 if (isInCoroutine) { 430 CoroMachinery &coro = outlined->getSecond(); 431 Block *suspended = op->getBlock(); 432 433 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 434 MLIRContext *ctx = op->getContext(); 435 436 // Save the coroutine state and resume on a runtime managed thread when 437 // the operand becomes available. 438 auto coroSaveOp = 439 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 440 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 441 442 // Split the entry block before the await operation. 443 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 444 445 // Add async.coro.suspend as a suspended block terminator. 446 builder.setInsertionPointToEnd(suspended); 447 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 448 coro.cleanup); 449 450 // Split the resume block into error checking and continuation. 451 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 452 453 // Check if the awaited value is in the error state. 454 builder.setInsertionPointToStart(resume); 455 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand); 456 builder.create<cf::CondBranchOp>(isError, 457 /*trueDest=*/setupSetErrorBlock(coro), 458 /*trueArgs=*/ArrayRef<Value>(), 459 /*falseDest=*/continuation, 460 /*falseArgs=*/ArrayRef<Value>()); 461 462 // Make sure that replacement value will be constructed in the 463 // continuation block. 464 rewriter.setInsertionPointToStart(continuation); 465 } 466 467 // Erase or replace the await operation with the new value. 468 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 469 rewriter.replaceOp(op, replaceWith); 470 else 471 rewriter.eraseOp(op); 472 473 return success(); 474 } 475 476 virtual Value getReplacementValue(AwaitType op, Value operand, 477 ConversionPatternRewriter &rewriter) const { 478 return Value(); 479 } 480 481 private: 482 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions; 483 }; 484 485 /// Lowering for `async.await` with a token operand. 486 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 487 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 488 489 public: 490 using Base::Base; 491 }; 492 493 /// Lowering for `async.await` with a value operand. 494 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 495 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 496 497 public: 498 using Base::Base; 499 500 Value 501 getReplacementValue(AwaitOp op, Value operand, 502 ConversionPatternRewriter &rewriter) const override { 503 // Load from the async value storage. 504 auto valueType = operand.getType().cast<ValueType>().getValueType(); 505 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 506 } 507 }; 508 509 /// Lowering for `async.await_all` operation. 510 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 511 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 512 513 public: 514 using Base::Base; 515 }; 516 517 } // namespace 518 519 //===----------------------------------------------------------------------===// 520 // Convert async.yield operation to async.runtime operations. 521 //===----------------------------------------------------------------------===// 522 523 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 524 public: 525 YieldOpLowering( 526 MLIRContext *ctx, 527 const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) 528 : OpConversionPattern<async::YieldOp>(ctx), 529 outlinedFunctions(outlinedFunctions) {} 530 531 LogicalResult 532 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 533 ConversionPatternRewriter &rewriter) const override { 534 // Check if yield operation is inside the async coroutine function. 535 auto func = op->template getParentOfType<func::FuncOp>(); 536 auto outlined = outlinedFunctions.find(func); 537 if (outlined == outlinedFunctions.end()) 538 return rewriter.notifyMatchFailure( 539 op, "operation is not inside the async coroutine function"); 540 541 Location loc = op->getLoc(); 542 const CoroMachinery &coro = outlined->getSecond(); 543 544 // Store yielded values into the async values storage and switch async 545 // values state to available. 546 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { 547 Value yieldValue = std::get<0>(tuple); 548 Value asyncValue = std::get<1>(tuple); 549 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 550 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 551 } 552 553 // Switch the coroutine completion token to available state. 554 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 555 556 return success(); 557 } 558 559 private: 560 const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions; 561 }; 562 563 //===----------------------------------------------------------------------===// 564 // Convert cf.assert operation to cf.cond_br into `set_error` block. 565 //===----------------------------------------------------------------------===// 566 567 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> { 568 public: 569 AssertOpLowering( 570 MLIRContext *ctx, 571 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) 572 : OpConversionPattern<cf::AssertOp>(ctx), 573 outlinedFunctions(outlinedFunctions) {} 574 575 LogicalResult 576 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 577 ConversionPatternRewriter &rewriter) const override { 578 // Check if assert operation is inside the async coroutine function. 579 auto func = op->template getParentOfType<func::FuncOp>(); 580 auto outlined = outlinedFunctions.find(func); 581 if (outlined == outlinedFunctions.end()) 582 return rewriter.notifyMatchFailure( 583 op, "operation is not inside the async coroutine function"); 584 585 Location loc = op->getLoc(); 586 CoroMachinery &coro = outlined->getSecond(); 587 588 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 589 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 590 rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), 591 /*trueDest=*/cont, 592 /*trueArgs=*/ArrayRef<Value>(), 593 /*falseDest=*/setupSetErrorBlock(coro), 594 /*falseArgs=*/ArrayRef<Value>()); 595 rewriter.eraseOp(op); 596 597 return success(); 598 } 599 600 private: 601 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions; 602 }; 603 604 //===----------------------------------------------------------------------===// 605 606 /// Rewrite a func as a coroutine by: 607 /// 1) Wrapping the results into `async.value`. 608 /// 2) Prepending the results with `async.token`. 609 /// 3) Setting up coroutine blocks. 610 /// 4) Rewriting return ops as yield op and branch op into the suspend block. 611 static CoroMachinery rewriteFuncAsCoroutine(func::FuncOp func) { 612 auto *ctx = func->getContext(); 613 auto loc = func.getLoc(); 614 SmallVector<Type> resultTypes; 615 resultTypes.reserve(func.getCallableResults().size()); 616 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), 617 [](Type type) { return ValueType::get(type); }); 618 func.setType( 619 FunctionType::get(ctx, func.getFunctionType().getInputs(), resultTypes)); 620 func.insertResult(0, TokenType::get(ctx), {}); 621 for (Block &block : func.getBlocks()) { 622 Operation *terminator = block.getTerminator(); 623 if (auto returnOp = dyn_cast<func::ReturnOp>(*terminator)) { 624 ImplicitLocOpBuilder builder(loc, returnOp); 625 builder.create<YieldOp>(returnOp.getOperands()); 626 returnOp.erase(); 627 } 628 } 629 return setupCoroMachinery(func); 630 } 631 632 /// Rewrites a call into a function that has been rewritten as a coroutine. 633 /// 634 /// The invocation of this function is safe only when call ops are traversed in 635 /// reverse order of how they appear in a single block. See `funcsToCoroutines`. 636 static void rewriteCallsiteForCoroutine(func::CallOp oldCall, 637 func::FuncOp func) { 638 auto loc = func.getLoc(); 639 ImplicitLocOpBuilder callBuilder(loc, oldCall); 640 auto newCall = callBuilder.create<func::CallOp>( 641 func.getName(), func.getCallableResults(), oldCall.getArgOperands()); 642 643 // Await on the async token and all the value results and unwrap the latter. 644 callBuilder.create<AwaitOp>(loc, newCall.getResults().front()); 645 SmallVector<Value> unwrappedResults; 646 unwrappedResults.reserve(newCall->getResults().size() - 1); 647 for (Value result : newCall.getResults().drop_front()) 648 unwrappedResults.push_back( 649 callBuilder.create<AwaitOp>(loc, result).result()); 650 // Careful, when result of a call is piped into another call this could lead 651 // to a dangling pointer. 652 oldCall.replaceAllUsesWith(unwrappedResults); 653 oldCall.erase(); 654 } 655 656 static bool isAllowedToBlock(func::FuncOp func) { 657 return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName); 658 } 659 660 static LogicalResult funcsToCoroutines( 661 ModuleOp module, 662 llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) { 663 // The following code supports the general case when 2 functions mutually 664 // recurse into each other. Because of this and that we are relying on 665 // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase 666 // a FuncOp while inserting an equivalent coroutine, because that could lead 667 // to dangling pointers. 668 669 SmallVector<func::FuncOp> funcWorklist; 670 671 // Careful, it's okay to add a func to the worklist multiple times if and only 672 // if the loop processing the worklist will skip the functions that have 673 // already been converted to coroutines. 674 auto addToWorklist = [&](func::FuncOp func) { 675 if (isAllowedToBlock(func)) 676 return; 677 // N.B. To refactor this code into a separate pass the lookup in 678 // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary 679 // func and recognizing if it has a coroutine structure is messy. Passing 680 // this dict between the passes is ugly. 681 if (isAllowedToBlock(func) || 682 outlinedFunctions.find(func) == outlinedFunctions.end()) { 683 for (Operation &op : func.getBody().getOps()) { 684 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) { 685 funcWorklist.push_back(func); 686 break; 687 } 688 } 689 } 690 }; 691 692 // Traverse in post-order collecting for each func op the await ops it has. 693 for (func::FuncOp func : module.getOps<func::FuncOp>()) 694 addToWorklist(func); 695 696 SymbolTableCollection symbolTable; 697 SymbolUserMap symbolUserMap(symbolTable, module); 698 699 // Rewrite funcs, while updating call sites and adding them to the worklist. 700 while (!funcWorklist.empty()) { 701 auto func = funcWorklist.pop_back_val(); 702 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); 703 if (!insertion.second) 704 // This function has already been processed because this is either 705 // the corecursive case, or a caller with multiple calls to a newly 706 // created corouting. Either way, skip updating the call sites. 707 continue; 708 insertion.first->second = rewriteFuncAsCoroutine(func); 709 SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(), 710 symbolUserMap.getUsers(func).end()); 711 // If there are multiple calls from the same block they need to be traversed 712 // in reverse order so that symbolUserMap references are not invalidated 713 // when updating the users of the call op which is earlier in the block. 714 llvm::sort(users, [](Operation *a, Operation *b) { 715 Block *blockA = a->getBlock(); 716 Block *blockB = b->getBlock(); 717 // Impose arbitrary order on blocks so that there is a well-defined order. 718 return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); 719 }); 720 // Rewrite the callsites to await on results of the newly created coroutine. 721 for (Operation *op : users) { 722 if (func::CallOp call = dyn_cast<func::CallOp>(*op)) { 723 func::FuncOp caller = call->getParentOfType<func::FuncOp>(); 724 rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. 725 addToWorklist(caller); 726 } else { 727 op->emitError("Unexpected reference to func referenced by symbol"); 728 return failure(); 729 } 730 } 731 } 732 return success(); 733 } 734 735 //===----------------------------------------------------------------------===// 736 void AsyncToAsyncRuntimePass::runOnOperation() { 737 ModuleOp module = getOperation(); 738 SymbolTable symbolTable(module); 739 740 // Outline all `async.execute` body regions into async functions (coroutines). 741 llvm::DenseMap<func::FuncOp, CoroMachinery> outlinedFunctions; 742 743 module.walk([&](ExecuteOp execute) { 744 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 745 }); 746 747 LLVM_DEBUG({ 748 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 749 << " functions built from async.execute operations\n"; 750 }); 751 752 // Returns true if operation is inside the coroutine. 753 auto isInCoroutine = [&](Operation *op) -> bool { 754 auto parentFunc = op->getParentOfType<func::FuncOp>(); 755 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 756 }; 757 758 if (eliminateBlockingAwaitOps && 759 failed(funcsToCoroutines(module, outlinedFunctions))) { 760 signalPassFailure(); 761 return; 762 } 763 764 // Lower async operations to async.runtime operations. 765 MLIRContext *ctx = module->getContext(); 766 RewritePatternSet asyncPatterns(ctx); 767 768 // Conversion to async runtime augments original CFG with the coroutine CFG, 769 // and we have to make sure that structured control flow operations with async 770 // operations in nested regions will be converted to branch-based control flow 771 // before we add the coroutine basic blocks. 772 populateSCFToControlFlowConversionPatterns(asyncPatterns); 773 774 // Async lowering does not use type converter because it must preserve all 775 // types for async.runtime operations. 776 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 777 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 778 AwaitAllOpLowering, YieldOpLowering>(ctx, 779 outlinedFunctions); 780 781 // Lower assertions to conditional branches into error blocks. 782 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 783 784 // All high level async operations must be lowered to the runtime operations. 785 ConversionTarget runtimeTarget(*ctx); 786 runtimeTarget.addLegalDialect<AsyncDialect>(); 787 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 788 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 789 790 // Decide if structured control flow has to be lowered to branch-based CFG. 791 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 792 auto walkResult = op->walk([&](Operation *nested) { 793 bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 794 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 795 : WalkResult::advance(); 796 }); 797 return !walkResult.wasInterrupted(); 798 }); 799 runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp, 800 func::ConstantOp, cf::BranchOp, cf::CondBranchOp>(); 801 802 // Assertions must be converted to runtime errors inside async functions. 803 runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>( 804 [&](cf::AssertOp op) -> bool { 805 auto func = op->getParentOfType<func::FuncOp>(); 806 return outlinedFunctions.find(func) == outlinedFunctions.end(); 807 }); 808 809 if (eliminateBlockingAwaitOps) 810 runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>( 811 [&](RuntimeAwaitOp op) -> bool { 812 return isAllowedToBlock(op->getParentOfType<func::FuncOp>()); 813 }); 814 815 if (failed(applyPartialConversion(module, runtimeTarget, 816 std::move(asyncPatterns)))) { 817 signalPassFailure(); 818 return; 819 } 820 } 821 822 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 823 return std::make_unique<AsyncToAsyncRuntimePass>(); 824 } 825