1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/Async/Passes.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/RegionUtils.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/Support/Debug.h" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 #define DEBUG_TYPE "async-to-async-runtime" 32 // Prefix for functions outlined from `async.execute` op regions. 33 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 34 35 namespace { 36 37 class AsyncToAsyncRuntimePass 38 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 39 public: 40 AsyncToAsyncRuntimePass() = default; 41 void runOnOperation() override; 42 }; 43 44 } // namespace 45 46 //===----------------------------------------------------------------------===// 47 // async.execute op outlining to the coroutine functions. 48 //===----------------------------------------------------------------------===// 49 50 /// Function targeted for coroutine transformation has two additional blocks at 51 /// the end: coroutine cleanup and coroutine suspension. 52 /// 53 /// async.await op lowering additionaly creates a resume block for each 54 /// operation to enable non-blocking waiting via coroutine suspension. 55 namespace { 56 struct CoroMachinery { 57 FuncOp func; 58 59 // Async execute region returns a completion token, and an async value for 60 // each yielded value. 61 // 62 // %token, %result = async.execute -> !async.value<T> { 63 // %0 = constant ... : T 64 // async.yield %0 : T 65 // } 66 Value asyncToken; // token representing completion of the async region 67 llvm::SmallVector<Value, 4> returnValues; // returned async values 68 69 Value coroHandle; // coroutine handle (!async.coro.handle value) 70 Block *entry; // coroutine entry block 71 Block *setError; // switch completion token and all values to error state 72 Block *cleanup; // coroutine cleanup block 73 Block *suspend; // coroutine suspension block 74 }; 75 } // namespace 76 77 /// Utility to partially update the regular function CFG to the coroutine CFG 78 /// compatible with LLVM coroutines switched-resume lowering using 79 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block 80 /// that branches into preexisting entry block. Also inserts trailing blocks. 81 /// 82 /// The result types of the passed `func` must start with an `async.token` 83 /// and be continued with some number of `async.value`s. 84 /// 85 /// The func given to this function needs to have been preprocessed to have 86 /// either branch or yield ops as terminators. Branches to the cleanup block are 87 /// inserted after each yield. 88 /// 89 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 90 /// 91 /// - `entry` block sets up the coroutine. 92 /// - `set_error` block sets completion token and async values state to error. 93 /// - `cleanup` block cleans up the coroutine state. 94 /// - `suspend block after the @llvm.coro.end() defines what value will be 95 /// returned to the initial caller of a coroutine. Everything before the 96 /// @llvm.coro.end() will be executed at every suspension point. 97 /// 98 /// Coroutine structure (only the important bits): 99 /// 100 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) 101 /// { 102 /// ^entry(<function-arguments>): 103 /// %token = <async token> : !async.token // create async runtime token 104 /// %value = <async value> : !async.value<T> // create async value 105 /// %id = async.coro.id // create a coroutine id 106 /// %hdl = async.coro.begin %id // create a coroutine handle 107 /// br ^preexisting_entry_block 108 /// 109 /// /* preexisting blocks modified to branch to the cleanup block */ 110 /// 111 /// ^set_error: // this block created lazily only if needed (see code below) 112 /// async.runtime.set_error %token : !async.token 113 /// async.runtime.set_error %value : !async.value<T> 114 /// br ^cleanup 115 /// 116 /// ^cleanup: 117 /// async.coro.free %hdl // delete the coroutine state 118 /// br ^suspend 119 /// 120 /// ^suspend: 121 /// async.coro.end %hdl // marks the end of a coroutine 122 /// return %token, %value : !async.token, !async.value<T> 123 /// } 124 /// 125 static CoroMachinery setupCoroMachinery(FuncOp func) { 126 assert(!func.getBlocks().empty() && "Function must have an entry block"); 127 128 MLIRContext *ctx = func.getContext(); 129 Block *entryBlock = &func.getBlocks().front(); 130 Block *originalEntryBlock = 131 entryBlock->splitBlock(entryBlock->getOperations().begin()); 132 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 133 134 // ------------------------------------------------------------------------ // 135 // Allocate async token/values that we will return from a ramp function. 136 // ------------------------------------------------------------------------ // 137 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 138 139 llvm::SmallVector<Value, 4> retValues; 140 for (auto resType : func.getCallableResults().drop_front()) 141 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 142 143 // ------------------------------------------------------------------------ // 144 // Initialize coroutine: get coroutine id and coroutine handle. 145 // ------------------------------------------------------------------------ // 146 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 147 auto coroHdlOp = 148 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 149 builder.create<BranchOp>(originalEntryBlock); 150 151 Block *cleanupBlock = func.addBlock(); 152 Block *suspendBlock = func.addBlock(); 153 154 // ------------------------------------------------------------------------ // 155 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 156 // ------------------------------------------------------------------------ // 157 builder.setInsertionPointToStart(cleanupBlock); 158 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 159 160 // Branch into the suspend block. 161 builder.create<BranchOp>(suspendBlock); 162 163 // ------------------------------------------------------------------------ // 164 // Coroutine suspend block: mark the end of a coroutine and return allocated 165 // async token. 166 // ------------------------------------------------------------------------ // 167 builder.setInsertionPointToStart(suspendBlock); 168 169 // Mark the end of a coroutine: async.coro.end 170 builder.create<CoroEndOp>(coroHdlOp.handle()); 171 172 // Return created `async.token` and `async.values` from the suspend block. 173 // This will be the return value of a coroutine ramp function. 174 SmallVector<Value, 4> ret{retToken}; 175 ret.insert(ret.end(), retValues.begin(), retValues.end()); 176 builder.create<ReturnOp>(ret); 177 178 // `async.await` op lowering will create resume blocks for async 179 // continuations, and will conditionally branch to cleanup or suspend blocks. 180 181 for (Block &block : func.body().getBlocks()) { 182 if (&block == entryBlock || &block == cleanupBlock || 183 &block == suspendBlock) 184 continue; 185 Operation *terminator = block.getTerminator(); 186 if (auto yield = dyn_cast<YieldOp>(terminator)) { 187 builder.setInsertionPointToEnd(&block); 188 builder.create<BranchOp>(cleanupBlock); 189 } 190 } 191 192 CoroMachinery machinery; 193 machinery.func = func; 194 machinery.asyncToken = retToken; 195 machinery.returnValues = retValues; 196 machinery.coroHandle = coroHdlOp.handle(); 197 machinery.entry = entryBlock; 198 machinery.setError = nullptr; // created lazily only if needed 199 machinery.cleanup = cleanupBlock; 200 machinery.suspend = suspendBlock; 201 return machinery; 202 } 203 204 // Lazily creates `set_error` block only if it is required for lowering to the 205 // runtime operations (see for example lowering of assert operation). 206 static Block *setupSetErrorBlock(CoroMachinery &coro) { 207 if (coro.setError) 208 return coro.setError; 209 210 coro.setError = coro.func.addBlock(); 211 coro.setError->moveBefore(coro.cleanup); 212 213 auto builder = 214 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 215 216 // Coroutine set_error block: set error on token and all returned values. 217 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 218 for (Value retValue : coro.returnValues) 219 builder.create<RuntimeSetErrorOp>(retValue); 220 221 // Branch into the cleanup block. 222 builder.create<BranchOp>(coro.cleanup); 223 224 return coro.setError; 225 } 226 227 /// Outline the body region attached to the `async.execute` op into a standalone 228 /// function. 229 /// 230 /// Note that this is not reversible transformation. 231 static std::pair<FuncOp, CoroMachinery> 232 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 233 ModuleOp module = execute->getParentOfType<ModuleOp>(); 234 235 MLIRContext *ctx = module.getContext(); 236 Location loc = execute.getLoc(); 237 238 // Make sure that all constants will be inside the outlined async function to 239 // reduce the number of function arguments. 240 cloneConstantsIntoTheRegion(execute.body()); 241 242 // Collect all outlined function inputs. 243 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 244 execute.dependencies().end()); 245 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 246 getUsedValuesDefinedAbove(execute.body(), functionInputs); 247 248 // Collect types for the outlined function inputs and outputs. 249 auto typesRange = llvm::map_range( 250 functionInputs, [](Value value) { return value.getType(); }); 251 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 252 auto outputTypes = execute.getResultTypes(); 253 254 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 255 auto funcAttrs = ArrayRef<NamedAttribute>(); 256 257 // TODO: Derive outlined function name from the parent FuncOp (support 258 // multiple nested async.execute operations). 259 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 260 symbolTable.insert(func); 261 262 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 263 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); 264 265 // Prepare for coroutine conversion by creating the body of the function. 266 { 267 size_t numDependencies = execute.dependencies().size(); 268 size_t numOperands = execute.operands().size(); 269 270 // Await on all dependencies before starting to execute the body region. 271 for (size_t i = 0; i < numDependencies; ++i) 272 builder.create<AwaitOp>(func.getArgument(i)); 273 274 // Await on all async value operands and unwrap the payload. 275 SmallVector<Value, 4> unwrappedOperands(numOperands); 276 for (size_t i = 0; i < numOperands; ++i) { 277 Value operand = func.getArgument(numDependencies + i); 278 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 279 } 280 281 // Map from function inputs defined above the execute op to the function 282 // arguments. 283 BlockAndValueMapping valueMapping; 284 valueMapping.map(functionInputs, func.getArguments()); 285 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 286 287 // Clone all operations from the execute operation body into the outlined 288 // function body. 289 for (Operation &op : execute.body().getOps()) 290 builder.clone(op, valueMapping); 291 } 292 293 // Adding entry/cleanup/suspend blocks. 294 CoroMachinery coro = setupCoroMachinery(func); 295 296 // Suspend async function at the end of an entry block, and resume it using 297 // Async resume operation (execution will be resumed in a thread managed by 298 // the async runtime). 299 { 300 BranchOp branch = cast<BranchOp>(coro.entry->getTerminator()); 301 builder.setInsertionPointToEnd(coro.entry); 302 303 // Save the coroutine state: async.coro.save 304 auto coroSaveOp = 305 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 306 307 // Pass coroutine to the runtime to be resumed on a runtime managed 308 // thread. 309 builder.create<RuntimeResumeOp>(coro.coroHandle); 310 311 // Add async.coro.suspend as a suspended block terminator. 312 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, 313 branch.getDest(), coro.cleanup); 314 315 branch.erase(); 316 } 317 318 // Replace the original `async.execute` with a call to outlined function. 319 { 320 ImplicitLocOpBuilder callBuilder(loc, execute); 321 auto callOutlinedFunc = callBuilder.create<CallOp>( 322 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 323 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 324 execute.erase(); 325 } 326 327 return {func, coro}; 328 } 329 330 //===----------------------------------------------------------------------===// 331 // Convert async.create_group operation to async.runtime.create_group 332 //===----------------------------------------------------------------------===// 333 334 namespace { 335 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 336 public: 337 using OpConversionPattern::OpConversionPattern; 338 339 LogicalResult 340 matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands, 341 ConversionPatternRewriter &rewriter) const override { 342 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 343 op, GroupType::get(op->getContext()), operands); 344 return success(); 345 } 346 }; 347 } // namespace 348 349 //===----------------------------------------------------------------------===// 350 // Convert async.add_to_group operation to async.runtime.add_to_group. 351 //===----------------------------------------------------------------------===// 352 353 namespace { 354 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 355 public: 356 using OpConversionPattern::OpConversionPattern; 357 358 LogicalResult 359 matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands, 360 ConversionPatternRewriter &rewriter) const override { 361 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 362 op, rewriter.getIndexType(), operands); 363 return success(); 364 } 365 }; 366 } // namespace 367 368 //===----------------------------------------------------------------------===// 369 // Convert async.await and async.await_all operations to the async.runtime.await 370 // or async.runtime.await_and_resume operations. 371 //===----------------------------------------------------------------------===// 372 373 namespace { 374 template <typename AwaitType, typename AwaitableType> 375 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 376 using AwaitAdaptor = typename AwaitType::Adaptor; 377 378 public: 379 AwaitOpLoweringBase(MLIRContext *ctx, 380 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 381 : OpConversionPattern<AwaitType>(ctx), 382 outlinedFunctions(outlinedFunctions) {} 383 384 LogicalResult 385 matchAndRewrite(AwaitType op, ArrayRef<Value> operands, 386 ConversionPatternRewriter &rewriter) const override { 387 // We can only await on one the `AwaitableType` (for `await` it can be 388 // a `token` or a `value`, for `await_all` it must be a `group`). 389 if (!op.operand().getType().template isa<AwaitableType>()) 390 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 391 392 // Check if await operation is inside the outlined coroutine function. 393 auto func = op->template getParentOfType<FuncOp>(); 394 auto outlined = outlinedFunctions.find(func); 395 const bool isInCoroutine = outlined != outlinedFunctions.end(); 396 397 Location loc = op->getLoc(); 398 Value operand = AwaitAdaptor(operands).operand(); 399 400 // Inside regular functions we use the blocking wait operation to wait for 401 // the async object (token, value or group) to become available. 402 if (!isInCoroutine) 403 rewriter.create<RuntimeAwaitOp>(loc, operand); 404 405 // Inside the coroutine we convert await operation into coroutine suspension 406 // point, and resume execution asynchronously. 407 if (isInCoroutine) { 408 CoroMachinery &coro = outlined->getSecond(); 409 Block *suspended = op->getBlock(); 410 411 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 412 MLIRContext *ctx = op->getContext(); 413 414 // Save the coroutine state and resume on a runtime managed thread when 415 // the operand becomes available. 416 auto coroSaveOp = 417 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 418 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 419 420 // Split the entry block before the await operation. 421 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 422 423 // Add async.coro.suspend as a suspended block terminator. 424 builder.setInsertionPointToEnd(suspended); 425 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 426 coro.cleanup); 427 428 // Split the resume block into error checking and continuation. 429 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 430 431 // Check if the awaited value is in the error state. 432 builder.setInsertionPointToStart(resume); 433 auto isError = 434 builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand); 435 builder.create<CondBranchOp>(isError, 436 /*trueDest=*/setupSetErrorBlock(coro), 437 /*trueArgs=*/ArrayRef<Value>(), 438 /*falseDest=*/continuation, 439 /*falseArgs=*/ArrayRef<Value>()); 440 441 // Make sure that replacement value will be constructed in the 442 // continuation block. 443 rewriter.setInsertionPointToStart(continuation); 444 } 445 446 // Erase or replace the await operation with the new value. 447 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 448 rewriter.replaceOp(op, replaceWith); 449 else 450 rewriter.eraseOp(op); 451 452 return success(); 453 } 454 455 virtual Value getReplacementValue(AwaitType op, Value operand, 456 ConversionPatternRewriter &rewriter) const { 457 return Value(); 458 } 459 460 private: 461 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 462 }; 463 464 /// Lowering for `async.await` with a token operand. 465 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 466 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 467 468 public: 469 using Base::Base; 470 }; 471 472 /// Lowering for `async.await` with a value operand. 473 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 474 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 475 476 public: 477 using Base::Base; 478 479 Value 480 getReplacementValue(AwaitOp op, Value operand, 481 ConversionPatternRewriter &rewriter) const override { 482 // Load from the async value storage. 483 auto valueType = operand.getType().cast<ValueType>().getValueType(); 484 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 485 } 486 }; 487 488 /// Lowering for `async.await_all` operation. 489 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 490 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 491 492 public: 493 using Base::Base; 494 }; 495 496 } // namespace 497 498 //===----------------------------------------------------------------------===// 499 // Convert async.yield operation to async.runtime operations. 500 //===----------------------------------------------------------------------===// 501 502 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 503 public: 504 YieldOpLowering( 505 MLIRContext *ctx, 506 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 507 : OpConversionPattern<async::YieldOp>(ctx), 508 outlinedFunctions(outlinedFunctions) {} 509 510 LogicalResult 511 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 512 ConversionPatternRewriter &rewriter) const override { 513 // Check if yield operation is inside the async coroutine function. 514 auto func = op->template getParentOfType<FuncOp>(); 515 auto outlined = outlinedFunctions.find(func); 516 if (outlined == outlinedFunctions.end()) 517 return rewriter.notifyMatchFailure( 518 op, "operation is not inside the async coroutine function"); 519 520 Location loc = op->getLoc(); 521 const CoroMachinery &coro = outlined->getSecond(); 522 523 // Store yielded values into the async values storage and switch async 524 // values state to available. 525 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 526 Value yieldValue = std::get<0>(tuple); 527 Value asyncValue = std::get<1>(tuple); 528 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 529 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 530 } 531 532 // Switch the coroutine completion token to available state. 533 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 534 535 return success(); 536 } 537 538 private: 539 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 540 }; 541 542 //===----------------------------------------------------------------------===// 543 // Convert std.assert operation to cond_br into `set_error` block. 544 //===----------------------------------------------------------------------===// 545 546 class AssertOpLowering : public OpConversionPattern<AssertOp> { 547 public: 548 AssertOpLowering(MLIRContext *ctx, 549 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 550 : OpConversionPattern<AssertOp>(ctx), 551 outlinedFunctions(outlinedFunctions) {} 552 553 LogicalResult 554 matchAndRewrite(AssertOp op, ArrayRef<Value> operands, 555 ConversionPatternRewriter &rewriter) const override { 556 // Check if assert operation is inside the async coroutine function. 557 auto func = op->template getParentOfType<FuncOp>(); 558 auto outlined = outlinedFunctions.find(func); 559 if (outlined == outlinedFunctions.end()) 560 return rewriter.notifyMatchFailure( 561 op, "operation is not inside the async coroutine function"); 562 563 Location loc = op->getLoc(); 564 CoroMachinery &coro = outlined->getSecond(); 565 566 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 567 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 568 rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(), 569 /*trueDest=*/cont, 570 /*trueArgs=*/ArrayRef<Value>(), 571 /*falseDest=*/setupSetErrorBlock(coro), 572 /*falseArgs=*/ArrayRef<Value>()); 573 rewriter.eraseOp(op); 574 575 return success(); 576 } 577 578 private: 579 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 580 }; 581 582 //===----------------------------------------------------------------------===// 583 584 /// Rewrite a func as a coroutine by: 585 /// 1) Wrapping the results into `async.value`. 586 /// 2) Prepending the results with `async.token`. 587 /// 3) Setting up coroutine blocks. 588 /// 4) Rewriting return ops as yield op and branch op into the suspend block. 589 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) { 590 auto *ctx = func->getContext(); 591 auto loc = func.getLoc(); 592 SmallVector<Type> resultTypes; 593 resultTypes.reserve(func.getCallableResults().size()); 594 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), 595 [](Type type) { return ValueType::get(type); }); 596 func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes)); 597 func.insertResult(0, TokenType::get(ctx), {}); 598 for (Block &block : func.getBlocks()) { 599 Operation *terminator = block.getTerminator(); 600 if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) { 601 ImplicitLocOpBuilder builder(loc, returnOp); 602 builder.create<YieldOp>(returnOp.getOperands()); 603 returnOp.erase(); 604 } 605 } 606 return setupCoroMachinery(func); 607 } 608 609 /// Rewrites a call into a function that has been rewritten as a coroutine. 610 /// 611 /// The invocation of this function is safe only when call ops are traversed in 612 /// reverse order of how they appear in a single block. See `funcsToCoroutines`. 613 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) { 614 auto loc = func.getLoc(); 615 ImplicitLocOpBuilder callBuilder(loc, oldCall); 616 auto newCall = callBuilder.create<CallOp>( 617 func.getName(), func.getCallableResults(), oldCall.getArgOperands()); 618 619 // Await on the async token and all the value results and unwrap the latter. 620 callBuilder.create<AwaitOp>(loc, newCall.getResults().front()); 621 SmallVector<Value> unwrappedResults; 622 unwrappedResults.reserve(newCall->getResults().size() - 1); 623 for (Value result : newCall.getResults().drop_front()) 624 unwrappedResults.push_back( 625 callBuilder.create<AwaitOp>(loc, result).result()); 626 // Careful, when result of a call is piped into another call this could lead 627 // to a dangling pointer. 628 oldCall.replaceAllUsesWith(unwrappedResults); 629 oldCall.erase(); 630 } 631 632 static bool isAllowedToBlock(FuncOp func) { 633 return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName); 634 } 635 636 static LogicalResult 637 funcsToCoroutines(ModuleOp module, 638 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) { 639 // The following code supports the general case when 2 functions mutually 640 // recurse into each other. Because of this and that we are relying on 641 // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase 642 // a FuncOp while inserting an equivalent coroutine, because that could lead 643 // to dangling pointers. 644 645 SmallVector<FuncOp> funcWorklist; 646 647 // Careful, it's okay to add a func to the worklist multiple times if and only 648 // if the loop processing the worklist will skip the functions that have 649 // already been converted to coroutines. 650 auto addToWorklist = [&](FuncOp func) { 651 if (isAllowedToBlock(func)) 652 return; 653 // N.B. To refactor this code into a separate pass the lookup in 654 // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary 655 // func and recognizing if it has a coroutine structure is messy. Passing 656 // this dict between the passes is ugly. 657 if (isAllowedToBlock(func) || 658 outlinedFunctions.find(func) == outlinedFunctions.end()) { 659 for (Operation &op : func.body().getOps()) { 660 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) { 661 funcWorklist.push_back(func); 662 break; 663 } 664 } 665 } 666 }; 667 668 // Traverse in post-order collecting for each func op the await ops it has. 669 for (FuncOp func : module.getOps<FuncOp>()) 670 addToWorklist(func); 671 672 SymbolTableCollection symbolTable; 673 SymbolUserMap symbolUserMap(symbolTable, module); 674 675 // Rewrite funcs, while updating call sites and adding them to the worklist. 676 while (!funcWorklist.empty()) { 677 auto func = funcWorklist.pop_back_val(); 678 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); 679 if (!insertion.second) 680 // This function has already been processed because this is either 681 // the corecursive case, or a caller with multiple calls to a newly 682 // created corouting. Either way, skip updating the call sites. 683 continue; 684 insertion.first->second = rewriteFuncAsCoroutine(func); 685 SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(), 686 symbolUserMap.getUsers(func).end()); 687 // If there are multiple calls from the same block they need to be traversed 688 // in reverse order so that symbolUserMap references are not invalidated 689 // when updating the users of the call op which is earlier in the block. 690 llvm::sort(users, [](Operation *a, Operation *b) { 691 Block *blockA = a->getBlock(); 692 Block *blockB = b->getBlock(); 693 // Impose arbitrary order on blocks so that there is a well-defined order. 694 return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); 695 }); 696 // Rewrite the callsites to await on results of the newly created coroutine. 697 for (Operation *op : users) { 698 if (CallOp call = dyn_cast<mlir::CallOp>(*op)) { 699 FuncOp caller = call->getParentOfType<FuncOp>(); 700 rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. 701 addToWorklist(caller); 702 } else { 703 op->emitError("Unexpected reference to func referenced by symbol"); 704 return failure(); 705 } 706 } 707 } 708 return success(); 709 } 710 711 //===----------------------------------------------------------------------===// 712 void AsyncToAsyncRuntimePass::runOnOperation() { 713 ModuleOp module = getOperation(); 714 SymbolTable symbolTable(module); 715 716 // Outline all `async.execute` body regions into async functions (coroutines). 717 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 718 719 module.walk([&](ExecuteOp execute) { 720 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 721 }); 722 723 LLVM_DEBUG({ 724 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 725 << " functions built from async.execute operations\n"; 726 }); 727 728 // Returns true if operation is inside the coroutine. 729 auto isInCoroutine = [&](Operation *op) -> bool { 730 auto parentFunc = op->getParentOfType<FuncOp>(); 731 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 732 }; 733 734 if (eliminateBlockingAwaitOps && 735 failed(funcsToCoroutines(module, outlinedFunctions))) { 736 signalPassFailure(); 737 return; 738 } 739 740 // Lower async operations to async.runtime operations. 741 MLIRContext *ctx = module->getContext(); 742 RewritePatternSet asyncPatterns(ctx); 743 744 // Conversion to async runtime augments original CFG with the coroutine CFG, 745 // and we have to make sure that structured control flow operations with async 746 // operations in nested regions will be converted to branch-based control flow 747 // before we add the coroutine basic blocks. 748 populateLoopToStdConversionPatterns(asyncPatterns); 749 750 // Async lowering does not use type converter because it must preserve all 751 // types for async.runtime operations. 752 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 753 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 754 AwaitAllOpLowering, YieldOpLowering>(ctx, 755 outlinedFunctions); 756 757 // Lower assertions to conditional branches into error blocks. 758 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 759 760 // All high level async operations must be lowered to the runtime operations. 761 ConversionTarget runtimeTarget(*ctx); 762 runtimeTarget.addLegalDialect<AsyncDialect>(); 763 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 764 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 765 766 // Decide if structured control flow has to be lowered to branch-based CFG. 767 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 768 auto walkResult = op->walk([&](Operation *nested) { 769 bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 770 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 771 : WalkResult::advance(); 772 }); 773 return !walkResult.wasInterrupted(); 774 }); 775 runtimeTarget.addLegalOp<BranchOp, CondBranchOp>(); 776 777 // Assertions must be converted to runtime errors inside async functions. 778 runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool { 779 auto func = op->getParentOfType<FuncOp>(); 780 return outlinedFunctions.find(func) == outlinedFunctions.end(); 781 }); 782 783 if (eliminateBlockingAwaitOps) 784 runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>( 785 [&](RuntimeAwaitOp op) -> bool { 786 return isAllowedToBlock(op->getParentOfType<FuncOp>()); 787 }); 788 789 if (failed(applyPartialConversion(module, runtimeTarget, 790 std::move(asyncPatterns)))) { 791 signalPassFailure(); 792 return; 793 } 794 } 795 796 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 797 return std::make_unique<AsyncToAsyncRuntimePass>(); 798 } 799