1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/Async/Passes.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/RegionUtils.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/Support/Debug.h" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 #define DEBUG_TYPE "async-to-async-runtime" 32 // Prefix for functions outlined from `async.execute` op regions. 33 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 34 35 namespace { 36 37 class AsyncToAsyncRuntimePass 38 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 39 public: 40 AsyncToAsyncRuntimePass() = default; 41 void runOnOperation() override; 42 }; 43 44 } // namespace 45 46 //===----------------------------------------------------------------------===// 47 // async.execute op outlining to the coroutine functions. 48 //===----------------------------------------------------------------------===// 49 50 /// Function targeted for coroutine transformation has two additional blocks at 51 /// the end: coroutine cleanup and coroutine suspension. 52 /// 53 /// async.await op lowering additionaly creates a resume block for each 54 /// operation to enable non-blocking waiting via coroutine suspension. 55 namespace { 56 struct CoroMachinery { 57 FuncOp func; 58 59 // Async execute region returns a completion token, and an async value for 60 // each yielded value. 61 // 62 // %token, %result = async.execute -> !async.value<T> { 63 // %0 = constant ... : T 64 // async.yield %0 : T 65 // } 66 Value asyncToken; // token representing completion of the async region 67 llvm::SmallVector<Value, 4> returnValues; // returned async values 68 69 Value coroHandle; // coroutine handle (!async.coro.handle value) 70 Block *entry; // coroutine entry block 71 Block *setError; // switch completion token and all values to error state 72 Block *cleanup; // coroutine cleanup block 73 Block *suspend; // coroutine suspension block 74 }; 75 } // namespace 76 77 /// Utility to partially update the regular function CFG to the coroutine CFG 78 /// compatible with LLVM coroutines switched-resume lowering using 79 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block 80 /// that branches into preexisting entry block. Also inserts trailing blocks. 81 /// 82 /// The result types of the passed `func` must start with an `async.token` 83 /// and be continued with some number of `async.value`s. 84 /// 85 /// The func given to this function needs to have been preprocessed to have 86 /// either branch or yield ops as terminators. Branches to the cleanup block are 87 /// inserted after each yield. 88 /// 89 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 90 /// 91 /// - `entry` block sets up the coroutine. 92 /// - `set_error` block sets completion token and async values state to error. 93 /// - `cleanup` block cleans up the coroutine state. 94 /// - `suspend block after the @llvm.coro.end() defines what value will be 95 /// returned to the initial caller of a coroutine. Everything before the 96 /// @llvm.coro.end() will be executed at every suspension point. 97 /// 98 /// Coroutine structure (only the important bits): 99 /// 100 /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) 101 /// { 102 /// ^entry(<function-arguments>): 103 /// %token = <async token> : !async.token // create async runtime token 104 /// %value = <async value> : !async.value<T> // create async value 105 /// %id = async.coro.id // create a coroutine id 106 /// %hdl = async.coro.begin %id // create a coroutine handle 107 /// br ^preexisting_entry_block 108 /// 109 /// /* preexisting blocks modified to branch to the cleanup block */ 110 /// 111 /// ^set_error: // this block created lazily only if needed (see code below) 112 /// async.runtime.set_error %token : !async.token 113 /// async.runtime.set_error %value : !async.value<T> 114 /// br ^cleanup 115 /// 116 /// ^cleanup: 117 /// async.coro.free %hdl // delete the coroutine state 118 /// br ^suspend 119 /// 120 /// ^suspend: 121 /// async.coro.end %hdl // marks the end of a coroutine 122 /// return %token, %value : !async.token, !async.value<T> 123 /// } 124 /// 125 static CoroMachinery setupCoroMachinery(FuncOp func) { 126 assert(!func.getBlocks().empty() && "Function must have an entry block"); 127 128 MLIRContext *ctx = func.getContext(); 129 Block *entryBlock = &func.getBlocks().front(); 130 Block *originalEntryBlock = 131 entryBlock->splitBlock(entryBlock->getOperations().begin()); 132 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 133 134 // ------------------------------------------------------------------------ // 135 // Allocate async token/values that we will return from a ramp function. 136 // ------------------------------------------------------------------------ // 137 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 138 139 llvm::SmallVector<Value, 4> retValues; 140 for (auto resType : func.getCallableResults().drop_front()) 141 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 142 143 // ------------------------------------------------------------------------ // 144 // Initialize coroutine: get coroutine id and coroutine handle. 145 // ------------------------------------------------------------------------ // 146 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 147 auto coroHdlOp = 148 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 149 builder.create<BranchOp>(originalEntryBlock); 150 151 Block *cleanupBlock = func.addBlock(); 152 Block *suspendBlock = func.addBlock(); 153 154 // ------------------------------------------------------------------------ // 155 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 156 // ------------------------------------------------------------------------ // 157 builder.setInsertionPointToStart(cleanupBlock); 158 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 159 160 // Branch into the suspend block. 161 builder.create<BranchOp>(suspendBlock); 162 163 // ------------------------------------------------------------------------ // 164 // Coroutine suspend block: mark the end of a coroutine and return allocated 165 // async token. 166 // ------------------------------------------------------------------------ // 167 builder.setInsertionPointToStart(suspendBlock); 168 169 // Mark the end of a coroutine: async.coro.end 170 builder.create<CoroEndOp>(coroHdlOp.handle()); 171 172 // Return created `async.token` and `async.values` from the suspend block. 173 // This will be the return value of a coroutine ramp function. 174 SmallVector<Value, 4> ret{retToken}; 175 ret.insert(ret.end(), retValues.begin(), retValues.end()); 176 builder.create<ReturnOp>(ret); 177 178 // `async.await` op lowering will create resume blocks for async 179 // continuations, and will conditionally branch to cleanup or suspend blocks. 180 181 for (Block &block : func.body().getBlocks()) { 182 if (&block == entryBlock || &block == cleanupBlock || 183 &block == suspendBlock) 184 continue; 185 Operation *terminator = block.getTerminator(); 186 if (auto yield = dyn_cast<YieldOp>(terminator)) { 187 builder.setInsertionPointToEnd(&block); 188 builder.create<BranchOp>(cleanupBlock); 189 } 190 } 191 192 CoroMachinery machinery; 193 machinery.func = func; 194 machinery.asyncToken = retToken; 195 machinery.returnValues = retValues; 196 machinery.coroHandle = coroHdlOp.handle(); 197 machinery.entry = entryBlock; 198 machinery.setError = nullptr; // created lazily only if needed 199 machinery.cleanup = cleanupBlock; 200 machinery.suspend = suspendBlock; 201 return machinery; 202 } 203 204 // Lazily creates `set_error` block only if it is required for lowering to the 205 // runtime operations (see for example lowering of assert operation). 206 static Block *setupSetErrorBlock(CoroMachinery &coro) { 207 if (coro.setError) 208 return coro.setError; 209 210 coro.setError = coro.func.addBlock(); 211 coro.setError->moveBefore(coro.cleanup); 212 213 auto builder = 214 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 215 216 // Coroutine set_error block: set error on token and all returned values. 217 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 218 for (Value retValue : coro.returnValues) 219 builder.create<RuntimeSetErrorOp>(retValue); 220 221 // Branch into the cleanup block. 222 builder.create<BranchOp>(coro.cleanup); 223 224 return coro.setError; 225 } 226 227 /// Outline the body region attached to the `async.execute` op into a standalone 228 /// function. 229 /// 230 /// Note that this is not reversible transformation. 231 static std::pair<FuncOp, CoroMachinery> 232 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 233 ModuleOp module = execute->getParentOfType<ModuleOp>(); 234 235 MLIRContext *ctx = module.getContext(); 236 Location loc = execute.getLoc(); 237 238 // Make sure that all constants will be inside the outlined async function to 239 // reduce the number of function arguments. 240 cloneConstantsIntoTheRegion(execute.body()); 241 242 // Collect all outlined function inputs. 243 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 244 execute.dependencies().end()); 245 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 246 getUsedValuesDefinedAbove(execute.body(), functionInputs); 247 248 // Collect types for the outlined function inputs and outputs. 249 auto typesRange = llvm::map_range( 250 functionInputs, [](Value value) { return value.getType(); }); 251 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 252 auto outputTypes = execute.getResultTypes(); 253 254 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 255 auto funcAttrs = ArrayRef<NamedAttribute>(); 256 257 // TODO: Derive outlined function name from the parent FuncOp (support 258 // multiple nested async.execute operations). 259 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 260 symbolTable.insert(func); 261 262 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 263 auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); 264 265 // Prepare for coroutine conversion by creating the body of the function. 266 { 267 size_t numDependencies = execute.dependencies().size(); 268 size_t numOperands = execute.operands().size(); 269 270 // Await on all dependencies before starting to execute the body region. 271 for (size_t i = 0; i < numDependencies; ++i) 272 builder.create<AwaitOp>(func.getArgument(i)); 273 274 // Await on all async value operands and unwrap the payload. 275 SmallVector<Value, 4> unwrappedOperands(numOperands); 276 for (size_t i = 0; i < numOperands; ++i) { 277 Value operand = func.getArgument(numDependencies + i); 278 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 279 } 280 281 // Map from function inputs defined above the execute op to the function 282 // arguments. 283 BlockAndValueMapping valueMapping; 284 valueMapping.map(functionInputs, func.getArguments()); 285 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 286 287 // Clone all operations from the execute operation body into the outlined 288 // function body. 289 for (Operation &op : execute.body().getOps()) 290 builder.clone(op, valueMapping); 291 } 292 293 // Adding entry/cleanup/suspend blocks. 294 CoroMachinery coro = setupCoroMachinery(func); 295 296 // Suspend async function at the end of an entry block, and resume it using 297 // Async resume operation (execution will be resumed in a thread managed by 298 // the async runtime). 299 { 300 BranchOp branch = cast<BranchOp>(coro.entry->getTerminator()); 301 builder.setInsertionPointToEnd(coro.entry); 302 303 // Save the coroutine state: async.coro.save 304 auto coroSaveOp = 305 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 306 307 // Pass coroutine to the runtime to be resumed on a runtime managed 308 // thread. 309 builder.create<RuntimeResumeOp>(coro.coroHandle); 310 311 // Add async.coro.suspend as a suspended block terminator. 312 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, 313 branch.getDest(), coro.cleanup); 314 315 branch.erase(); 316 } 317 318 // Replace the original `async.execute` with a call to outlined function. 319 { 320 ImplicitLocOpBuilder callBuilder(loc, execute); 321 auto callOutlinedFunc = callBuilder.create<CallOp>( 322 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 323 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 324 execute.erase(); 325 } 326 327 return {func, coro}; 328 } 329 330 //===----------------------------------------------------------------------===// 331 // Convert async.create_group operation to async.runtime.create_group 332 //===----------------------------------------------------------------------===// 333 334 namespace { 335 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 336 public: 337 using OpConversionPattern::OpConversionPattern; 338 339 LogicalResult 340 matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands, 341 ConversionPatternRewriter &rewriter) const override { 342 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 343 op, GroupType::get(op->getContext()), operands); 344 return success(); 345 } 346 }; 347 } // namespace 348 349 //===----------------------------------------------------------------------===// 350 // Convert async.add_to_group operation to async.runtime.add_to_group. 351 //===----------------------------------------------------------------------===// 352 353 namespace { 354 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 355 public: 356 using OpConversionPattern::OpConversionPattern; 357 358 LogicalResult 359 matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands, 360 ConversionPatternRewriter &rewriter) const override { 361 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 362 op, rewriter.getIndexType(), operands); 363 return success(); 364 } 365 }; 366 } // namespace 367 368 //===----------------------------------------------------------------------===// 369 // Convert async.await and async.await_all operations to the async.runtime.await 370 // or async.runtime.await_and_resume operations. 371 //===----------------------------------------------------------------------===// 372 373 namespace { 374 template <typename AwaitType, typename AwaitableType> 375 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 376 using AwaitAdaptor = typename AwaitType::Adaptor; 377 378 public: 379 AwaitOpLoweringBase(MLIRContext *ctx, 380 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 381 : OpConversionPattern<AwaitType>(ctx), 382 outlinedFunctions(outlinedFunctions) {} 383 384 LogicalResult 385 matchAndRewrite(AwaitType op, ArrayRef<Value> operands, 386 ConversionPatternRewriter &rewriter) const override { 387 // We can only await on one the `AwaitableType` (for `await` it can be 388 // a `token` or a `value`, for `await_all` it must be a `group`). 389 if (!op.operand().getType().template isa<AwaitableType>()) 390 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 391 392 // Check if await operation is inside the outlined coroutine function. 393 auto func = op->template getParentOfType<FuncOp>(); 394 auto outlined = outlinedFunctions.find(func); 395 const bool isInCoroutine = outlined != outlinedFunctions.end(); 396 397 Location loc = op->getLoc(); 398 Value operand = AwaitAdaptor(operands).operand(); 399 400 Type i1 = rewriter.getI1Type(); 401 402 // Inside regular functions we use the blocking wait operation to wait for 403 // the async object (token, value or group) to become available. 404 if (!isInCoroutine) { 405 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 406 builder.create<RuntimeAwaitOp>(loc, operand); 407 408 // Assert that the awaited operands is not in the error state. 409 Value isError = builder.create<RuntimeIsErrorOp>(i1, operand); 410 Value notError = builder.create<XOrOp>( 411 isError, 412 builder.create<ConstantOp>(loc, i1, builder.getIntegerAttr(i1, 1))); 413 414 builder.create<AssertOp>(notError, 415 "Awaited async operand is in error state"); 416 } 417 418 // Inside the coroutine we convert await operation into coroutine suspension 419 // point, and resume execution asynchronously. 420 if (isInCoroutine) { 421 CoroMachinery &coro = outlined->getSecond(); 422 Block *suspended = op->getBlock(); 423 424 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 425 MLIRContext *ctx = op->getContext(); 426 427 // Save the coroutine state and resume on a runtime managed thread when 428 // the operand becomes available. 429 auto coroSaveOp = 430 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 431 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 432 433 // Split the entry block before the await operation. 434 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 435 436 // Add async.coro.suspend as a suspended block terminator. 437 builder.setInsertionPointToEnd(suspended); 438 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 439 coro.cleanup); 440 441 // Split the resume block into error checking and continuation. 442 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 443 444 // Check if the awaited value is in the error state. 445 builder.setInsertionPointToStart(resume); 446 auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand); 447 builder.create<CondBranchOp>(isError, 448 /*trueDest=*/setupSetErrorBlock(coro), 449 /*trueArgs=*/ArrayRef<Value>(), 450 /*falseDest=*/continuation, 451 /*falseArgs=*/ArrayRef<Value>()); 452 453 // Make sure that replacement value will be constructed in the 454 // continuation block. 455 rewriter.setInsertionPointToStart(continuation); 456 } 457 458 // Erase or replace the await operation with the new value. 459 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 460 rewriter.replaceOp(op, replaceWith); 461 else 462 rewriter.eraseOp(op); 463 464 return success(); 465 } 466 467 virtual Value getReplacementValue(AwaitType op, Value operand, 468 ConversionPatternRewriter &rewriter) const { 469 return Value(); 470 } 471 472 private: 473 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 474 }; 475 476 /// Lowering for `async.await` with a token operand. 477 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 478 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 479 480 public: 481 using Base::Base; 482 }; 483 484 /// Lowering for `async.await` with a value operand. 485 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 486 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 487 488 public: 489 using Base::Base; 490 491 Value 492 getReplacementValue(AwaitOp op, Value operand, 493 ConversionPatternRewriter &rewriter) const override { 494 // Load from the async value storage. 495 auto valueType = operand.getType().cast<ValueType>().getValueType(); 496 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 497 } 498 }; 499 500 /// Lowering for `async.await_all` operation. 501 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 502 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 503 504 public: 505 using Base::Base; 506 }; 507 508 } // namespace 509 510 //===----------------------------------------------------------------------===// 511 // Convert async.yield operation to async.runtime operations. 512 //===----------------------------------------------------------------------===// 513 514 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 515 public: 516 YieldOpLowering( 517 MLIRContext *ctx, 518 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 519 : OpConversionPattern<async::YieldOp>(ctx), 520 outlinedFunctions(outlinedFunctions) {} 521 522 LogicalResult 523 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 524 ConversionPatternRewriter &rewriter) const override { 525 // Check if yield operation is inside the async coroutine function. 526 auto func = op->template getParentOfType<FuncOp>(); 527 auto outlined = outlinedFunctions.find(func); 528 if (outlined == outlinedFunctions.end()) 529 return rewriter.notifyMatchFailure( 530 op, "operation is not inside the async coroutine function"); 531 532 Location loc = op->getLoc(); 533 const CoroMachinery &coro = outlined->getSecond(); 534 535 // Store yielded values into the async values storage and switch async 536 // values state to available. 537 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 538 Value yieldValue = std::get<0>(tuple); 539 Value asyncValue = std::get<1>(tuple); 540 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 541 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 542 } 543 544 // Switch the coroutine completion token to available state. 545 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 546 547 return success(); 548 } 549 550 private: 551 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 552 }; 553 554 //===----------------------------------------------------------------------===// 555 // Convert std.assert operation to cond_br into `set_error` block. 556 //===----------------------------------------------------------------------===// 557 558 class AssertOpLowering : public OpConversionPattern<AssertOp> { 559 public: 560 AssertOpLowering(MLIRContext *ctx, 561 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 562 : OpConversionPattern<AssertOp>(ctx), 563 outlinedFunctions(outlinedFunctions) {} 564 565 LogicalResult 566 matchAndRewrite(AssertOp op, ArrayRef<Value> operands, 567 ConversionPatternRewriter &rewriter) const override { 568 // Check if assert operation is inside the async coroutine function. 569 auto func = op->template getParentOfType<FuncOp>(); 570 auto outlined = outlinedFunctions.find(func); 571 if (outlined == outlinedFunctions.end()) 572 return rewriter.notifyMatchFailure( 573 op, "operation is not inside the async coroutine function"); 574 575 Location loc = op->getLoc(); 576 CoroMachinery &coro = outlined->getSecond(); 577 578 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 579 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 580 rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(), 581 /*trueDest=*/cont, 582 /*trueArgs=*/ArrayRef<Value>(), 583 /*falseDest=*/setupSetErrorBlock(coro), 584 /*falseArgs=*/ArrayRef<Value>()); 585 rewriter.eraseOp(op); 586 587 return success(); 588 } 589 590 private: 591 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 592 }; 593 594 //===----------------------------------------------------------------------===// 595 596 /// Rewrite a func as a coroutine by: 597 /// 1) Wrapping the results into `async.value`. 598 /// 2) Prepending the results with `async.token`. 599 /// 3) Setting up coroutine blocks. 600 /// 4) Rewriting return ops as yield op and branch op into the suspend block. 601 static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) { 602 auto *ctx = func->getContext(); 603 auto loc = func.getLoc(); 604 SmallVector<Type> resultTypes; 605 resultTypes.reserve(func.getCallableResults().size()); 606 llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), 607 [](Type type) { return ValueType::get(type); }); 608 func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes)); 609 func.insertResult(0, TokenType::get(ctx), {}); 610 for (Block &block : func.getBlocks()) { 611 Operation *terminator = block.getTerminator(); 612 if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) { 613 ImplicitLocOpBuilder builder(loc, returnOp); 614 builder.create<YieldOp>(returnOp.getOperands()); 615 returnOp.erase(); 616 } 617 } 618 return setupCoroMachinery(func); 619 } 620 621 /// Rewrites a call into a function that has been rewritten as a coroutine. 622 /// 623 /// The invocation of this function is safe only when call ops are traversed in 624 /// reverse order of how they appear in a single block. See `funcsToCoroutines`. 625 static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) { 626 auto loc = func.getLoc(); 627 ImplicitLocOpBuilder callBuilder(loc, oldCall); 628 auto newCall = callBuilder.create<CallOp>( 629 func.getName(), func.getCallableResults(), oldCall.getArgOperands()); 630 631 // Await on the async token and all the value results and unwrap the latter. 632 callBuilder.create<AwaitOp>(loc, newCall.getResults().front()); 633 SmallVector<Value> unwrappedResults; 634 unwrappedResults.reserve(newCall->getResults().size() - 1); 635 for (Value result : newCall.getResults().drop_front()) 636 unwrappedResults.push_back( 637 callBuilder.create<AwaitOp>(loc, result).result()); 638 // Careful, when result of a call is piped into another call this could lead 639 // to a dangling pointer. 640 oldCall.replaceAllUsesWith(unwrappedResults); 641 oldCall.erase(); 642 } 643 644 static bool isAllowedToBlock(FuncOp func) { 645 return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName); 646 } 647 648 static LogicalResult 649 funcsToCoroutines(ModuleOp module, 650 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) { 651 // The following code supports the general case when 2 functions mutually 652 // recurse into each other. Because of this and that we are relying on 653 // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase 654 // a FuncOp while inserting an equivalent coroutine, because that could lead 655 // to dangling pointers. 656 657 SmallVector<FuncOp> funcWorklist; 658 659 // Careful, it's okay to add a func to the worklist multiple times if and only 660 // if the loop processing the worklist will skip the functions that have 661 // already been converted to coroutines. 662 auto addToWorklist = [&](FuncOp func) { 663 if (isAllowedToBlock(func)) 664 return; 665 // N.B. To refactor this code into a separate pass the lookup in 666 // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary 667 // func and recognizing if it has a coroutine structure is messy. Passing 668 // this dict between the passes is ugly. 669 if (isAllowedToBlock(func) || 670 outlinedFunctions.find(func) == outlinedFunctions.end()) { 671 for (Operation &op : func.body().getOps()) { 672 if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) { 673 funcWorklist.push_back(func); 674 break; 675 } 676 } 677 } 678 }; 679 680 // Traverse in post-order collecting for each func op the await ops it has. 681 for (FuncOp func : module.getOps<FuncOp>()) 682 addToWorklist(func); 683 684 SymbolTableCollection symbolTable; 685 SymbolUserMap symbolUserMap(symbolTable, module); 686 687 // Rewrite funcs, while updating call sites and adding them to the worklist. 688 while (!funcWorklist.empty()) { 689 auto func = funcWorklist.pop_back_val(); 690 auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); 691 if (!insertion.second) 692 // This function has already been processed because this is either 693 // the corecursive case, or a caller with multiple calls to a newly 694 // created corouting. Either way, skip updating the call sites. 695 continue; 696 insertion.first->second = rewriteFuncAsCoroutine(func); 697 SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(), 698 symbolUserMap.getUsers(func).end()); 699 // If there are multiple calls from the same block they need to be traversed 700 // in reverse order so that symbolUserMap references are not invalidated 701 // when updating the users of the call op which is earlier in the block. 702 llvm::sort(users, [](Operation *a, Operation *b) { 703 Block *blockA = a->getBlock(); 704 Block *blockB = b->getBlock(); 705 // Impose arbitrary order on blocks so that there is a well-defined order. 706 return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); 707 }); 708 // Rewrite the callsites to await on results of the newly created coroutine. 709 for (Operation *op : users) { 710 if (CallOp call = dyn_cast<mlir::CallOp>(*op)) { 711 FuncOp caller = call->getParentOfType<FuncOp>(); 712 rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. 713 addToWorklist(caller); 714 } else { 715 op->emitError("Unexpected reference to func referenced by symbol"); 716 return failure(); 717 } 718 } 719 } 720 return success(); 721 } 722 723 //===----------------------------------------------------------------------===// 724 void AsyncToAsyncRuntimePass::runOnOperation() { 725 ModuleOp module = getOperation(); 726 SymbolTable symbolTable(module); 727 728 // Outline all `async.execute` body regions into async functions (coroutines). 729 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 730 731 module.walk([&](ExecuteOp execute) { 732 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 733 }); 734 735 LLVM_DEBUG({ 736 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 737 << " functions built from async.execute operations\n"; 738 }); 739 740 // Returns true if operation is inside the coroutine. 741 auto isInCoroutine = [&](Operation *op) -> bool { 742 auto parentFunc = op->getParentOfType<FuncOp>(); 743 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 744 }; 745 746 if (eliminateBlockingAwaitOps && 747 failed(funcsToCoroutines(module, outlinedFunctions))) { 748 signalPassFailure(); 749 return; 750 } 751 752 // Lower async operations to async.runtime operations. 753 MLIRContext *ctx = module->getContext(); 754 RewritePatternSet asyncPatterns(ctx); 755 756 // Conversion to async runtime augments original CFG with the coroutine CFG, 757 // and we have to make sure that structured control flow operations with async 758 // operations in nested regions will be converted to branch-based control flow 759 // before we add the coroutine basic blocks. 760 populateLoopToStdConversionPatterns(asyncPatterns); 761 762 // Async lowering does not use type converter because it must preserve all 763 // types for async.runtime operations. 764 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 765 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 766 AwaitAllOpLowering, YieldOpLowering>(ctx, 767 outlinedFunctions); 768 769 // Lower assertions to conditional branches into error blocks. 770 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 771 772 // All high level async operations must be lowered to the runtime operations. 773 ConversionTarget runtimeTarget(*ctx); 774 runtimeTarget.addLegalDialect<AsyncDialect>(); 775 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 776 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 777 778 // Decide if structured control flow has to be lowered to branch-based CFG. 779 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 780 auto walkResult = op->walk([&](Operation *nested) { 781 bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 782 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 783 : WalkResult::advance(); 784 }); 785 return !walkResult.wasInterrupted(); 786 }); 787 runtimeTarget 788 .addLegalOp<AssertOp, XOrOp, ConstantOp, BranchOp, CondBranchOp>(); 789 790 // Assertions must be converted to runtime errors inside async functions. 791 runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool { 792 auto func = op->getParentOfType<FuncOp>(); 793 return outlinedFunctions.find(func) == outlinedFunctions.end(); 794 }); 795 796 if (eliminateBlockingAwaitOps) 797 runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>( 798 [&](RuntimeAwaitOp op) -> bool { 799 return isAllowedToBlock(op->getParentOfType<FuncOp>()); 800 }); 801 802 if (failed(applyPartialConversion(module, runtimeTarget, 803 std::move(asyncPatterns)))) { 804 signalPassFailure(); 805 return; 806 } 807 } 808 809 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 810 return std::make_unique<AsyncToAsyncRuntimePass>(); 811 } 812