1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/Async/Passes.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/RegionUtils.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/Support/Debug.h" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 #define DEBUG_TYPE "async-to-async-runtime" 32 // Prefix for functions outlined from `async.execute` op regions. 33 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 34 35 namespace { 36 37 class AsyncToAsyncRuntimePass 38 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 39 public: 40 AsyncToAsyncRuntimePass() = default; 41 void runOnOperation() override; 42 }; 43 44 } // namespace 45 46 //===----------------------------------------------------------------------===// 47 // async.execute op outlining to the coroutine functions. 48 //===----------------------------------------------------------------------===// 49 50 /// Function targeted for coroutine transformation has two additional blocks at 51 /// the end: coroutine cleanup and coroutine suspension. 52 /// 53 /// async.await op lowering additionaly creates a resume block for each 54 /// operation to enable non-blocking waiting via coroutine suspension. 55 namespace { 56 struct CoroMachinery { 57 FuncOp func; 58 59 // Async execute region returns a completion token, and an async value for 60 // each yielded value. 61 // 62 // %token, %result = async.execute -> !async.value<T> { 63 // %0 = constant ... : T 64 // async.yield %0 : T 65 // } 66 Value asyncToken; // token representing completion of the async region 67 llvm::SmallVector<Value, 4> returnValues; // returned async values 68 69 Value coroHandle; // coroutine handle (!async.coro.handle value) 70 Block *entry; // coroutine entry block 71 Block *setError; // switch completion token and all values to error state 72 Block *cleanup; // coroutine cleanup block 73 Block *suspend; // coroutine suspension block 74 }; 75 } // namespace 76 77 /// Utility to partially update the regular function CFG to the coroutine CFG 78 /// compatible with LLVM coroutines switched-resume lowering using 79 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block 80 /// that branches into preexisting entry block. Also inserts trailing blocks. 81 /// 82 /// The result types of the passed `func` must start with an `async.token` 83 /// and be continued with some number of `async.value`s. 84 /// 85 /// The func given to this function needs to have been preprocessed to have 86 /// either branch or yield ops as terminators. Branches to the cleanup block are 87 /// inserted after each yield. 88 /// 89 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 90 /// 91 /// - `entry` block sets up the coroutine. 92 /// - `set_error` block sets completion token and async values state to error. 93 /// - `cleanup` block cleans up the coroutine state. 94 /// - `suspend block after the @llvm.coro.end() defines what value will be 95 /// returned to the initial caller of a coroutine. Everything before the 96 /// @llvm.coro.end() will be executed at every suspension point. 97 /// 98 /// Coroutine structure (only the important bits): 99 /// 100 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) 101 /// { 102 /// ^entry(<function-arguments>): 103 /// %token = <async token> : !async.token // create async runtime token 104 /// %value = <async value> : !async.value<T> // create async value 105 /// %id = async.coro.id // create a coroutine id 106 /// %hdl = async.coro.begin %id // create a coroutine handle 107 /// br ^preexisting_entry_block 108 /// 109 /// /* preexisting blocks modified to branch to the cleanup block */ 110 /// 111 /// ^set_error: // this block created lazily only if needed (see code below) 112 /// async.runtime.set_error %token : !async.token 113 /// async.runtime.set_error %value : !async.value<T> 114 /// br ^cleanup 115 /// 116 /// ^cleanup: 117 /// async.coro.free %hdl // delete the coroutine state 118 /// br ^suspend 119 /// 120 /// ^suspend: 121 /// async.coro.end %hdl // marks the end of a coroutine 122 /// return %token, %value : !async.token, !async.value<T> 123 /// } 124 /// 125 static CoroMachinery setupCoroMachinery(FuncOp func) { 126 assert(!func.getBlocks().empty() && "Function must have an entry block"); 127 128 MLIRContext *ctx = func.getContext(); 129 Block *entryBlock = &func.getBlocks().front(); 130 Block *originalEntryBlock = 131 entryBlock->splitBlock(entryBlock->getOperations().begin()); 132 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 133 134 // ------------------------------------------------------------------------ // 135 // Allocate async token/values that we will return from a ramp function. 136 // ------------------------------------------------------------------------ // 137 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 138 139 llvm::SmallVector<Value, 4> retValues; 140 for (auto resType : func.getCallableResults().drop_front()) 141 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 142 143 // ------------------------------------------------------------------------ // 144 // Initialize coroutine: get coroutine id and coroutine handle. 145 // ------------------------------------------------------------------------ // 146 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 147 auto coroHdlOp = 148 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 149 builder.create<BranchOp>(originalEntryBlock); 150 151 Block *cleanupBlock = func.addBlock(); 152 Block *suspendBlock = func.addBlock(); 153 154 // ------------------------------------------------------------------------ // 155 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 156 // ------------------------------------------------------------------------ // 157 builder.setInsertionPointToStart(cleanupBlock); 158 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 159 160 // Branch into the suspend block. 161 builder.create<BranchOp>(suspendBlock); 162 163 // ------------------------------------------------------------------------ // 164 // Coroutine suspend block: mark the end of a coroutine and return allocated 165 // async token. 166 // ------------------------------------------------------------------------ // 167 builder.setInsertionPointToStart(suspendBlock); 168 169 // Mark the end of a coroutine: async.coro.end 170 builder.create<CoroEndOp>(coroHdlOp.handle()); 171 172 // Return created `async.token` and `async.values` from the suspend block. 173 // This will be the return value of a coroutine ramp function. 174 SmallVector<Value, 4> ret{retToken}; 175 ret.insert(ret.end(), retValues.begin(), retValues.end()); 176 builder.create<ReturnOp>(ret); 177 178 // `async.await` op lowering will create resume blocks for async 179 // continuations, and will conditionally branch to cleanup or suspend blocks. 180 181 for (Block &block : func.body().getBlocks()) { 182 if (&block == entryBlock || &block == cleanupBlock || 183 &block == suspendBlock) 184 continue; 185 Operation *terminator = block.getTerminator(); 186 if (auto yield = dyn_cast<YieldOp>(terminator)) { 187 builder.setInsertionPointToEnd(&block); 188 builder.create<BranchOp>(cleanupBlock); 189 } 190 } 191 192 CoroMachinery machinery; 193 machinery.func = func; 194 machinery.asyncToken = retToken; 195 machinery.returnValues = retValues; 196 machinery.coroHandle = coroHdlOp.handle(); 197 machinery.entry = entryBlock; 198 machinery.setError = nullptr; // created lazily only if needed 199 machinery.cleanup = cleanupBlock; 200 machinery.suspend = suspendBlock; 201 return machinery; 202 } 203 204 // Lazily creates `set_error` block only if it is required for lowering to the 205 // runtime operations (see for example lowering of assert operation). 206 static Block *setupSetErrorBlock(CoroMachinery &coro) { 207 if (coro.setError) 208 return coro.setError; 209 210 coro.setError = coro.func.addBlock(); 211 coro.setError->moveBefore(coro.cleanup); 212 213 auto builder = 214 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 215 216 // Coroutine set_error block: set error on token and all returned values. 217 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 218 for (Value retValue : coro.returnValues) 219 builder.create<RuntimeSetErrorOp>(retValue); 220 221 // Branch into the cleanup block. 222 builder.create<BranchOp>(coro.cleanup); 223 224 return coro.setError; 225 } 226 227 /// Outline the body region attached to the `async.execute` op into a standalone 228 /// function. 229 /// 230 /// Note that this is not reversible transformation. 231 static std::pair<FuncOp, CoroMachinery> 232 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 233 ModuleOp module = execute->getParentOfType<ModuleOp>(); 234 235 MLIRContext *ctx = module.getContext(); 236 Location loc = execute.getLoc(); 237 238 // Collect all outlined function inputs. 239 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 240 execute.dependencies().end()); 241 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 242 getUsedValuesDefinedAbove(execute.body(), functionInputs); 243 244 // Collect types for the outlined function inputs and outputs. 245 auto typesRange = llvm::map_range( 246 functionInputs, [](Value value) { return value.getType(); }); 247 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 248 auto outputTypes = execute.getResultTypes(); 249 250 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 251 auto funcAttrs = ArrayRef<NamedAttribute>(); 252 253 // TODO: Derive outlined function name from the parent FuncOp (support 254 // multiple nested async.execute operations). 255 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 256 symbolTable.insert(func); 257 258 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 259 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); 260 261 // Prepare for coroutine conversion by creating the body of the function. 262 { 263 size_t numDependencies = execute.dependencies().size(); 264 size_t numOperands = execute.operands().size(); 265 266 // Await on all dependencies before starting to execute the body region. 267 for (size_t i = 0; i < numDependencies; ++i) 268 builder.create<AwaitOp>(func.getArgument(i)); 269 270 // Await on all async value operands and unwrap the payload. 271 SmallVector<Value, 4> unwrappedOperands(numOperands); 272 for (size_t i = 0; i < numOperands; ++i) { 273 Value operand = func.getArgument(numDependencies + i); 274 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 275 } 276 277 // Map from function inputs defined above the execute op to the function 278 // arguments. 279 BlockAndValueMapping valueMapping; 280 valueMapping.map(functionInputs, func.getArguments()); 281 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 282 283 // Clone all operations from the execute operation body into the outlined 284 // function body. 285 for (Operation &op : execute.body().getOps()) 286 builder.clone(op, valueMapping); 287 } 288 289 // Adding entry/cleanup/suspend blocks. 290 CoroMachinery coro = setupCoroMachinery(func); 291 292 // Suspend async function at the end of an entry block, and resume it using 293 // Async resume operation (execution will be resumed in a thread managed by 294 // the async runtime). 295 { 296 BranchOp branch = cast<BranchOp>(coro.entry->getTerminator()); 297 builder.setInsertionPointToEnd(coro.entry); 298 299 // Save the coroutine state: async.coro.save 300 auto coroSaveOp = 301 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 302 303 // Pass coroutine to the runtime to be resumed on a runtime managed 304 // thread. 305 builder.create<RuntimeResumeOp>(coro.coroHandle); 306 307 // Add async.coro.suspend as a suspended block terminator. 308 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, 309 branch.getDest(), coro.cleanup); 310 311 branch.erase(); 312 } 313 314 // Replace the original `async.execute` with a call to outlined function. 315 { 316 ImplicitLocOpBuilder callBuilder(loc, execute); 317 auto callOutlinedFunc = callBuilder.create<CallOp>( 318 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 319 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 320 execute.erase(); 321 } 322 323 return {func, coro}; 324 } 325 326 //===----------------------------------------------------------------------===// 327 // Convert async.create_group operation to async.runtime.create_group 328 //===----------------------------------------------------------------------===// 329 330 namespace { 331 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 332 public: 333 using OpConversionPattern::OpConversionPattern; 334 335 LogicalResult 336 matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands, 337 ConversionPatternRewriter &rewriter) const override { 338 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 339 op, GroupType::get(op->getContext()), operands); 340 return success(); 341 } 342 }; 343 } // namespace 344 345 //===----------------------------------------------------------------------===// 346 // Convert async.add_to_group operation to async.runtime.add_to_group. 347 //===----------------------------------------------------------------------===// 348 349 namespace { 350 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 351 public: 352 using OpConversionPattern::OpConversionPattern; 353 354 LogicalResult 355 matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands, 356 ConversionPatternRewriter &rewriter) const override { 357 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 358 op, rewriter.getIndexType(), operands); 359 return success(); 360 } 361 }; 362 } // namespace 363 364 //===----------------------------------------------------------------------===// 365 // Convert async.await and async.await_all operations to the async.runtime.await 366 // or async.runtime.await_and_resume operations. 367 //===----------------------------------------------------------------------===// 368 369 namespace { 370 template <typename AwaitType, typename AwaitableType> 371 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 372 using AwaitAdaptor = typename AwaitType::Adaptor; 373 374 public: 375 AwaitOpLoweringBase(MLIRContext *ctx, 376 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 377 : OpConversionPattern<AwaitType>(ctx), 378 outlinedFunctions(outlinedFunctions) {} 379 380 LogicalResult 381 matchAndRewrite(AwaitType op, ArrayRef<Value> operands, 382 ConversionPatternRewriter &rewriter) const override { 383 // We can only await on one the `AwaitableType` (for `await` it can be 384 // a `token` or a `value`, for `await_all` it must be a `group`). 385 if (!op.operand().getType().template isa<AwaitableType>()) 386 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 387 388 // Check if await operation is inside the outlined coroutine function. 389 auto func = op->template getParentOfType<FuncOp>(); 390 auto outlined = outlinedFunctions.find(func); 391 const bool isInCoroutine = outlined != outlinedFunctions.end(); 392 393 Location loc = op->getLoc(); 394 Value operand = AwaitAdaptor(operands).operand(); 395 396 // Inside regular functions we use the blocking wait operation to wait for 397 // the async object (token, value or group) to become available. 398 if (!isInCoroutine) 399 rewriter.create<RuntimeAwaitOp>(loc, operand); 400 401 // Inside the coroutine we convert await operation into coroutine suspension 402 // point, and resume execution asynchronously. 403 if (isInCoroutine) { 404 CoroMachinery &coro = outlined->getSecond(); 405 Block *suspended = op->getBlock(); 406 407 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 408 MLIRContext *ctx = op->getContext(); 409 410 // Save the coroutine state and resume on a runtime managed thread when 411 // the operand becomes available. 412 auto coroSaveOp = 413 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 414 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 415 416 // Split the entry block before the await operation. 417 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 418 419 // Add async.coro.suspend as a suspended block terminator. 420 builder.setInsertionPointToEnd(suspended); 421 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 422 coro.cleanup); 423 424 // Split the resume block into error checking and continuation. 425 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 426 427 // Check if the awaited value is in the error state. 428 builder.setInsertionPointToStart(resume); 429 auto isError = 430 builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand); 431 builder.create<CondBranchOp>(isError, 432 /*trueDest=*/setupSetErrorBlock(coro), 433 /*trueArgs=*/ArrayRef<Value>(), 434 /*falseDest=*/continuation, 435 /*falseArgs=*/ArrayRef<Value>()); 436 437 // Make sure that replacement value will be constructed in the 438 // continuation block. 439 rewriter.setInsertionPointToStart(continuation); 440 } 441 442 // Erase or replace the await operation with the new value. 443 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 444 rewriter.replaceOp(op, replaceWith); 445 else 446 rewriter.eraseOp(op); 447 448 return success(); 449 } 450 451 virtual Value getReplacementValue(AwaitType op, Value operand, 452 ConversionPatternRewriter &rewriter) const { 453 return Value(); 454 } 455 456 private: 457 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 458 }; 459 460 /// Lowering for `async.await` with a token operand. 461 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 462 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 463 464 public: 465 using Base::Base; 466 }; 467 468 /// Lowering for `async.await` with a value operand. 469 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 470 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 471 472 public: 473 using Base::Base; 474 475 Value 476 getReplacementValue(AwaitOp op, Value operand, 477 ConversionPatternRewriter &rewriter) const override { 478 // Load from the async value storage. 479 auto valueType = operand.getType().cast<ValueType>().getValueType(); 480 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 481 } 482 }; 483 484 /// Lowering for `async.await_all` operation. 485 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 486 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 487 488 public: 489 using Base::Base; 490 }; 491 492 } // namespace 493 494 //===----------------------------------------------------------------------===// 495 // Convert async.yield operation to async.runtime operations. 496 //===----------------------------------------------------------------------===// 497 498 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 499 public: 500 YieldOpLowering( 501 MLIRContext *ctx, 502 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 503 : OpConversionPattern<async::YieldOp>(ctx), 504 outlinedFunctions(outlinedFunctions) {} 505 506 LogicalResult 507 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 508 ConversionPatternRewriter &rewriter) const override { 509 // Check if yield operation is inside the async coroutine function. 510 auto func = op->template getParentOfType<FuncOp>(); 511 auto outlined = outlinedFunctions.find(func); 512 if (outlined == outlinedFunctions.end()) 513 return rewriter.notifyMatchFailure( 514 op, "operation is not inside the async coroutine function"); 515 516 Location loc = op->getLoc(); 517 const CoroMachinery &coro = outlined->getSecond(); 518 519 // Store yielded values into the async values storage and switch async 520 // values state to available. 521 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 522 Value yieldValue = std::get<0>(tuple); 523 Value asyncValue = std::get<1>(tuple); 524 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 525 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 526 } 527 528 // Switch the coroutine completion token to available state. 529 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 530 531 return success(); 532 } 533 534 private: 535 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 536 }; 537 538 //===----------------------------------------------------------------------===// 539 // Convert std.assert operation to cond_br into `set_error` block. 540 //===----------------------------------------------------------------------===// 541 542 class AssertOpLowering : public OpConversionPattern<AssertOp> { 543 public: 544 AssertOpLowering(MLIRContext *ctx, 545 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 546 : OpConversionPattern<AssertOp>(ctx), 547 outlinedFunctions(outlinedFunctions) {} 548 549 LogicalResult 550 matchAndRewrite(AssertOp op, ArrayRef<Value> operands, 551 ConversionPatternRewriter &rewriter) const override { 552 // Check if assert operation is inside the async coroutine function. 553 auto func = op->template getParentOfType<FuncOp>(); 554 auto outlined = outlinedFunctions.find(func); 555 if (outlined == outlinedFunctions.end()) 556 return rewriter.notifyMatchFailure( 557 op, "operation is not inside the async coroutine function"); 558 559 Location loc = op->getLoc(); 560 CoroMachinery &coro = outlined->getSecond(); 561 562 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 563 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 564 rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(), 565 /*trueDest=*/cont, 566 /*trueArgs=*/ArrayRef<Value>(), 567 /*falseDest=*/setupSetErrorBlock(coro), 568 /*falseArgs=*/ArrayRef<Value>()); 569 rewriter.eraseOp(op); 570 571 return success(); 572 } 573 574 private: 575 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 576 }; 577 578 //===----------------------------------------------------------------------===// 579 580 /// Rewrite a func as a coroutine by: 581 /// 1) Wrapping the results into `async.value`. 582 /// 2) Prepending the results with `async.token`. 583 /// 3) Setting up coroutine blocks. 584 /// 4) Rewriting return ops as yield op and branch op into the suspend block. 585 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) { 586 auto *ctx = func->getContext(); 587 auto loc = func.getLoc(); 588 SmallVector<Type> resultTypes; 589 resultTypes.reserve(func.getCallableResults().size()); 590 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), 591 [](Type type) { return ValueType::get(type); }); 592 func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes)); 593 func.insertResult(0, TokenType::get(ctx), {}); 594 for (Block &block : func.getBlocks()) { 595 Operation *terminator = block.getTerminator(); 596 if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) { 597 ImplicitLocOpBuilder builder(loc, returnOp); 598 builder.create<YieldOp>(returnOp.getOperands()); 599 returnOp.erase(); 600 } 601 } 602 return setupCoroMachinery(func); 603 } 604 605 /// Rewrites a call into a function that has been rewritten as a coroutine. 606 /// 607 /// The invocation of this function is safe only when call ops are traversed in 608 /// reverse order of how they appear in a single block. See `funcsToCoroutines`. 609 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) { 610 auto loc = func.getLoc(); 611 ImplicitLocOpBuilder callBuilder(loc, oldCall); 612 auto newCall = callBuilder.create<CallOp>( 613 func.getName(), func.getCallableResults(), oldCall.getArgOperands()); 614 615 // Await on the async token and all the value results and unwrap the latter. 616 callBuilder.create<AwaitOp>(loc, newCall.getResults().front()); 617 SmallVector<Value> unwrappedResults; 618 unwrappedResults.reserve(newCall->getResults().size() - 1); 619 for (Value result : newCall.getResults().drop_front()) 620 unwrappedResults.push_back( 621 callBuilder.create<AwaitOp>(loc, result).result()); 622 // Careful, when result of a call is piped into another call this could lead 623 // to a dangling pointer. 624 oldCall.replaceAllUsesWith(unwrappedResults); 625 oldCall.erase(); 626 } 627 628 static bool isAllowedToBlock(FuncOp func) { 629 return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName); 630 } 631 632 static LogicalResult 633 funcsToCoroutines(ModuleOp module, 634 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) { 635 // The following code supports the general case when 2 functions mutually 636 // recurse into each other. Because of this and that we are relying on 637 // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase 638 // a FuncOp while inserting an equivalent coroutine, because that could lead 639 // to dangling pointers. 640 641 SmallVector<FuncOp> funcWorklist; 642 643 // Careful, it's okay to add a func to the worklist multiple times if and only 644 // if the loop processing the worklist will skip the functions that have 645 // already been converted to coroutines. 646 auto addToWorklist = [&](FuncOp func) { 647 if (isAllowedToBlock(func)) 648 return; 649 // N.B. To refactor this code into a separate pass the lookup in 650 // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary 651 // func and recognizing if it has a coroutine structure is messy. Passing 652 // this dict between the passes is ugly. 653 if (isAllowedToBlock(func) || 654 outlinedFunctions.find(func) == outlinedFunctions.end()) { 655 for (Operation &op : func.body().getOps()) { 656 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) { 657 funcWorklist.push_back(func); 658 break; 659 } 660 } 661 } 662 }; 663 664 // Traverse in post-order collecting for each func op the await ops it has. 665 for (FuncOp func : module.getOps<FuncOp>()) 666 addToWorklist(func); 667 668 SymbolTableCollection symbolTable; 669 SymbolUserMap symbolUserMap(symbolTable, module); 670 671 // Rewrite funcs, while updating call sites and adding them to the worklist. 672 while (!funcWorklist.empty()) { 673 auto func = funcWorklist.pop_back_val(); 674 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); 675 if (!insertion.second) 676 // This function has already been processed because this is either 677 // the corecursive case, or a caller with multiple calls to a newly 678 // created corouting. Either way, skip updating the call sites. 679 continue; 680 insertion.first->second = rewriteFuncAsCoroutine(func); 681 SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(), 682 symbolUserMap.getUsers(func).end()); 683 // If there are multiple calls from the same block they need to be traversed 684 // in reverse order so that symbolUserMap references are not invalidated 685 // when updating the users of the call op which is earlier in the block. 686 llvm::sort(users, [](Operation *a, Operation *b) { 687 Block *blockA = a->getBlock(); 688 Block *blockB = b->getBlock(); 689 // Impose arbitrary order on blocks so that there is a well-defined order. 690 return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); 691 }); 692 // Rewrite the callsites to await on results of the newly created coroutine. 693 for (Operation *op : users) { 694 if (CallOp call = dyn_cast<mlir::CallOp>(*op)) { 695 FuncOp caller = call->getParentOfType<FuncOp>(); 696 rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. 697 addToWorklist(caller); 698 } else { 699 op->emitError("Unexpected reference to func referenced by symbol"); 700 return failure(); 701 } 702 } 703 } 704 return success(); 705 } 706 707 //===----------------------------------------------------------------------===// 708 void AsyncToAsyncRuntimePass::runOnOperation() { 709 ModuleOp module = getOperation(); 710 SymbolTable symbolTable(module); 711 712 // Outline all `async.execute` body regions into async functions (coroutines). 713 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 714 715 module.walk([&](ExecuteOp execute) { 716 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 717 }); 718 719 LLVM_DEBUG({ 720 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 721 << " functions built from async.execute operations\n"; 722 }); 723 724 // Returns true if operation is inside the coroutine. 725 auto isInCoroutine = [&](Operation *op) -> bool { 726 auto parentFunc = op->getParentOfType<FuncOp>(); 727 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 728 }; 729 730 if (eliminateBlockingAwaitOps && 731 failed(funcsToCoroutines(module, outlinedFunctions))) { 732 signalPassFailure(); 733 return; 734 } 735 736 // Lower async operations to async.runtime operations. 737 MLIRContext *ctx = module->getContext(); 738 RewritePatternSet asyncPatterns(ctx); 739 740 // Conversion to async runtime augments original CFG with the coroutine CFG, 741 // and we have to make sure that structured control flow operations with async 742 // operations in nested regions will be converted to branch-based control flow 743 // before we add the coroutine basic blocks. 744 populateLoopToStdConversionPatterns(asyncPatterns); 745 746 // Async lowering does not use type converter because it must preserve all 747 // types for async.runtime operations. 748 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 749 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 750 AwaitAllOpLowering, YieldOpLowering>(ctx, 751 outlinedFunctions); 752 753 // Lower assertions to conditional branches into error blocks. 754 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 755 756 // All high level async operations must be lowered to the runtime operations. 757 ConversionTarget runtimeTarget(*ctx); 758 runtimeTarget.addLegalDialect<AsyncDialect>(); 759 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 760 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 761 762 // Decide if structured control flow has to be lowered to branch-based CFG. 763 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 764 auto walkResult = op->walk([&](Operation *nested) { 765 bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 766 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 767 : WalkResult::advance(); 768 }); 769 return !walkResult.wasInterrupted(); 770 }); 771 runtimeTarget.addLegalOp<BranchOp, CondBranchOp>(); 772 773 // Assertions must be converted to runtime errors inside async functions. 774 runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool { 775 auto func = op->getParentOfType<FuncOp>(); 776 return outlinedFunctions.find(func) == outlinedFunctions.end(); 777 }); 778 779 if (eliminateBlockingAwaitOps) 780 runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>( 781 [&](RuntimeAwaitOp op) -> bool { 782 return isAllowedToBlock(op->getParentOfType<FuncOp>()); 783 }); 784 785 if (failed(applyPartialConversion(module, runtimeTarget, 786 std::move(asyncPatterns)))) { 787 signalPassFailure(); 788 return; 789 } 790 } 791 792 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 793 return std::make_unique<AsyncToAsyncRuntimePass>(); 794 } 795