1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Async/IR/Async.h" 18 #include "mlir/Dialect/Async/Passes.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" 21 #include "mlir/IR/BlockAndValueMapping.h" 22 #include "mlir/IR/ImplicitLocOpBuilder.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/RegionUtils.h" 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/Support/Debug.h" 28 29 using namespace mlir; 30 using namespace mlir::async; 31 32 #define DEBUG_TYPE "async-to-async-runtime" 33 // Prefix for functions outlined from `async.execute` op regions. 34 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 35 36 namespace { 37 38 class AsyncToAsyncRuntimePass 39 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 40 public: 41 AsyncToAsyncRuntimePass() = default; 42 void runOnOperation() override; 43 }; 44 45 } // namespace 46 47 //===----------------------------------------------------------------------===// 48 // async.execute op outlining to the coroutine functions. 49 //===----------------------------------------------------------------------===// 50 51 /// Function targeted for coroutine transformation has two additional blocks at 52 /// the end: coroutine cleanup and coroutine suspension. 53 /// 54 /// async.await op lowering additionaly creates a resume block for each 55 /// operation to enable non-blocking waiting via coroutine suspension. 56 namespace { 57 struct CoroMachinery { 58 FuncOp func; 59 60 // Async execute region returns a completion token, and an async value for 61 // each yielded value. 62 // 63 // %token, %result = async.execute -> !async.value<T> { 64 // %0 = arith.constant ... : T 65 // async.yield %0 : T 66 // } 67 Value asyncToken; // token representing completion of the async region 68 llvm::SmallVector<Value, 4> returnValues; // returned async values 69 70 Value coroHandle; // coroutine handle (!async.coro.handle value) 71 Block *entry; // coroutine entry block 72 Block *setError; // switch completion token and all values to error state 73 Block *cleanup; // coroutine cleanup block 74 Block *suspend; // coroutine suspension block 75 }; 76 } // namespace 77 78 /// Utility to partially update the regular function CFG to the coroutine CFG 79 /// compatible with LLVM coroutines switched-resume lowering using 80 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block 81 /// that branches into preexisting entry block. Also inserts trailing blocks. 82 /// 83 /// The result types of the passed `func` must start with an `async.token` 84 /// and be continued with some number of `async.value`s. 85 /// 86 /// The func given to this function needs to have been preprocessed to have 87 /// either branch or yield ops as terminators. Branches to the cleanup block are 88 /// inserted after each yield. 89 /// 90 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 91 /// 92 /// - `entry` block sets up the coroutine. 93 /// - `set_error` block sets completion token and async values state to error. 94 /// - `cleanup` block cleans up the coroutine state. 95 /// - `suspend block after the @llvm.coro.end() defines what value will be 96 /// returned to the initial caller of a coroutine. Everything before the 97 /// @llvm.coro.end() will be executed at every suspension point. 98 /// 99 /// Coroutine structure (only the important bits): 100 /// 101 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) 102 /// { 103 /// ^entry(<function-arguments>): 104 /// %token = <async token> : !async.token // create async runtime token 105 /// %value = <async value> : !async.value<T> // create async value 106 /// %id = async.coro.id // create a coroutine id 107 /// %hdl = async.coro.begin %id // create a coroutine handle 108 /// br ^preexisting_entry_block 109 /// 110 /// /* preexisting blocks modified to branch to the cleanup block */ 111 /// 112 /// ^set_error: // this block created lazily only if needed (see code below) 113 /// async.runtime.set_error %token : !async.token 114 /// async.runtime.set_error %value : !async.value<T> 115 /// br ^cleanup 116 /// 117 /// ^cleanup: 118 /// async.coro.free %hdl // delete the coroutine state 119 /// br ^suspend 120 /// 121 /// ^suspend: 122 /// async.coro.end %hdl // marks the end of a coroutine 123 /// return %token, %value : !async.token, !async.value<T> 124 /// } 125 /// 126 static CoroMachinery setupCoroMachinery(FuncOp func) { 127 assert(!func.getBlocks().empty() && "Function must have an entry block"); 128 129 MLIRContext *ctx = func.getContext(); 130 Block *entryBlock = &func.getBlocks().front(); 131 Block *originalEntryBlock = 132 entryBlock->splitBlock(entryBlock->getOperations().begin()); 133 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 134 135 // ------------------------------------------------------------------------ // 136 // Allocate async token/values that we will return from a ramp function. 137 // ------------------------------------------------------------------------ // 138 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 139 140 llvm::SmallVector<Value, 4> retValues; 141 for (auto resType : func.getCallableResults().drop_front()) 142 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 143 144 // ------------------------------------------------------------------------ // 145 // Initialize coroutine: get coroutine id and coroutine handle. 146 // ------------------------------------------------------------------------ // 147 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 148 auto coroHdlOp = 149 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 150 builder.create<BranchOp>(originalEntryBlock); 151 152 Block *cleanupBlock = func.addBlock(); 153 Block *suspendBlock = func.addBlock(); 154 155 // ------------------------------------------------------------------------ // 156 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 157 // ------------------------------------------------------------------------ // 158 builder.setInsertionPointToStart(cleanupBlock); 159 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 160 161 // Branch into the suspend block. 162 builder.create<BranchOp>(suspendBlock); 163 164 // ------------------------------------------------------------------------ // 165 // Coroutine suspend block: mark the end of a coroutine and return allocated 166 // async token. 167 // ------------------------------------------------------------------------ // 168 builder.setInsertionPointToStart(suspendBlock); 169 170 // Mark the end of a coroutine: async.coro.end 171 builder.create<CoroEndOp>(coroHdlOp.handle()); 172 173 // Return created `async.token` and `async.values` from the suspend block. 174 // This will be the return value of a coroutine ramp function. 175 SmallVector<Value, 4> ret{retToken}; 176 ret.insert(ret.end(), retValues.begin(), retValues.end()); 177 builder.create<ReturnOp>(ret); 178 179 // `async.await` op lowering will create resume blocks for async 180 // continuations, and will conditionally branch to cleanup or suspend blocks. 181 182 for (Block &block : func.body().getBlocks()) { 183 if (&block == entryBlock || &block == cleanupBlock || 184 &block == suspendBlock) 185 continue; 186 Operation *terminator = block.getTerminator(); 187 if (auto yield = dyn_cast<YieldOp>(terminator)) { 188 builder.setInsertionPointToEnd(&block); 189 builder.create<BranchOp>(cleanupBlock); 190 } 191 } 192 193 CoroMachinery machinery; 194 machinery.func = func; 195 machinery.asyncToken = retToken; 196 machinery.returnValues = retValues; 197 machinery.coroHandle = coroHdlOp.handle(); 198 machinery.entry = entryBlock; 199 machinery.setError = nullptr; // created lazily only if needed 200 machinery.cleanup = cleanupBlock; 201 machinery.suspend = suspendBlock; 202 return machinery; 203 } 204 205 // Lazily creates `set_error` block only if it is required for lowering to the 206 // runtime operations (see for example lowering of assert operation). 207 static Block *setupSetErrorBlock(CoroMachinery &coro) { 208 if (coro.setError) 209 return coro.setError; 210 211 coro.setError = coro.func.addBlock(); 212 coro.setError->moveBefore(coro.cleanup); 213 214 auto builder = 215 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 216 217 // Coroutine set_error block: set error on token and all returned values. 218 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 219 for (Value retValue : coro.returnValues) 220 builder.create<RuntimeSetErrorOp>(retValue); 221 222 // Branch into the cleanup block. 223 builder.create<BranchOp>(coro.cleanup); 224 225 return coro.setError; 226 } 227 228 /// Outline the body region attached to the `async.execute` op into a standalone 229 /// function. 230 /// 231 /// Note that this is not reversible transformation. 232 static std::pair<FuncOp, CoroMachinery> 233 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 234 ModuleOp module = execute->getParentOfType<ModuleOp>(); 235 236 MLIRContext *ctx = module.getContext(); 237 Location loc = execute.getLoc(); 238 239 // Make sure that all constants will be inside the outlined async function to 240 // reduce the number of function arguments. 241 cloneConstantsIntoTheRegion(execute.body()); 242 243 // Collect all outlined function inputs. 244 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 245 execute.dependencies().end()); 246 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 247 getUsedValuesDefinedAbove(execute.body(), functionInputs); 248 249 // Collect types for the outlined function inputs and outputs. 250 auto typesRange = llvm::map_range( 251 functionInputs, [](Value value) { return value.getType(); }); 252 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 253 auto outputTypes = execute.getResultTypes(); 254 255 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 256 auto funcAttrs = ArrayRef<NamedAttribute>(); 257 258 // TODO: Derive outlined function name from the parent FuncOp (support 259 // multiple nested async.execute operations). 260 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 261 symbolTable.insert(func); 262 263 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 264 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); 265 266 // Prepare for coroutine conversion by creating the body of the function. 267 { 268 size_t numDependencies = execute.dependencies().size(); 269 size_t numOperands = execute.operands().size(); 270 271 // Await on all dependencies before starting to execute the body region. 272 for (size_t i = 0; i < numDependencies; ++i) 273 builder.create<AwaitOp>(func.getArgument(i)); 274 275 // Await on all async value operands and unwrap the payload. 276 SmallVector<Value, 4> unwrappedOperands(numOperands); 277 for (size_t i = 0; i < numOperands; ++i) { 278 Value operand = func.getArgument(numDependencies + i); 279 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 280 } 281 282 // Map from function inputs defined above the execute op to the function 283 // arguments. 284 BlockAndValueMapping valueMapping; 285 valueMapping.map(functionInputs, func.getArguments()); 286 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 287 288 // Clone all operations from the execute operation body into the outlined 289 // function body. 290 for (Operation &op : execute.body().getOps()) 291 builder.clone(op, valueMapping); 292 } 293 294 // Adding entry/cleanup/suspend blocks. 295 CoroMachinery coro = setupCoroMachinery(func); 296 297 // Suspend async function at the end of an entry block, and resume it using 298 // Async resume operation (execution will be resumed in a thread managed by 299 // the async runtime). 300 { 301 BranchOp branch = cast<BranchOp>(coro.entry->getTerminator()); 302 builder.setInsertionPointToEnd(coro.entry); 303 304 // Save the coroutine state: async.coro.save 305 auto coroSaveOp = 306 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 307 308 // Pass coroutine to the runtime to be resumed on a runtime managed 309 // thread. 310 builder.create<RuntimeResumeOp>(coro.coroHandle); 311 312 // Add async.coro.suspend as a suspended block terminator. 313 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, 314 branch.getDest(), coro.cleanup); 315 316 branch.erase(); 317 } 318 319 // Replace the original `async.execute` with a call to outlined function. 320 { 321 ImplicitLocOpBuilder callBuilder(loc, execute); 322 auto callOutlinedFunc = callBuilder.create<CallOp>( 323 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 324 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 325 execute.erase(); 326 } 327 328 return {func, coro}; 329 } 330 331 //===----------------------------------------------------------------------===// 332 // Convert async.create_group operation to async.runtime.create_group 333 //===----------------------------------------------------------------------===// 334 335 namespace { 336 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 337 public: 338 using OpConversionPattern::OpConversionPattern; 339 340 LogicalResult 341 matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor, 342 ConversionPatternRewriter &rewriter) const override { 343 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 344 op, GroupType::get(op->getContext()), adaptor.getOperands()); 345 return success(); 346 } 347 }; 348 } // namespace 349 350 //===----------------------------------------------------------------------===// 351 // Convert async.add_to_group operation to async.runtime.add_to_group. 352 //===----------------------------------------------------------------------===// 353 354 namespace { 355 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 356 public: 357 using OpConversionPattern::OpConversionPattern; 358 359 LogicalResult 360 matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor, 361 ConversionPatternRewriter &rewriter) const override { 362 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 363 op, rewriter.getIndexType(), adaptor.getOperands()); 364 return success(); 365 } 366 }; 367 } // namespace 368 369 //===----------------------------------------------------------------------===// 370 // Convert async.await and async.await_all operations to the async.runtime.await 371 // or async.runtime.await_and_resume operations. 372 //===----------------------------------------------------------------------===// 373 374 namespace { 375 template <typename AwaitType, typename AwaitableType> 376 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 377 using AwaitAdaptor = typename AwaitType::Adaptor; 378 379 public: 380 AwaitOpLoweringBase(MLIRContext *ctx, 381 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 382 : OpConversionPattern<AwaitType>(ctx), 383 outlinedFunctions(outlinedFunctions) {} 384 385 LogicalResult 386 matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, 387 ConversionPatternRewriter &rewriter) const override { 388 // We can only await on one the `AwaitableType` (for `await` it can be 389 // a `token` or a `value`, for `await_all` it must be a `group`). 390 if (!op.operand().getType().template isa<AwaitableType>()) 391 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 392 393 // Check if await operation is inside the outlined coroutine function. 394 auto func = op->template getParentOfType<FuncOp>(); 395 auto outlined = outlinedFunctions.find(func); 396 const bool isInCoroutine = outlined != outlinedFunctions.end(); 397 398 Location loc = op->getLoc(); 399 Value operand = adaptor.operand(); 400 401 Type i1 = rewriter.getI1Type(); 402 403 // Inside regular functions we use the blocking wait operation to wait for 404 // the async object (token, value or group) to become available. 405 if (!isInCoroutine) { 406 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 407 builder.create<RuntimeAwaitOp>(loc, operand); 408 409 // Assert that the awaited operands is not in the error state. 410 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand); 411 Value notError = builder.create<arith::XOrIOp>( 412 isError, builder.create<arith::ConstantOp>( 413 loc, i1, builder.getIntegerAttr(i1, 1))); 414 415 builder.create<AssertOp>(notError, 416 "Awaited async operand is in error state"); 417 } 418 419 // Inside the coroutine we convert await operation into coroutine suspension 420 // point, and resume execution asynchronously. 421 if (isInCoroutine) { 422 CoroMachinery &coro = outlined->getSecond(); 423 Block *suspended = op->getBlock(); 424 425 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 426 MLIRContext *ctx = op->getContext(); 427 428 // Save the coroutine state and resume on a runtime managed thread when 429 // the operand becomes available. 430 auto coroSaveOp = 431 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 432 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 433 434 // Split the entry block before the await operation. 435 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 436 437 // Add async.coro.suspend as a suspended block terminator. 438 builder.setInsertionPointToEnd(suspended); 439 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 440 coro.cleanup); 441 442 // Split the resume block into error checking and continuation. 443 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 444 445 // Check if the awaited value is in the error state. 446 builder.setInsertionPointToStart(resume); 447 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand); 448 builder.create<CondBranchOp>(isError, 449 /*trueDest=*/setupSetErrorBlock(coro), 450 /*trueArgs=*/ArrayRef<Value>(), 451 /*falseDest=*/continuation, 452 /*falseArgs=*/ArrayRef<Value>()); 453 454 // Make sure that replacement value will be constructed in the 455 // continuation block. 456 rewriter.setInsertionPointToStart(continuation); 457 } 458 459 // Erase or replace the await operation with the new value. 460 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 461 rewriter.replaceOp(op, replaceWith); 462 else 463 rewriter.eraseOp(op); 464 465 return success(); 466 } 467 468 virtual Value getReplacementValue(AwaitType op, Value operand, 469 ConversionPatternRewriter &rewriter) const { 470 return Value(); 471 } 472 473 private: 474 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 475 }; 476 477 /// Lowering for `async.await` with a token operand. 478 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 479 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 480 481 public: 482 using Base::Base; 483 }; 484 485 /// Lowering for `async.await` with a value operand. 486 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 487 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 488 489 public: 490 using Base::Base; 491 492 Value 493 getReplacementValue(AwaitOp op, Value operand, 494 ConversionPatternRewriter &rewriter) const override { 495 // Load from the async value storage. 496 auto valueType = operand.getType().cast<ValueType>().getValueType(); 497 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 498 } 499 }; 500 501 /// Lowering for `async.await_all` operation. 502 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 503 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 504 505 public: 506 using Base::Base; 507 }; 508 509 } // namespace 510 511 //===----------------------------------------------------------------------===// 512 // Convert async.yield operation to async.runtime operations. 513 //===----------------------------------------------------------------------===// 514 515 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 516 public: 517 YieldOpLowering( 518 MLIRContext *ctx, 519 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 520 : OpConversionPattern<async::YieldOp>(ctx), 521 outlinedFunctions(outlinedFunctions) {} 522 523 LogicalResult 524 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 525 ConversionPatternRewriter &rewriter) const override { 526 // Check if yield operation is inside the async coroutine function. 527 auto func = op->template getParentOfType<FuncOp>(); 528 auto outlined = outlinedFunctions.find(func); 529 if (outlined == outlinedFunctions.end()) 530 return rewriter.notifyMatchFailure( 531 op, "operation is not inside the async coroutine function"); 532 533 Location loc = op->getLoc(); 534 const CoroMachinery &coro = outlined->getSecond(); 535 536 // Store yielded values into the async values storage and switch async 537 // values state to available. 538 for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { 539 Value yieldValue = std::get<0>(tuple); 540 Value asyncValue = std::get<1>(tuple); 541 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 542 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 543 } 544 545 // Switch the coroutine completion token to available state. 546 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 547 548 return success(); 549 } 550 551 private: 552 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 553 }; 554 555 //===----------------------------------------------------------------------===// 556 // Convert std.assert operation to cond_br into `set_error` block. 557 //===----------------------------------------------------------------------===// 558 559 class AssertOpLowering : public OpConversionPattern<AssertOp> { 560 public: 561 AssertOpLowering(MLIRContext *ctx, 562 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 563 : OpConversionPattern<AssertOp>(ctx), 564 outlinedFunctions(outlinedFunctions) {} 565 566 LogicalResult 567 matchAndRewrite(AssertOp op, OpAdaptor adaptor, 568 ConversionPatternRewriter &rewriter) const override { 569 // Check if assert operation is inside the async coroutine function. 570 auto func = op->template getParentOfType<FuncOp>(); 571 auto outlined = outlinedFunctions.find(func); 572 if (outlined == outlinedFunctions.end()) 573 return rewriter.notifyMatchFailure( 574 op, "operation is not inside the async coroutine function"); 575 576 Location loc = op->getLoc(); 577 CoroMachinery &coro = outlined->getSecond(); 578 579 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 580 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 581 rewriter.create<CondBranchOp>(loc, adaptor.getArg(), 582 /*trueDest=*/cont, 583 /*trueArgs=*/ArrayRef<Value>(), 584 /*falseDest=*/setupSetErrorBlock(coro), 585 /*falseArgs=*/ArrayRef<Value>()); 586 rewriter.eraseOp(op); 587 588 return success(); 589 } 590 591 private: 592 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 593 }; 594 595 //===----------------------------------------------------------------------===// 596 597 /// Rewrite a func as a coroutine by: 598 /// 1) Wrapping the results into `async.value`. 599 /// 2) Prepending the results with `async.token`. 600 /// 3) Setting up coroutine blocks. 601 /// 4) Rewriting return ops as yield op and branch op into the suspend block. 602 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) { 603 auto *ctx = func->getContext(); 604 auto loc = func.getLoc(); 605 SmallVector<Type> resultTypes; 606 resultTypes.reserve(func.getCallableResults().size()); 607 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), 608 [](Type type) { return ValueType::get(type); }); 609 func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes)); 610 func.insertResult(0, TokenType::get(ctx), {}); 611 for (Block &block : func.getBlocks()) { 612 Operation *terminator = block.getTerminator(); 613 if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) { 614 ImplicitLocOpBuilder builder(loc, returnOp); 615 builder.create<YieldOp>(returnOp.getOperands()); 616 returnOp.erase(); 617 } 618 } 619 return setupCoroMachinery(func); 620 } 621 622 /// Rewrites a call into a function that has been rewritten as a coroutine. 623 /// 624 /// The invocation of this function is safe only when call ops are traversed in 625 /// reverse order of how they appear in a single block. See `funcsToCoroutines`. 626 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) { 627 auto loc = func.getLoc(); 628 ImplicitLocOpBuilder callBuilder(loc, oldCall); 629 auto newCall = callBuilder.create<CallOp>( 630 func.getName(), func.getCallableResults(), oldCall.getArgOperands()); 631 632 // Await on the async token and all the value results and unwrap the latter. 633 callBuilder.create<AwaitOp>(loc, newCall.getResults().front()); 634 SmallVector<Value> unwrappedResults; 635 unwrappedResults.reserve(newCall->getResults().size() - 1); 636 for (Value result : newCall.getResults().drop_front()) 637 unwrappedResults.push_back( 638 callBuilder.create<AwaitOp>(loc, result).result()); 639 // Careful, when result of a call is piped into another call this could lead 640 // to a dangling pointer. 641 oldCall.replaceAllUsesWith(unwrappedResults); 642 oldCall.erase(); 643 } 644 645 static bool isAllowedToBlock(FuncOp func) { 646 return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName); 647 } 648 649 static LogicalResult 650 funcsToCoroutines(ModuleOp module, 651 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) { 652 // The following code supports the general case when 2 functions mutually 653 // recurse into each other. Because of this and that we are relying on 654 // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase 655 // a FuncOp while inserting an equivalent coroutine, because that could lead 656 // to dangling pointers. 657 658 SmallVector<FuncOp> funcWorklist; 659 660 // Careful, it's okay to add a func to the worklist multiple times if and only 661 // if the loop processing the worklist will skip the functions that have 662 // already been converted to coroutines. 663 auto addToWorklist = [&](FuncOp func) { 664 if (isAllowedToBlock(func)) 665 return; 666 // N.B. To refactor this code into a separate pass the lookup in 667 // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary 668 // func and recognizing if it has a coroutine structure is messy. Passing 669 // this dict between the passes is ugly. 670 if (isAllowedToBlock(func) || 671 outlinedFunctions.find(func) == outlinedFunctions.end()) { 672 for (Operation &op : func.body().getOps()) { 673 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) { 674 funcWorklist.push_back(func); 675 break; 676 } 677 } 678 } 679 }; 680 681 // Traverse in post-order collecting for each func op the await ops it has. 682 for (FuncOp func : module.getOps<FuncOp>()) 683 addToWorklist(func); 684 685 SymbolTableCollection symbolTable; 686 SymbolUserMap symbolUserMap(symbolTable, module); 687 688 // Rewrite funcs, while updating call sites and adding them to the worklist. 689 while (!funcWorklist.empty()) { 690 auto func = funcWorklist.pop_back_val(); 691 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); 692 if (!insertion.second) 693 // This function has already been processed because this is either 694 // the corecursive case, or a caller with multiple calls to a newly 695 // created corouting. Either way, skip updating the call sites. 696 continue; 697 insertion.first->second = rewriteFuncAsCoroutine(func); 698 SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(), 699 symbolUserMap.getUsers(func).end()); 700 // If there are multiple calls from the same block they need to be traversed 701 // in reverse order so that symbolUserMap references are not invalidated 702 // when updating the users of the call op which is earlier in the block. 703 llvm::sort(users, [](Operation *a, Operation *b) { 704 Block *blockA = a->getBlock(); 705 Block *blockB = b->getBlock(); 706 // Impose arbitrary order on blocks so that there is a well-defined order. 707 return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); 708 }); 709 // Rewrite the callsites to await on results of the newly created coroutine. 710 for (Operation *op : users) { 711 if (CallOp call = dyn_cast<mlir::CallOp>(*op)) { 712 FuncOp caller = call->getParentOfType<FuncOp>(); 713 rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. 714 addToWorklist(caller); 715 } else { 716 op->emitError("Unexpected reference to func referenced by symbol"); 717 return failure(); 718 } 719 } 720 } 721 return success(); 722 } 723 724 //===----------------------------------------------------------------------===// 725 void AsyncToAsyncRuntimePass::runOnOperation() { 726 ModuleOp module = getOperation(); 727 SymbolTable symbolTable(module); 728 729 // Outline all `async.execute` body regions into async functions (coroutines). 730 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 731 732 module.walk([&](ExecuteOp execute) { 733 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 734 }); 735 736 LLVM_DEBUG({ 737 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 738 << " functions built from async.execute operations\n"; 739 }); 740 741 // Returns true if operation is inside the coroutine. 742 auto isInCoroutine = [&](Operation *op) -> bool { 743 auto parentFunc = op->getParentOfType<FuncOp>(); 744 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 745 }; 746 747 if (eliminateBlockingAwaitOps && 748 failed(funcsToCoroutines(module, outlinedFunctions))) { 749 signalPassFailure(); 750 return; 751 } 752 753 // Lower async operations to async.runtime operations. 754 MLIRContext *ctx = module->getContext(); 755 RewritePatternSet asyncPatterns(ctx); 756 757 // Conversion to async runtime augments original CFG with the coroutine CFG, 758 // and we have to make sure that structured control flow operations with async 759 // operations in nested regions will be converted to branch-based control flow 760 // before we add the coroutine basic blocks. 761 populateLoopToStdConversionPatterns(asyncPatterns); 762 763 // Async lowering does not use type converter because it must preserve all 764 // types for async.runtime operations. 765 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 766 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 767 AwaitAllOpLowering, YieldOpLowering>(ctx, 768 outlinedFunctions); 769 770 // Lower assertions to conditional branches into error blocks. 771 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 772 773 // All high level async operations must be lowered to the runtime operations. 774 ConversionTarget runtimeTarget(*ctx); 775 runtimeTarget.addLegalDialect<AsyncDialect>(); 776 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 777 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 778 779 // Decide if structured control flow has to be lowered to branch-based CFG. 780 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 781 auto walkResult = op->walk([&](Operation *nested) { 782 bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 783 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 784 : WalkResult::advance(); 785 }); 786 return !walkResult.wasInterrupted(); 787 }); 788 runtimeTarget.addLegalOp<AssertOp, arith::XOrIOp, arith::ConstantOp, 789 ConstantOp, BranchOp, CondBranchOp>(); 790 791 // Assertions must be converted to runtime errors inside async functions. 792 runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool { 793 auto func = op->getParentOfType<FuncOp>(); 794 return outlinedFunctions.find(func) == outlinedFunctions.end(); 795 }); 796 797 if (eliminateBlockingAwaitOps) 798 runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>( 799 [&](RuntimeAwaitOp op) -> bool { 800 return isAllowedToBlock(op->getParentOfType<FuncOp>()); 801 }); 802 803 if (failed(applyPartialConversion(module, runtimeTarget, 804 std::move(asyncPatterns)))) { 805 signalPassFailure(); 806 return; 807 } 808 } 809 810 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 811 return std::make_unique<AsyncToAsyncRuntimePass>(); 812 } 813