1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/Async/Passes.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/RegionUtils.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/Support/Debug.h" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 #define DEBUG_TYPE "async-to-async-runtime" 32 // Prefix for functions outlined from `async.execute` op regions. 33 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 34 35 namespace { 36 37 class AsyncToAsyncRuntimePass 38 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 39 public: 40 AsyncToAsyncRuntimePass() = default; 41 void runOnOperation() override; 42 }; 43 44 } // namespace 45 46 //===----------------------------------------------------------------------===// 47 // async.execute op outlining to the coroutine functions. 48 //===----------------------------------------------------------------------===// 49 50 /// Function targeted for coroutine transformation has two additional blocks at 51 /// the end: coroutine cleanup and coroutine suspension. 52 /// 53 /// async.await op lowering additionaly creates a resume block for each 54 /// operation to enable non-blocking waiting via coroutine suspension. 55 namespace { 56 struct CoroMachinery { 57 FuncOp func; 58 59 // Async execute region returns a completion token, and an async value for 60 // each yielded value. 61 // 62 // %token, %result = async.execute -> !async.value<T> { 63 // %0 = constant ... : T 64 // async.yield %0 : T 65 // } 66 Value asyncToken; // token representing completion of the async region 67 llvm::SmallVector<Value, 4> returnValues; // returned async values 68 69 Value coroHandle; // coroutine handle (!async.coro.handle value) 70 Block *setError; // switch completion token and all values to error state 71 Block *cleanup; // coroutine cleanup block 72 Block *suspend; // coroutine suspension block 73 }; 74 } // namespace 75 76 /// Utility to partially update the regular function CFG to the coroutine CFG 77 /// compatible with LLVM coroutines switched-resume lowering using 78 /// `async.runtime.*` and `async.coro.*` operations. Modifies the entry block 79 /// by prepending its ops with coroutine setup. Also inserts trailing blocks. 80 /// 81 /// The result types of the passed `func` must start with an `async.token` 82 /// and be continued with some number of `async.value`s. 83 /// 84 /// It's up to the caller of this function to fix up the terminators of the 85 /// preexisting blocks of the passed func op. If the passed `func` is legal, 86 /// this typically means rewriting every return op as a yield op and a branch op 87 /// to the suspend block. 88 /// 89 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 90 /// 91 /// - `entry` block sets up the coroutine. 92 /// - `set_error` block sets completion token and async values state to error. 93 /// - `cleanup` block cleans up the coroutine state. 94 /// - `suspend block after the @llvm.coro.end() defines what value will be 95 /// returned to the initial caller of a coroutine. Everything before the 96 /// @llvm.coro.end() will be executed at every suspension point. 97 /// 98 /// Coroutine structure (only the important bits): 99 /// 100 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) 101 /// { 102 /// ^entry(<function-arguments>): 103 /// %token = <async token> : !async.token // create async runtime token 104 /// %value = <async value> : !async.value<T> // create async value 105 /// %id = async.coro.id // create a coroutine id 106 /// %hdl = async.coro.begin %id // create a coroutine handle 107 /// /* other ops of the preexisting entry block */ 108 /// 109 /// /* other preexisting blocks */ 110 /// 111 /// ^set_error: // this block created lazily only if needed (see code below) 112 /// async.runtime.set_error %token : !async.token 113 /// async.runtime.set_error %value : !async.value<T> 114 /// br ^cleanup 115 /// 116 /// ^cleanup: 117 /// async.coro.free %hdl // delete the coroutine state 118 /// br ^suspend 119 /// 120 /// ^suspend: 121 /// async.coro.end %hdl // marks the end of a coroutine 122 /// return %token, %value : !async.token, !async.value<T> 123 /// } 124 /// 125 static CoroMachinery setupCoroMachinery(FuncOp func) { 126 assert(!func.getBlocks().empty() && "Function must have an entry block"); 127 128 MLIRContext *ctx = func.getContext(); 129 Block *entryBlock = &func.getBlocks().front(); 130 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 131 132 // ------------------------------------------------------------------------ // 133 // Allocate async token/values that we will return from a ramp function. 134 // ------------------------------------------------------------------------ // 135 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 136 137 llvm::SmallVector<Value, 4> retValues; 138 for (auto resType : func.getCallableResults().drop_front()) 139 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 140 141 // ------------------------------------------------------------------------ // 142 // Initialize coroutine: get coroutine id and coroutine handle. 143 // ------------------------------------------------------------------------ // 144 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 145 auto coroHdlOp = 146 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 147 148 Block *cleanupBlock = func.addBlock(); 149 Block *suspendBlock = func.addBlock(); 150 151 // ------------------------------------------------------------------------ // 152 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 153 // ------------------------------------------------------------------------ // 154 builder.setInsertionPointToStart(cleanupBlock); 155 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 156 157 // Branch into the suspend block. 158 builder.create<BranchOp>(suspendBlock); 159 160 // ------------------------------------------------------------------------ // 161 // Coroutine suspend block: mark the end of a coroutine and return allocated 162 // async token. 163 // ------------------------------------------------------------------------ // 164 builder.setInsertionPointToStart(suspendBlock); 165 166 // Mark the end of a coroutine: async.coro.end 167 builder.create<CoroEndOp>(coroHdlOp.handle()); 168 169 // Return created `async.token` and `async.values` from the suspend block. 170 // This will be the return value of a coroutine ramp function. 171 SmallVector<Value, 4> ret{retToken}; 172 ret.insert(ret.end(), retValues.begin(), retValues.end()); 173 builder.create<ReturnOp>(ret); 174 175 // `async.await` op lowering will create resume blocks for async 176 // continuations, and will conditionally branch to cleanup or suspend blocks. 177 178 CoroMachinery machinery; 179 machinery.func = func; 180 machinery.asyncToken = retToken; 181 machinery.returnValues = retValues; 182 machinery.coroHandle = coroHdlOp.handle(); 183 machinery.setError = nullptr; // created lazily only if needed 184 machinery.cleanup = cleanupBlock; 185 machinery.suspend = suspendBlock; 186 return machinery; 187 } 188 189 // Lazily creates `set_error` block only if it is required for lowering to the 190 // runtime operations (see for example lowering of assert operation). 191 static Block *setupSetErrorBlock(CoroMachinery &coro) { 192 if (coro.setError) 193 return coro.setError; 194 195 coro.setError = coro.func.addBlock(); 196 coro.setError->moveBefore(coro.cleanup); 197 198 auto builder = 199 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 200 201 // Coroutine set_error block: set error on token and all returned values. 202 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 203 for (Value retValue : coro.returnValues) 204 builder.create<RuntimeSetErrorOp>(retValue); 205 206 // Branch into the cleanup block. 207 builder.create<BranchOp>(coro.cleanup); 208 209 return coro.setError; 210 } 211 212 /// Outline the body region attached to the `async.execute` op into a standalone 213 /// function. 214 /// 215 /// Note that this is not reversible transformation. 216 static std::pair<FuncOp, CoroMachinery> 217 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 218 ModuleOp module = execute->getParentOfType<ModuleOp>(); 219 220 MLIRContext *ctx = module.getContext(); 221 Location loc = execute.getLoc(); 222 223 // Collect all outlined function inputs. 224 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 225 execute.dependencies().end()); 226 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 227 getUsedValuesDefinedAbove(execute.body(), functionInputs); 228 229 // Collect types for the outlined function inputs and outputs. 230 auto typesRange = llvm::map_range( 231 functionInputs, [](Value value) { return value.getType(); }); 232 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 233 auto outputTypes = execute.getResultTypes(); 234 235 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 236 auto funcAttrs = ArrayRef<NamedAttribute>(); 237 238 // TODO: Derive outlined function name from the parent FuncOp (support 239 // multiple nested async.execute operations). 240 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 241 symbolTable.insert(func); 242 243 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 244 245 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 246 // blocks, adding async.coro operations and setting up control flow. 247 func.addEntryBlock(); 248 CoroMachinery coro = setupCoroMachinery(func); 249 250 // Suspend async function at the end of an entry block, and resume it using 251 // Async resume operation (execution will be resumed in a thread managed by 252 // the async runtime). 253 Block *entryBlock = &func.getBlocks().front(); 254 auto builder = ImplicitLocOpBuilder::atBlockEnd(loc, entryBlock); 255 256 // Save the coroutine state: async.coro.save 257 auto coroSaveOp = 258 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 259 260 // Pass coroutine to the runtime to be resumed on a runtime managed thread. 261 builder.create<RuntimeResumeOp>(coro.coroHandle); 262 builder.create<BranchOp>(coro.cleanup); 263 264 // Split the entry block before the terminator (branch to suspend block). 265 auto *terminatorOp = entryBlock->getTerminator(); 266 Block *suspended = terminatorOp->getBlock(); 267 Block *resume = suspended->splitBlock(terminatorOp); 268 269 // Add async.coro.suspend as a suspended block terminator. 270 builder.setInsertionPointToEnd(suspended); 271 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 272 coro.cleanup); 273 274 size_t numDependencies = execute.dependencies().size(); 275 size_t numOperands = execute.operands().size(); 276 277 // Await on all dependencies before starting to execute the body region. 278 builder.setInsertionPointToStart(resume); 279 for (size_t i = 0; i < numDependencies; ++i) 280 builder.create<AwaitOp>(func.getArgument(i)); 281 282 // Await on all async value operands and unwrap the payload. 283 SmallVector<Value, 4> unwrappedOperands(numOperands); 284 for (size_t i = 0; i < numOperands; ++i) { 285 Value operand = func.getArgument(numDependencies + i); 286 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 287 } 288 289 // Map from function inputs defined above the execute op to the function 290 // arguments. 291 BlockAndValueMapping valueMapping; 292 valueMapping.map(functionInputs, func.getArguments()); 293 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 294 295 // Clone all operations from the execute operation body into the outlined 296 // function body. 297 for (Operation &op : execute.body().getOps()) 298 builder.clone(op, valueMapping); 299 300 // Replace the original `async.execute` with a call to outlined function. 301 ImplicitLocOpBuilder callBuilder(loc, execute); 302 auto callOutlinedFunc = callBuilder.create<CallOp>( 303 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 304 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 305 execute.erase(); 306 307 return {func, coro}; 308 } 309 310 //===----------------------------------------------------------------------===// 311 // Convert async.create_group operation to async.runtime.create_group 312 //===----------------------------------------------------------------------===// 313 314 namespace { 315 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 316 public: 317 using OpConversionPattern::OpConversionPattern; 318 319 LogicalResult 320 matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands, 321 ConversionPatternRewriter &rewriter) const override { 322 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 323 op, GroupType::get(op->getContext()), operands); 324 return success(); 325 } 326 }; 327 } // namespace 328 329 //===----------------------------------------------------------------------===// 330 // Convert async.add_to_group operation to async.runtime.add_to_group. 331 //===----------------------------------------------------------------------===// 332 333 namespace { 334 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 335 public: 336 using OpConversionPattern::OpConversionPattern; 337 338 LogicalResult 339 matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands, 340 ConversionPatternRewriter &rewriter) const override { 341 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 342 op, rewriter.getIndexType(), operands); 343 return success(); 344 } 345 }; 346 } // namespace 347 348 //===----------------------------------------------------------------------===// 349 // Convert async.await and async.await_all operations to the async.runtime.await 350 // or async.runtime.await_and_resume operations. 351 //===----------------------------------------------------------------------===// 352 353 namespace { 354 template <typename AwaitType, typename AwaitableType> 355 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 356 using AwaitAdaptor = typename AwaitType::Adaptor; 357 358 public: 359 AwaitOpLoweringBase(MLIRContext *ctx, 360 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 361 : OpConversionPattern<AwaitType>(ctx), 362 outlinedFunctions(outlinedFunctions) {} 363 364 LogicalResult 365 matchAndRewrite(AwaitType op, ArrayRef<Value> operands, 366 ConversionPatternRewriter &rewriter) const override { 367 // We can only await on one the `AwaitableType` (for `await` it can be 368 // a `token` or a `value`, for `await_all` it must be a `group`). 369 if (!op.operand().getType().template isa<AwaitableType>()) 370 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 371 372 // Check if await operation is inside the outlined coroutine function. 373 auto func = op->template getParentOfType<FuncOp>(); 374 auto outlined = outlinedFunctions.find(func); 375 const bool isInCoroutine = outlined != outlinedFunctions.end(); 376 377 Location loc = op->getLoc(); 378 Value operand = AwaitAdaptor(operands).operand(); 379 380 // Inside regular functions we use the blocking wait operation to wait for 381 // the async object (token, value or group) to become available. 382 if (!isInCoroutine) 383 rewriter.create<RuntimeAwaitOp>(loc, operand); 384 385 // Inside the coroutine we convert await operation into coroutine suspension 386 // point, and resume execution asynchronously. 387 if (isInCoroutine) { 388 CoroMachinery &coro = outlined->getSecond(); 389 Block *suspended = op->getBlock(); 390 391 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 392 MLIRContext *ctx = op->getContext(); 393 394 // Save the coroutine state and resume on a runtime managed thread when 395 // the operand becomes available. 396 auto coroSaveOp = 397 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 398 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 399 400 // Split the entry block before the await operation. 401 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 402 403 // Add async.coro.suspend as a suspended block terminator. 404 builder.setInsertionPointToEnd(suspended); 405 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 406 coro.cleanup); 407 408 // Split the resume block into error checking and continuation. 409 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 410 411 // Check if the awaited value is in the error state. 412 builder.setInsertionPointToStart(resume); 413 auto isError = 414 builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand); 415 builder.create<CondBranchOp>(isError, 416 /*trueDest=*/setupSetErrorBlock(coro), 417 /*trueArgs=*/ArrayRef<Value>(), 418 /*falseDest=*/continuation, 419 /*falseArgs=*/ArrayRef<Value>()); 420 421 // Make sure that replacement value will be constructed in the 422 // continuation block. 423 rewriter.setInsertionPointToStart(continuation); 424 } 425 426 // Erase or replace the await operation with the new value. 427 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 428 rewriter.replaceOp(op, replaceWith); 429 else 430 rewriter.eraseOp(op); 431 432 return success(); 433 } 434 435 virtual Value getReplacementValue(AwaitType op, Value operand, 436 ConversionPatternRewriter &rewriter) const { 437 return Value(); 438 } 439 440 private: 441 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 442 }; 443 444 /// Lowering for `async.await` with a token operand. 445 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 446 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 447 448 public: 449 using Base::Base; 450 }; 451 452 /// Lowering for `async.await` with a value operand. 453 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 454 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 455 456 public: 457 using Base::Base; 458 459 Value 460 getReplacementValue(AwaitOp op, Value operand, 461 ConversionPatternRewriter &rewriter) const override { 462 // Load from the async value storage. 463 auto valueType = operand.getType().cast<ValueType>().getValueType(); 464 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 465 } 466 }; 467 468 /// Lowering for `async.await_all` operation. 469 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 470 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 471 472 public: 473 using Base::Base; 474 }; 475 476 } // namespace 477 478 //===----------------------------------------------------------------------===// 479 // Convert async.yield operation to async.runtime operations. 480 //===----------------------------------------------------------------------===// 481 482 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 483 public: 484 YieldOpLowering( 485 MLIRContext *ctx, 486 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 487 : OpConversionPattern<async::YieldOp>(ctx), 488 outlinedFunctions(outlinedFunctions) {} 489 490 LogicalResult 491 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 492 ConversionPatternRewriter &rewriter) const override { 493 // Check if yield operation is inside the async coroutine function. 494 auto func = op->template getParentOfType<FuncOp>(); 495 auto outlined = outlinedFunctions.find(func); 496 if (outlined == outlinedFunctions.end()) 497 return rewriter.notifyMatchFailure( 498 op, "operation is not inside the async coroutine function"); 499 500 Location loc = op->getLoc(); 501 const CoroMachinery &coro = outlined->getSecond(); 502 503 // Store yielded values into the async values storage and switch async 504 // values state to available. 505 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 506 Value yieldValue = std::get<0>(tuple); 507 Value asyncValue = std::get<1>(tuple); 508 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 509 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 510 } 511 512 // Switch the coroutine completion token to available state. 513 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 514 515 return success(); 516 } 517 518 private: 519 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 520 }; 521 522 //===----------------------------------------------------------------------===// 523 // Convert std.assert operation to cond_br into `set_error` block. 524 //===----------------------------------------------------------------------===// 525 526 class AssertOpLowering : public OpConversionPattern<AssertOp> { 527 public: 528 AssertOpLowering(MLIRContext *ctx, 529 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 530 : OpConversionPattern<AssertOp>(ctx), 531 outlinedFunctions(outlinedFunctions) {} 532 533 LogicalResult 534 matchAndRewrite(AssertOp op, ArrayRef<Value> operands, 535 ConversionPatternRewriter &rewriter) const override { 536 // Check if assert operation is inside the async coroutine function. 537 auto func = op->template getParentOfType<FuncOp>(); 538 auto outlined = outlinedFunctions.find(func); 539 if (outlined == outlinedFunctions.end()) 540 return rewriter.notifyMatchFailure( 541 op, "operation is not inside the async coroutine function"); 542 543 Location loc = op->getLoc(); 544 CoroMachinery &coro = outlined->getSecond(); 545 546 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 547 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 548 rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(), 549 /*trueDest=*/cont, 550 /*trueArgs=*/ArrayRef<Value>(), 551 /*falseDest=*/setupSetErrorBlock(coro), 552 /*falseArgs=*/ArrayRef<Value>()); 553 rewriter.eraseOp(op); 554 555 return success(); 556 } 557 558 private: 559 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 560 }; 561 562 //===----------------------------------------------------------------------===// 563 564 /// Rewrite a func as a coroutine by: 565 /// 1) Wrapping the results into `async.value`. 566 /// 2) Prepending the results with `async.token`. 567 /// 3) Setting up coroutine blocks. 568 /// 4) Rewriting return ops as yield op and branch op into the suspend block. 569 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) { 570 auto *ctx = func->getContext(); 571 auto loc = func.getLoc(); 572 SmallVector<Type> resultTypes; 573 resultTypes.reserve(func.getCallableResults().size()); 574 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), 575 [](Type type) { return ValueType::get(type); }); 576 func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes)); 577 func.insertResult(0, TokenType::get(ctx), {}); 578 CoroMachinery coro = setupCoroMachinery(func); 579 for (Block &block : func.getBlocks()) { 580 if (&block == coro.suspend) 581 continue; 582 583 Operation *terminator = block.getTerminator(); 584 if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) { 585 ImplicitLocOpBuilder builder(loc, returnOp); 586 builder.create<YieldOp>(returnOp.getOperands()); 587 builder.create<BranchOp>(coro.cleanup); 588 returnOp.erase(); 589 } 590 } 591 return coro; 592 } 593 594 /// Rewrites a call into a function that has been rewritten as a coroutine. 595 /// 596 /// The invocation of this function is safe only when call ops are traversed in 597 /// reverse order of how they appear in a single block. See `funcsToCoroutines`. 598 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) { 599 auto loc = func.getLoc(); 600 ImplicitLocOpBuilder callBuilder(loc, oldCall); 601 auto newCall = callBuilder.create<CallOp>( 602 func.getName(), func.getCallableResults(), oldCall.getArgOperands()); 603 604 // Await on the async token and all the value results and unwrap the latter. 605 callBuilder.create<AwaitOp>(loc, newCall.getResults().front()); 606 SmallVector<Value> unwrappedResults; 607 unwrappedResults.reserve(newCall->getResults().size() - 1); 608 for (Value result : newCall.getResults().drop_front()) 609 unwrappedResults.push_back( 610 callBuilder.create<AwaitOp>(loc, result).result()); 611 // Careful, when result of a call is piped into another call this could lead 612 // to a dangling pointer. 613 oldCall.replaceAllUsesWith(unwrappedResults); 614 oldCall.erase(); 615 } 616 617 static LogicalResult 618 funcsToCoroutines(ModuleOp module, 619 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) { 620 // The following code supports the general case when 2 functions mutually 621 // recurse into each other. Because of this and that we are relying on 622 // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase 623 // a FuncOp while inserting an equivalent coroutine, because that could lead 624 // to dangling pointers. 625 626 SmallVector<FuncOp> funcWorklist; 627 628 // Careful, it's okay to add a func to the worklist multiple times if and only 629 // if the loop processing the worklist will skip the functions that have 630 // already been converted to coroutines. 631 auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) { 632 // N.B. To refactor this code into a separate pass the lookup in 633 // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary 634 // func and recognizing if it has a coroutine structure is messy. Passing 635 // this dict between the passes is ugly. 636 if (outlinedFunctions.find(func) == outlinedFunctions.end()) { 637 for (Operation &op : func.body().getOps()) { 638 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) { 639 funcWorklist.push_back(func); 640 break; 641 } 642 } 643 } 644 }; 645 646 // Traverse in post-order collecting for each func op the await ops it has. 647 for (FuncOp func : module.getOps<FuncOp>()) 648 addToWorklist(func); 649 650 SymbolTableCollection symbolTable; 651 SymbolUserMap symbolUserMap(symbolTable, module); 652 653 // Rewrite funcs, while updating call sites and adding them to the worklist. 654 while (!funcWorklist.empty()) { 655 auto func = funcWorklist.pop_back_val(); 656 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); 657 if (!insertion.second) 658 // This function has already been processed because this is either 659 // the corecursive case, or a caller with multiple calls to a newly 660 // created corouting. Either way, skip updating the call sites. 661 continue; 662 insertion.first->second = rewriteFuncAsCoroutine(func); 663 SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(), 664 symbolUserMap.getUsers(func).end()); 665 // If there are multiple calls from the same block they need to be traversed 666 // in reverse order so that symbolUserMap references are not invalidated 667 // when updating the users of the call op which is earlier in the block. 668 llvm::sort(users, [](Operation *a, Operation *b) { 669 Block *blockA = a->getBlock(); 670 Block *blockB = b->getBlock(); 671 // Impose arbitrary order on blocks so that there is a well-defined order. 672 return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); 673 }); 674 // Rewrite the callsites to await on results of the newly created coroutine. 675 for (Operation *op : users) { 676 if (CallOp call = dyn_cast<mlir::CallOp>(*op)) { 677 FuncOp caller = call->getParentOfType<FuncOp>(); 678 rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. 679 addToWorklist(caller); 680 } else { 681 op->emitError("Unexpected reference to func referenced by symbol"); 682 return failure(); 683 } 684 } 685 } 686 return success(); 687 } 688 689 //===----------------------------------------------------------------------===// 690 void AsyncToAsyncRuntimePass::runOnOperation() { 691 ModuleOp module = getOperation(); 692 SymbolTable symbolTable(module); 693 694 // Outline all `async.execute` body regions into async functions (coroutines). 695 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 696 697 module.walk([&](ExecuteOp execute) { 698 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 699 }); 700 701 LLVM_DEBUG({ 702 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 703 << " functions built from async.execute operations\n"; 704 }); 705 706 // Returns true if operation is inside the coroutine. 707 auto isInCoroutine = [&](Operation *op) -> bool { 708 auto parentFunc = op->getParentOfType<FuncOp>(); 709 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 710 }; 711 712 if (eliminateBlockingAwaitOps && 713 failed(funcsToCoroutines(module, outlinedFunctions))) { 714 signalPassFailure(); 715 return; 716 } 717 718 // Lower async operations to async.runtime operations. 719 MLIRContext *ctx = module->getContext(); 720 RewritePatternSet asyncPatterns(ctx); 721 722 // Conversion to async runtime augments original CFG with the coroutine CFG, 723 // and we have to make sure that structured control flow operations with async 724 // operations in nested regions will be converted to branch-based control flow 725 // before we add the coroutine basic blocks. 726 populateLoopToStdConversionPatterns(asyncPatterns); 727 728 // Async lowering does not use type converter because it must preserve all 729 // types for async.runtime operations. 730 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 731 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 732 AwaitAllOpLowering, YieldOpLowering>(ctx, 733 outlinedFunctions); 734 735 // Lower assertions to conditional branches into error blocks. 736 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 737 738 // All high level async operations must be lowered to the runtime operations. 739 ConversionTarget runtimeTarget(*ctx); 740 runtimeTarget.addLegalDialect<AsyncDialect>(); 741 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 742 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 743 744 // Decide if structured control flow has to be lowered to branch-based CFG. 745 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 746 auto walkResult = op->walk([&](Operation *nested) { 747 bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 748 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 749 : WalkResult::advance(); 750 }); 751 return !walkResult.wasInterrupted(); 752 }); 753 runtimeTarget.addLegalOp<BranchOp, CondBranchOp>(); 754 755 // Assertions must be converted to runtime errors inside async functions. 756 runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool { 757 auto func = op->getParentOfType<FuncOp>(); 758 return outlinedFunctions.find(func) == outlinedFunctions.end(); 759 }); 760 761 if (eliminateBlockingAwaitOps) 762 runtimeTarget.addIllegalOp<RuntimeAwaitOp>(); 763 764 if (failed(applyPartialConversion(module, runtimeTarget, 765 std::move(asyncPatterns)))) { 766 signalPassFailure(); 767 return; 768 } 769 } 770 771 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 772 return std::make_unique<AsyncToAsyncRuntimePass>(); 773 } 774