1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Async/IR/Async.h" 18 #include "mlir/Dialect/Async/Passes.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" 21 #include "mlir/IR/BlockAndValueMapping.h" 22 #include "mlir/IR/ImplicitLocOpBuilder.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/RegionUtils.h" 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/Support/Debug.h" 28 29 using namespace mlir; 30 using namespace mlir::async; 31 32 #define DEBUG_TYPE "async-to-async-runtime" 33 // Prefix for functions outlined from `async.execute` op regions. 34 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 35 36 namespace { 37 38 class AsyncToAsyncRuntimePass 39 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 40 public: 41 AsyncToAsyncRuntimePass() = default; 42 void runOnOperation() override; 43 }; 44 45 } // namespace 46 47 //===----------------------------------------------------------------------===// 48 // async.execute op outlining to the coroutine functions. 49 //===----------------------------------------------------------------------===// 50 51 /// Function targeted for coroutine transformation has two additional blocks at 52 /// the end: coroutine cleanup and coroutine suspension. 53 /// 54 /// async.await op lowering additionaly creates a resume block for each 55 /// operation to enable non-blocking waiting via coroutine suspension. 56 namespace { 57 struct CoroMachinery { 58 FuncOp func; 59 60 // Async execute region returns a completion token, and an async value for 61 // each yielded value. 62 // 63 // %token, %result = async.execute -> !async.value<T> { 64 // %0 = arith.constant ... : T 65 // async.yield %0 : T 66 // } 67 Value asyncToken; // token representing completion of the async region 68 llvm::SmallVector<Value, 4> returnValues; // returned async values 69 70 Value coroHandle; // coroutine handle (!async.coro.handle value) 71 Block *entry; // coroutine entry block 72 Block *setError; // switch completion token and all values to error state 73 Block *cleanup; // coroutine cleanup block 74 Block *suspend; // coroutine suspension block 75 }; 76 } // namespace 77 78 /// Utility to partially update the regular function CFG to the coroutine CFG 79 /// compatible with LLVM coroutines switched-resume lowering using 80 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block 81 /// that branches into preexisting entry block. Also inserts trailing blocks. 82 /// 83 /// The result types of the passed `func` must start with an `async.token` 84 /// and be continued with some number of `async.value`s. 85 /// 86 /// The func given to this function needs to have been preprocessed to have 87 /// either branch or yield ops as terminators. Branches to the cleanup block are 88 /// inserted after each yield. 89 /// 90 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 91 /// 92 /// - `entry` block sets up the coroutine. 93 /// - `set_error` block sets completion token and async values state to error. 94 /// - `cleanup` block cleans up the coroutine state. 95 /// - `suspend block after the @llvm.coro.end() defines what value will be 96 /// returned to the initial caller of a coroutine. Everything before the 97 /// @llvm.coro.end() will be executed at every suspension point. 98 /// 99 /// Coroutine structure (only the important bits): 100 /// 101 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) 102 /// { 103 /// ^entry(<function-arguments>): 104 /// %token = <async token> : !async.token // create async runtime token 105 /// %value = <async value> : !async.value<T> // create async value 106 /// %id = async.coro.id // create a coroutine id 107 /// %hdl = async.coro.begin %id // create a coroutine handle 108 /// br ^preexisting_entry_block 109 /// 110 /// /* preexisting blocks modified to branch to the cleanup block */ 111 /// 112 /// ^set_error: // this block created lazily only if needed (see code below) 113 /// async.runtime.set_error %token : !async.token 114 /// async.runtime.set_error %value : !async.value<T> 115 /// br ^cleanup 116 /// 117 /// ^cleanup: 118 /// async.coro.free %hdl // delete the coroutine state 119 /// br ^suspend 120 /// 121 /// ^suspend: 122 /// async.coro.end %hdl // marks the end of a coroutine 123 /// return %token, %value : !async.token, !async.value<T> 124 /// } 125 /// 126 static CoroMachinery setupCoroMachinery(FuncOp func) { 127 assert(!func.getBlocks().empty() && "Function must have an entry block"); 128 129 MLIRContext *ctx = func.getContext(); 130 Block *entryBlock = &func.getBlocks().front(); 131 Block *originalEntryBlock = 132 entryBlock->splitBlock(entryBlock->getOperations().begin()); 133 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 134 135 // ------------------------------------------------------------------------ // 136 // Allocate async token/values that we will return from a ramp function. 137 // ------------------------------------------------------------------------ // 138 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 139 140 llvm::SmallVector<Value, 4> retValues; 141 for (auto resType : func.getCallableResults().drop_front()) 142 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 143 144 // ------------------------------------------------------------------------ // 145 // Initialize coroutine: get coroutine id and coroutine handle. 146 // ------------------------------------------------------------------------ // 147 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 148 auto coroHdlOp = 149 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 150 builder.create<BranchOp>(originalEntryBlock); 151 152 Block *cleanupBlock = func.addBlock(); 153 Block *suspendBlock = func.addBlock(); 154 155 // ------------------------------------------------------------------------ // 156 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 157 // ------------------------------------------------------------------------ // 158 builder.setInsertionPointToStart(cleanupBlock); 159 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 160 161 // Branch into the suspend block. 162 builder.create<BranchOp>(suspendBlock); 163 164 // ------------------------------------------------------------------------ // 165 // Coroutine suspend block: mark the end of a coroutine and return allocated 166 // async token. 167 // ------------------------------------------------------------------------ // 168 builder.setInsertionPointToStart(suspendBlock); 169 170 // Mark the end of a coroutine: async.coro.end 171 builder.create<CoroEndOp>(coroHdlOp.handle()); 172 173 // Return created `async.token` and `async.values` from the suspend block. 174 // This will be the return value of a coroutine ramp function. 175 SmallVector<Value, 4> ret{retToken}; 176 ret.insert(ret.end(), retValues.begin(), retValues.end()); 177 builder.create<ReturnOp>(ret); 178 179 // `async.await` op lowering will create resume blocks for async 180 // continuations, and will conditionally branch to cleanup or suspend blocks. 181 182 for (Block &block : func.body().getBlocks()) { 183 if (&block == entryBlock || &block == cleanupBlock || 184 &block == suspendBlock) 185 continue; 186 Operation *terminator = block.getTerminator(); 187 if (auto yield = dyn_cast<YieldOp>(terminator)) { 188 builder.setInsertionPointToEnd(&block); 189 builder.create<BranchOp>(cleanupBlock); 190 } 191 } 192 193 // The switch-resumed API based coroutine should be marked with 194 // "coroutine.presplit" attribute with value "0" to mark the function as a 195 // coroutine. 196 func->setAttr("passthrough", builder.getArrayAttr(builder.getArrayAttr( 197 {builder.getStringAttr("coroutine.presplit"), 198 builder.getStringAttr("0")}))); 199 200 CoroMachinery machinery; 201 machinery.func = func; 202 machinery.asyncToken = retToken; 203 machinery.returnValues = retValues; 204 machinery.coroHandle = coroHdlOp.handle(); 205 machinery.entry = entryBlock; 206 machinery.setError = nullptr; // created lazily only if needed 207 machinery.cleanup = cleanupBlock; 208 machinery.suspend = suspendBlock; 209 return machinery; 210 } 211 212 // Lazily creates `set_error` block only if it is required for lowering to the 213 // runtime operations (see for example lowering of assert operation). 214 static Block *setupSetErrorBlock(CoroMachinery &coro) { 215 if (coro.setError) 216 return coro.setError; 217 218 coro.setError = coro.func.addBlock(); 219 coro.setError->moveBefore(coro.cleanup); 220 221 auto builder = 222 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 223 224 // Coroutine set_error block: set error on token and all returned values. 225 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 226 for (Value retValue : coro.returnValues) 227 builder.create<RuntimeSetErrorOp>(retValue); 228 229 // Branch into the cleanup block. 230 builder.create<BranchOp>(coro.cleanup); 231 232 return coro.setError; 233 } 234 235 /// Outline the body region attached to the `async.execute` op into a standalone 236 /// function. 237 /// 238 /// Note that this is not reversible transformation. 239 static std::pair<FuncOp, CoroMachinery> 240 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 241 ModuleOp module = execute->getParentOfType<ModuleOp>(); 242 243 MLIRContext *ctx = module.getContext(); 244 Location loc = execute.getLoc(); 245 246 // Make sure that all constants will be inside the outlined async function to 247 // reduce the number of function arguments. 248 cloneConstantsIntoTheRegion(execute.body()); 249 250 // Collect all outlined function inputs. 251 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 252 execute.dependencies().end()); 253 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 254 getUsedValuesDefinedAbove(execute.body(), functionInputs); 255 256 // Collect types for the outlined function inputs and outputs. 257 auto typesRange = llvm::map_range( 258 functionInputs, [](Value value) { return value.getType(); }); 259 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 260 auto outputTypes = execute.getResultTypes(); 261 262 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 263 auto funcAttrs = ArrayRef<NamedAttribute>(); 264 265 // TODO: Derive outlined function name from the parent FuncOp (support 266 // multiple nested async.execute operations). 267 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 268 symbolTable.insert(func); 269 270 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 271 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); 272 273 // Prepare for coroutine conversion by creating the body of the function. 274 { 275 size_t numDependencies = execute.dependencies().size(); 276 size_t numOperands = execute.operands().size(); 277 278 // Await on all dependencies before starting to execute the body region. 279 for (size_t i = 0; i < numDependencies; ++i) 280 builder.create<AwaitOp>(func.getArgument(i)); 281 282 // Await on all async value operands and unwrap the payload. 283 SmallVector<Value, 4> unwrappedOperands(numOperands); 284 for (size_t i = 0; i < numOperands; ++i) { 285 Value operand = func.getArgument(numDependencies + i); 286 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 287 } 288 289 // Map from function inputs defined above the execute op to the function 290 // arguments. 291 BlockAndValueMapping valueMapping; 292 valueMapping.map(functionInputs, func.getArguments()); 293 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 294 295 // Clone all operations from the execute operation body into the outlined 296 // function body. 297 for (Operation &op : execute.body().getOps()) 298 builder.clone(op, valueMapping); 299 } 300 301 // Adding entry/cleanup/suspend blocks. 302 CoroMachinery coro = setupCoroMachinery(func); 303 304 // Suspend async function at the end of an entry block, and resume it using 305 // Async resume operation (execution will be resumed in a thread managed by 306 // the async runtime). 307 { 308 BranchOp branch = cast<BranchOp>(coro.entry->getTerminator()); 309 builder.setInsertionPointToEnd(coro.entry); 310 311 // Save the coroutine state: async.coro.save 312 auto coroSaveOp = 313 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 314 315 // Pass coroutine to the runtime to be resumed on a runtime managed 316 // thread. 317 builder.create<RuntimeResumeOp>(coro.coroHandle); 318 319 // Add async.coro.suspend as a suspended block terminator. 320 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, 321 branch.getDest(), coro.cleanup); 322 323 branch.erase(); 324 } 325 326 // Replace the original `async.execute` with a call to outlined function. 327 { 328 ImplicitLocOpBuilder callBuilder(loc, execute); 329 auto callOutlinedFunc = callBuilder.create<CallOp>( 330 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 331 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 332 execute.erase(); 333 } 334 335 return {func, coro}; 336 } 337 338 //===----------------------------------------------------------------------===// 339 // Convert async.create_group operation to async.runtime.create_group 340 //===----------------------------------------------------------------------===// 341 342 namespace { 343 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 344 public: 345 using OpConversionPattern::OpConversionPattern; 346 347 LogicalResult 348 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor, 349 ConversionPatternRewriter &rewriter) const override { 350 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 351 op, GroupType::get(op->getContext()), adaptor.getOperands()); 352 return success(); 353 } 354 }; 355 } // namespace 356 357 //===----------------------------------------------------------------------===// 358 // Convert async.add_to_group operation to async.runtime.add_to_group. 359 //===----------------------------------------------------------------------===// 360 361 namespace { 362 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 363 public: 364 using OpConversionPattern::OpConversionPattern; 365 366 LogicalResult 367 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor, 368 ConversionPatternRewriter &rewriter) const override { 369 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 370 op, rewriter.getIndexType(), adaptor.getOperands()); 371 return success(); 372 } 373 }; 374 } // namespace 375 376 //===----------------------------------------------------------------------===// 377 // Convert async.await and async.await_all operations to the async.runtime.await 378 // or async.runtime.await_and_resume operations. 379 //===----------------------------------------------------------------------===// 380 381 namespace { 382 template <typename AwaitType, typename AwaitableType> 383 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 384 using AwaitAdaptor = typename AwaitType::Adaptor; 385 386 public: 387 AwaitOpLoweringBase(MLIRContext *ctx, 388 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 389 : OpConversionPattern<AwaitType>(ctx), 390 outlinedFunctions(outlinedFunctions) {} 391 392 LogicalResult 393 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, 394 ConversionPatternRewriter &rewriter) const override { 395 // We can only await on one the `AwaitableType` (for `await` it can be 396 // a `token` or a `value`, for `await_all` it must be a `group`). 397 if (!op.operand().getType().template isa<AwaitableType>()) 398 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 399 400 // Check if await operation is inside the outlined coroutine function. 401 auto func = op->template getParentOfType<FuncOp>(); 402 auto outlined = outlinedFunctions.find(func); 403 const bool isInCoroutine = outlined != outlinedFunctions.end(); 404 405 Location loc = op->getLoc(); 406 Value operand = adaptor.operand(); 407 408 Type i1 = rewriter.getI1Type(); 409 410 // Inside regular functions we use the blocking wait operation to wait for 411 // the async object (token, value or group) to become available. 412 if (!isInCoroutine) { 413 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 414 builder.create<RuntimeAwaitOp>(loc, operand); 415 416 // Assert that the awaited operands is not in the error state. 417 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand); 418 Value notError = builder.create<arith::XOrIOp>( 419 isError, builder.create<arith::ConstantOp>( 420 loc, i1, builder.getIntegerAttr(i1, 1))); 421 422 builder.create<AssertOp>(notError, 423 "Awaited async operand is in error state"); 424 } 425 426 // Inside the coroutine we convert await operation into coroutine suspension 427 // point, and resume execution asynchronously. 428 if (isInCoroutine) { 429 CoroMachinery &coro = outlined->getSecond(); 430 Block *suspended = op->getBlock(); 431 432 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 433 MLIRContext *ctx = op->getContext(); 434 435 // Save the coroutine state and resume on a runtime managed thread when 436 // the operand becomes available. 437 auto coroSaveOp = 438 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 439 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 440 441 // Split the entry block before the await operation. 442 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 443 444 // Add async.coro.suspend as a suspended block terminator. 445 builder.setInsertionPointToEnd(suspended); 446 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 447 coro.cleanup); 448 449 // Split the resume block into error checking and continuation. 450 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 451 452 // Check if the awaited value is in the error state. 453 builder.setInsertionPointToStart(resume); 454 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand); 455 builder.create<CondBranchOp>(isError, 456 /*trueDest=*/setupSetErrorBlock(coro), 457 /*trueArgs=*/ArrayRef<Value>(), 458 /*falseDest=*/continuation, 459 /*falseArgs=*/ArrayRef<Value>()); 460 461 // Make sure that replacement value will be constructed in the 462 // continuation block. 463 rewriter.setInsertionPointToStart(continuation); 464 } 465 466 // Erase or replace the await operation with the new value. 467 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 468 rewriter.replaceOp(op, replaceWith); 469 else 470 rewriter.eraseOp(op); 471 472 return success(); 473 } 474 475 virtual Value getReplacementValue(AwaitType op, Value operand, 476 ConversionPatternRewriter &rewriter) const { 477 return Value(); 478 } 479 480 private: 481 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 482 }; 483 484 /// Lowering for `async.await` with a token operand. 485 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 486 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 487 488 public: 489 using Base::Base; 490 }; 491 492 /// Lowering for `async.await` with a value operand. 493 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 494 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 495 496 public: 497 using Base::Base; 498 499 Value 500 getReplacementValue(AwaitOp op, Value operand, 501 ConversionPatternRewriter &rewriter) const override { 502 // Load from the async value storage. 503 auto valueType = operand.getType().cast<ValueType>().getValueType(); 504 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 505 } 506 }; 507 508 /// Lowering for `async.await_all` operation. 509 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 510 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 511 512 public: 513 using Base::Base; 514 }; 515 516 } // namespace 517 518 //===----------------------------------------------------------------------===// 519 // Convert async.yield operation to async.runtime operations. 520 //===----------------------------------------------------------------------===// 521 522 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 523 public: 524 YieldOpLowering( 525 MLIRContext *ctx, 526 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 527 : OpConversionPattern<async::YieldOp>(ctx), 528 outlinedFunctions(outlinedFunctions) {} 529 530 LogicalResult 531 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 532 ConversionPatternRewriter &rewriter) const override { 533 // Check if yield operation is inside the async coroutine function. 534 auto func = op->template getParentOfType<FuncOp>(); 535 auto outlined = outlinedFunctions.find(func); 536 if (outlined == outlinedFunctions.end()) 537 return rewriter.notifyMatchFailure( 538 op, "operation is not inside the async coroutine function"); 539 540 Location loc = op->getLoc(); 541 const CoroMachinery &coro = outlined->getSecond(); 542 543 // Store yielded values into the async values storage and switch async 544 // values state to available. 545 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { 546 Value yieldValue = std::get<0>(tuple); 547 Value asyncValue = std::get<1>(tuple); 548 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 549 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 550 } 551 552 // Switch the coroutine completion token to available state. 553 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 554 555 return success(); 556 } 557 558 private: 559 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 560 }; 561 562 //===----------------------------------------------------------------------===// 563 // Convert std.assert operation to cond_br into `set_error` block. 564 //===----------------------------------------------------------------------===// 565 566 class AssertOpLowering : public OpConversionPattern<AssertOp> { 567 public: 568 AssertOpLowering(MLIRContext *ctx, 569 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 570 : OpConversionPattern<AssertOp>(ctx), 571 outlinedFunctions(outlinedFunctions) {} 572 573 LogicalResult 574 matchAndRewrite(AssertOp op, OpAdaptor adaptor, 575 ConversionPatternRewriter &rewriter) const override { 576 // Check if assert operation is inside the async coroutine function. 577 auto func = op->template getParentOfType<FuncOp>(); 578 auto outlined = outlinedFunctions.find(func); 579 if (outlined == outlinedFunctions.end()) 580 return rewriter.notifyMatchFailure( 581 op, "operation is not inside the async coroutine function"); 582 583 Location loc = op->getLoc(); 584 CoroMachinery &coro = outlined->getSecond(); 585 586 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 587 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 588 rewriter.create<CondBranchOp>(loc, adaptor.getArg(), 589 /*trueDest=*/cont, 590 /*trueArgs=*/ArrayRef<Value>(), 591 /*falseDest=*/setupSetErrorBlock(coro), 592 /*falseArgs=*/ArrayRef<Value>()); 593 rewriter.eraseOp(op); 594 595 return success(); 596 } 597 598 private: 599 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 600 }; 601 602 //===----------------------------------------------------------------------===// 603 604 /// Rewrite a func as a coroutine by: 605 /// 1) Wrapping the results into `async.value`. 606 /// 2) Prepending the results with `async.token`. 607 /// 3) Setting up coroutine blocks. 608 /// 4) Rewriting return ops as yield op and branch op into the suspend block. 609 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) { 610 auto *ctx = func->getContext(); 611 auto loc = func.getLoc(); 612 SmallVector<Type> resultTypes; 613 resultTypes.reserve(func.getCallableResults().size()); 614 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), 615 [](Type type) { return ValueType::get(type); }); 616 func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes)); 617 func.insertResult(0, TokenType::get(ctx), {}); 618 for (Block &block : func.getBlocks()) { 619 Operation *terminator = block.getTerminator(); 620 if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) { 621 ImplicitLocOpBuilder builder(loc, returnOp); 622 builder.create<YieldOp>(returnOp.getOperands()); 623 returnOp.erase(); 624 } 625 } 626 return setupCoroMachinery(func); 627 } 628 629 /// Rewrites a call into a function that has been rewritten as a coroutine. 630 /// 631 /// The invocation of this function is safe only when call ops are traversed in 632 /// reverse order of how they appear in a single block. See `funcsToCoroutines`. 633 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) { 634 auto loc = func.getLoc(); 635 ImplicitLocOpBuilder callBuilder(loc, oldCall); 636 auto newCall = callBuilder.create<CallOp>( 637 func.getName(), func.getCallableResults(), oldCall.getArgOperands()); 638 639 // Await on the async token and all the value results and unwrap the latter. 640 callBuilder.create<AwaitOp>(loc, newCall.getResults().front()); 641 SmallVector<Value> unwrappedResults; 642 unwrappedResults.reserve(newCall->getResults().size() - 1); 643 for (Value result : newCall.getResults().drop_front()) 644 unwrappedResults.push_back( 645 callBuilder.create<AwaitOp>(loc, result).result()); 646 // Careful, when result of a call is piped into another call this could lead 647 // to a dangling pointer. 648 oldCall.replaceAllUsesWith(unwrappedResults); 649 oldCall.erase(); 650 } 651 652 static bool isAllowedToBlock(FuncOp func) { 653 return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName); 654 } 655 656 static LogicalResult 657 funcsToCoroutines(ModuleOp module, 658 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) { 659 // The following code supports the general case when 2 functions mutually 660 // recurse into each other. Because of this and that we are relying on 661 // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase 662 // a FuncOp while inserting an equivalent coroutine, because that could lead 663 // to dangling pointers. 664 665 SmallVector<FuncOp> funcWorklist; 666 667 // Careful, it's okay to add a func to the worklist multiple times if and only 668 // if the loop processing the worklist will skip the functions that have 669 // already been converted to coroutines. 670 auto addToWorklist = [&](FuncOp func) { 671 if (isAllowedToBlock(func)) 672 return; 673 // N.B. To refactor this code into a separate pass the lookup in 674 // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary 675 // func and recognizing if it has a coroutine structure is messy. Passing 676 // this dict between the passes is ugly. 677 if (isAllowedToBlock(func) || 678 outlinedFunctions.find(func) == outlinedFunctions.end()) { 679 for (Operation &op : func.body().getOps()) { 680 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) { 681 funcWorklist.push_back(func); 682 break; 683 } 684 } 685 } 686 }; 687 688 // Traverse in post-order collecting for each func op the await ops it has. 689 for (FuncOp func : module.getOps<FuncOp>()) 690 addToWorklist(func); 691 692 SymbolTableCollection symbolTable; 693 SymbolUserMap symbolUserMap(symbolTable, module); 694 695 // Rewrite funcs, while updating call sites and adding them to the worklist. 696 while (!funcWorklist.empty()) { 697 auto func = funcWorklist.pop_back_val(); 698 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); 699 if (!insertion.second) 700 // This function has already been processed because this is either 701 // the corecursive case, or a caller with multiple calls to a newly 702 // created corouting. Either way, skip updating the call sites. 703 continue; 704 insertion.first->second = rewriteFuncAsCoroutine(func); 705 SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(), 706 symbolUserMap.getUsers(func).end()); 707 // If there are multiple calls from the same block they need to be traversed 708 // in reverse order so that symbolUserMap references are not invalidated 709 // when updating the users of the call op which is earlier in the block. 710 llvm::sort(users, [](Operation *a, Operation *b) { 711 Block *blockA = a->getBlock(); 712 Block *blockB = b->getBlock(); 713 // Impose arbitrary order on blocks so that there is a well-defined order. 714 return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); 715 }); 716 // Rewrite the callsites to await on results of the newly created coroutine. 717 for (Operation *op : users) { 718 if (CallOp call = dyn_cast<mlir::CallOp>(*op)) { 719 FuncOp caller = call->getParentOfType<FuncOp>(); 720 rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. 721 addToWorklist(caller); 722 } else { 723 op->emitError("Unexpected reference to func referenced by symbol"); 724 return failure(); 725 } 726 } 727 } 728 return success(); 729 } 730 731 //===----------------------------------------------------------------------===// 732 void AsyncToAsyncRuntimePass::runOnOperation() { 733 ModuleOp module = getOperation(); 734 SymbolTable symbolTable(module); 735 736 // Outline all `async.execute` body regions into async functions (coroutines). 737 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 738 739 module.walk([&](ExecuteOp execute) { 740 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 741 }); 742 743 LLVM_DEBUG({ 744 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 745 << " functions built from async.execute operations\n"; 746 }); 747 748 // Returns true if operation is inside the coroutine. 749 auto isInCoroutine = [&](Operation *op) -> bool { 750 auto parentFunc = op->getParentOfType<FuncOp>(); 751 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 752 }; 753 754 if (eliminateBlockingAwaitOps && 755 failed(funcsToCoroutines(module, outlinedFunctions))) { 756 signalPassFailure(); 757 return; 758 } 759 760 // Lower async operations to async.runtime operations. 761 MLIRContext *ctx = module->getContext(); 762 RewritePatternSet asyncPatterns(ctx); 763 764 // Conversion to async runtime augments original CFG with the coroutine CFG, 765 // and we have to make sure that structured control flow operations with async 766 // operations in nested regions will be converted to branch-based control flow 767 // before we add the coroutine basic blocks. 768 populateLoopToStdConversionPatterns(asyncPatterns); 769 770 // Async lowering does not use type converter because it must preserve all 771 // types for async.runtime operations. 772 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 773 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 774 AwaitAllOpLowering, YieldOpLowering>(ctx, 775 outlinedFunctions); 776 777 // Lower assertions to conditional branches into error blocks. 778 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 779 780 // All high level async operations must be lowered to the runtime operations. 781 ConversionTarget runtimeTarget(*ctx); 782 runtimeTarget.addLegalDialect<AsyncDialect>(); 783 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 784 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 785 786 // Decide if structured control flow has to be lowered to branch-based CFG. 787 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 788 auto walkResult = op->walk([&](Operation *nested) { 789 bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 790 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 791 : WalkResult::advance(); 792 }); 793 return !walkResult.wasInterrupted(); 794 }); 795 runtimeTarget.addLegalOp<AssertOp, arith::XOrIOp, arith::ConstantOp, 796 ConstantOp, BranchOp, CondBranchOp>(); 797 798 // Assertions must be converted to runtime errors inside async functions. 799 runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool { 800 auto func = op->getParentOfType<FuncOp>(); 801 return outlinedFunctions.find(func) == outlinedFunctions.end(); 802 }); 803 804 if (eliminateBlockingAwaitOps) 805 runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>( 806 [&](RuntimeAwaitOp op) -> bool { 807 return isAllowedToBlock(op->getParentOfType<FuncOp>()); 808 }); 809 810 if (failed(applyPartialConversion(module, runtimeTarget, 811 std::move(asyncPatterns)))) { 812 signalPassFailure(); 813 return; 814 } 815 } 816 817 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 818 return std::make_unique<AsyncToAsyncRuntimePass>(); 819 } 820