1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 16 #include "mlir/Dialect/Async/IR/Async.h" 17 #include "mlir/Dialect/Async/Passes.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "mlir/Transforms/RegionUtils.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/Support/Debug.h" 27 28 using namespace mlir; 29 using namespace mlir::async; 30 31 #define DEBUG_TYPE "async-to-async-runtime" 32 // Prefix for functions outlined from `async.execute` op regions. 33 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 34 35 namespace { 36 37 class AsyncToAsyncRuntimePass 38 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 39 public: 40 AsyncToAsyncRuntimePass() = default; 41 void runOnOperation() override; 42 }; 43 44 } // namespace 45 46 //===----------------------------------------------------------------------===// 47 // async.execute op outlining to the coroutine functions. 48 //===----------------------------------------------------------------------===// 49 50 /// Function targeted for coroutine transformation has two additional blocks at 51 /// the end: coroutine cleanup and coroutine suspension. 52 /// 53 /// async.await op lowering additionaly creates a resume block for each 54 /// operation to enable non-blocking waiting via coroutine suspension. 55 namespace { 56 struct CoroMachinery { 57 FuncOp func; 58 59 // Async execute region returns a completion token, and an async value for 60 // each yielded value. 61 // 62 // %token, %result = async.execute -> !async.value<T> { 63 // %0 = constant ... : T 64 // async.yield %0 : T 65 // } 66 Value asyncToken; // token representing completion of the async region 67 llvm::SmallVector<Value, 4> returnValues; // returned async values 68 69 Value coroHandle; // coroutine handle (!async.coro.handle value) 70 Block *setError; // switch completion token and all values to error state 71 Block *cleanup; // coroutine cleanup block 72 Block *suspend; // coroutine suspension block 73 }; 74 } // namespace 75 76 /// Builds an coroutine template compatible with LLVM coroutines switched-resume 77 /// lowering using `async.runtime.*` and `async.coro.*` operations. 78 /// 79 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 80 /// 81 /// - `entry` block sets up the coroutine. 82 /// - `set_error` block sets completion token and async values state to error. 83 /// - `cleanup` block cleans up the coroutine state. 84 /// - `suspend block after the @llvm.coro.end() defines what value will be 85 /// returned to the initial caller of a coroutine. Everything before the 86 /// @llvm.coro.end() will be executed at every suspension point. 87 /// 88 /// Coroutine structure (only the important bits): 89 /// 90 /// func @async_execute_fn(<function-arguments>) 91 /// -> (!async.token, !async.value<T>) 92 /// { 93 /// ^entry(<function-arguments>): 94 /// %token = <async token> : !async.token // create async runtime token 95 /// %value = <async value> : !async.value<T> // create async value 96 /// %id = async.coro.id // create a coroutine id 97 /// %hdl = async.coro.begin %id // create a coroutine handle 98 /// br ^cleanup 99 /// 100 /// ^set_error: // this block created lazily only if needed (see code below) 101 /// async.runtime.set_error %token : !async.token 102 /// async.runtime.set_error %value : !async.value<T> 103 /// br ^cleanup 104 /// 105 /// ^cleanup: 106 /// async.coro.free %hdl // delete the coroutine state 107 /// br ^suspend 108 /// 109 /// ^suspend: 110 /// async.coro.end %hdl // marks the end of a coroutine 111 /// return %token, %value : !async.token, !async.value<T> 112 /// } 113 /// 114 /// The actual code for the async.execute operation body region will be inserted 115 /// before the entry block terminator. 116 /// 117 /// 118 static CoroMachinery setupCoroMachinery(FuncOp func) { 119 assert(func.getBody().empty() && "Function must have empty body"); 120 121 MLIRContext *ctx = func.getContext(); 122 Block *entryBlock = func.addEntryBlock(); 123 124 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 125 126 // ------------------------------------------------------------------------ // 127 // Allocate async token/values that we will return from a ramp function. 128 // ------------------------------------------------------------------------ // 129 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 130 131 llvm::SmallVector<Value, 4> retValues; 132 for (auto resType : func.getCallableResults().drop_front()) 133 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 134 135 // ------------------------------------------------------------------------ // 136 // Initialize coroutine: get coroutine id and coroutine handle. 137 // ------------------------------------------------------------------------ // 138 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 139 auto coroHdlOp = 140 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 141 142 Block *cleanupBlock = func.addBlock(); 143 Block *suspendBlock = func.addBlock(); 144 145 // ------------------------------------------------------------------------ // 146 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 147 // ------------------------------------------------------------------------ // 148 builder.setInsertionPointToStart(cleanupBlock); 149 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 150 151 // Branch into the suspend block. 152 builder.create<BranchOp>(suspendBlock); 153 154 // ------------------------------------------------------------------------ // 155 // Coroutine suspend block: mark the end of a coroutine and return allocated 156 // async token. 157 // ------------------------------------------------------------------------ // 158 builder.setInsertionPointToStart(suspendBlock); 159 160 // Mark the end of a coroutine: async.coro.end 161 builder.create<CoroEndOp>(coroHdlOp.handle()); 162 163 // Return created `async.token` and `async.values` from the suspend block. 164 // This will be the return value of a coroutine ramp function. 165 SmallVector<Value, 4> ret{retToken}; 166 ret.insert(ret.end(), retValues.begin(), retValues.end()); 167 builder.create<ReturnOp>(ret); 168 169 // Branch from the entry block to the cleanup block to create a valid CFG. 170 builder.setInsertionPointToEnd(entryBlock); 171 builder.create<BranchOp>(cleanupBlock); 172 173 // `async.await` op lowering will create resume blocks for async 174 // continuations, and will conditionally branch to cleanup or suspend blocks. 175 176 CoroMachinery machinery; 177 machinery.func = func; 178 machinery.asyncToken = retToken; 179 machinery.returnValues = retValues; 180 machinery.coroHandle = coroHdlOp.handle(); 181 machinery.setError = nullptr; // created lazily only if needed 182 machinery.cleanup = cleanupBlock; 183 machinery.suspend = suspendBlock; 184 return machinery; 185 } 186 187 // Lazily creates `set_error` block only if it is required for lowering to the 188 // runtime operations (see for example lowering of assert operation). 189 static Block *setupSetErrorBlock(CoroMachinery &coro) { 190 if (coro.setError) 191 return coro.setError; 192 193 coro.setError = coro.func.addBlock(); 194 coro.setError->moveBefore(coro.cleanup); 195 196 auto builder = 197 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 198 199 // Coroutine set_error block: set error on token and all returned values. 200 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 201 for (Value retValue : coro.returnValues) 202 builder.create<RuntimeSetErrorOp>(retValue); 203 204 // Branch into the cleanup block. 205 builder.create<BranchOp>(coro.cleanup); 206 207 return coro.setError; 208 } 209 210 /// Outline the body region attached to the `async.execute` op into a standalone 211 /// function. 212 /// 213 /// Note that this is not reversible transformation. 214 static std::pair<FuncOp, CoroMachinery> 215 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 216 ModuleOp module = execute->getParentOfType<ModuleOp>(); 217 218 MLIRContext *ctx = module.getContext(); 219 Location loc = execute.getLoc(); 220 221 // Collect all outlined function inputs. 222 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 223 execute.dependencies().end()); 224 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 225 getUsedValuesDefinedAbove(execute.body(), functionInputs); 226 227 // Collect types for the outlined function inputs and outputs. 228 auto typesRange = llvm::map_range( 229 functionInputs, [](Value value) { return value.getType(); }); 230 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 231 auto outputTypes = execute.getResultTypes(); 232 233 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 234 auto funcAttrs = ArrayRef<NamedAttribute>(); 235 236 // TODO: Derive outlined function name from the parent FuncOp (support 237 // multiple nested async.execute operations). 238 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 239 symbolTable.insert(func); 240 241 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 242 243 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 244 // blocks, adding async.coro operations and setting up control flow. 245 CoroMachinery coro = setupCoroMachinery(func); 246 247 // Suspend async function at the end of an entry block, and resume it using 248 // Async resume operation (execution will be resumed in a thread managed by 249 // the async runtime). 250 Block *entryBlock = &func.getBlocks().front(); 251 auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); 252 253 // Save the coroutine state: async.coro.save 254 auto coroSaveOp = 255 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 256 257 // Pass coroutine to the runtime to be resumed on a runtime managed thread. 258 builder.create<RuntimeResumeOp>(coro.coroHandle); 259 260 // Split the entry block before the terminator (branch to suspend block). 261 auto *terminatorOp = entryBlock->getTerminator(); 262 Block *suspended = terminatorOp->getBlock(); 263 Block *resume = suspended->splitBlock(terminatorOp); 264 265 // Add async.coro.suspend as a suspended block terminator. 266 builder.setInsertionPointToEnd(suspended); 267 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 268 coro.cleanup); 269 270 size_t numDependencies = execute.dependencies().size(); 271 size_t numOperands = execute.operands().size(); 272 273 // Await on all dependencies before starting to execute the body region. 274 builder.setInsertionPointToStart(resume); 275 for (size_t i = 0; i < numDependencies; ++i) 276 builder.create<AwaitOp>(func.getArgument(i)); 277 278 // Await on all async value operands and unwrap the payload. 279 SmallVector<Value, 4> unwrappedOperands(numOperands); 280 for (size_t i = 0; i < numOperands; ++i) { 281 Value operand = func.getArgument(numDependencies + i); 282 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 283 } 284 285 // Map from function inputs defined above the execute op to the function 286 // arguments. 287 BlockAndValueMapping valueMapping; 288 valueMapping.map(functionInputs, func.getArguments()); 289 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 290 291 // Clone all operations from the execute operation body into the outlined 292 // function body. 293 for (Operation &op : execute.body().getOps()) 294 builder.clone(op, valueMapping); 295 296 // Replace the original `async.execute` with a call to outlined function. 297 ImplicitLocOpBuilder callBuilder(loc, execute); 298 auto callOutlinedFunc = callBuilder.create<CallOp>( 299 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 300 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 301 execute.erase(); 302 303 return {func, coro}; 304 } 305 306 //===----------------------------------------------------------------------===// 307 // Convert async.create_group operation to async.runtime.create_group 308 //===----------------------------------------------------------------------===// 309 310 namespace { 311 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 312 public: 313 using OpConversionPattern::OpConversionPattern; 314 315 LogicalResult 316 matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands, 317 ConversionPatternRewriter &rewriter) const override { 318 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 319 op, GroupType::get(op->getContext()), operands); 320 return success(); 321 } 322 }; 323 } // namespace 324 325 //===----------------------------------------------------------------------===// 326 // Convert async.add_to_group operation to async.runtime.add_to_group. 327 //===----------------------------------------------------------------------===// 328 329 namespace { 330 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 331 public: 332 using OpConversionPattern::OpConversionPattern; 333 334 LogicalResult 335 matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands, 336 ConversionPatternRewriter &rewriter) const override { 337 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 338 op, rewriter.getIndexType(), operands); 339 return success(); 340 } 341 }; 342 } // namespace 343 344 //===----------------------------------------------------------------------===// 345 // Convert async.await and async.await_all operations to the async.runtime.await 346 // or async.runtime.await_and_resume operations. 347 //===----------------------------------------------------------------------===// 348 349 namespace { 350 template <typename AwaitType, typename AwaitableType> 351 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 352 using AwaitAdaptor = typename AwaitType::Adaptor; 353 354 public: 355 AwaitOpLoweringBase(MLIRContext *ctx, 356 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 357 : OpConversionPattern<AwaitType>(ctx), 358 outlinedFunctions(outlinedFunctions) {} 359 360 LogicalResult 361 matchAndRewrite(AwaitType op, ArrayRef<Value> operands, 362 ConversionPatternRewriter &rewriter) const override { 363 // We can only await on one the `AwaitableType` (for `await` it can be 364 // a `token` or a `value`, for `await_all` it must be a `group`). 365 if (!op.operand().getType().template isa<AwaitableType>()) 366 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 367 368 // Check if await operation is inside the outlined coroutine function. 369 auto func = op->template getParentOfType<FuncOp>(); 370 auto outlined = outlinedFunctions.find(func); 371 const bool isInCoroutine = outlined != outlinedFunctions.end(); 372 373 Location loc = op->getLoc(); 374 Value operand = AwaitAdaptor(operands).operand(); 375 376 // Inside regular functions we use the blocking wait operation to wait for 377 // the async object (token, value or group) to become available. 378 if (!isInCoroutine) 379 rewriter.create<RuntimeAwaitOp>(loc, operand); 380 381 // Inside the coroutine we convert await operation into coroutine suspension 382 // point, and resume execution asynchronously. 383 if (isInCoroutine) { 384 CoroMachinery &coro = outlined->getSecond(); 385 Block *suspended = op->getBlock(); 386 387 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 388 MLIRContext *ctx = op->getContext(); 389 390 // Save the coroutine state and resume on a runtime managed thread when 391 // the operand becomes available. 392 auto coroSaveOp = 393 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 394 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 395 396 // Split the entry block before the await operation. 397 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 398 399 // Add async.coro.suspend as a suspended block terminator. 400 builder.setInsertionPointToEnd(suspended); 401 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 402 coro.cleanup); 403 404 // Split the resume block into error checking and continuation. 405 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 406 407 // Check if the awaited value is in the error state. 408 builder.setInsertionPointToStart(resume); 409 auto isError = 410 builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand); 411 builder.create<CondBranchOp>(isError, 412 /*trueDest=*/setupSetErrorBlock(coro), 413 /*trueArgs=*/ArrayRef<Value>(), 414 /*falseDest=*/continuation, 415 /*falseArgs=*/ArrayRef<Value>()); 416 417 // Make sure that replacement value will be constructed in the 418 // continuation block. 419 rewriter.setInsertionPointToStart(continuation); 420 } 421 422 // Erase or replace the await operation with the new value. 423 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 424 rewriter.replaceOp(op, replaceWith); 425 else 426 rewriter.eraseOp(op); 427 428 return success(); 429 } 430 431 virtual Value getReplacementValue(AwaitType op, Value operand, 432 ConversionPatternRewriter &rewriter) const { 433 return Value(); 434 } 435 436 private: 437 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 438 }; 439 440 /// Lowering for `async.await` with a token operand. 441 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 442 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 443 444 public: 445 using Base::Base; 446 }; 447 448 /// Lowering for `async.await` with a value operand. 449 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 450 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 451 452 public: 453 using Base::Base; 454 455 Value 456 getReplacementValue(AwaitOp op, Value operand, 457 ConversionPatternRewriter &rewriter) const override { 458 // Load from the async value storage. 459 auto valueType = operand.getType().cast<ValueType>().getValueType(); 460 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 461 } 462 }; 463 464 /// Lowering for `async.await_all` operation. 465 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 466 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 467 468 public: 469 using Base::Base; 470 }; 471 472 } // namespace 473 474 //===----------------------------------------------------------------------===// 475 // Convert async.yield operation to async.runtime operations. 476 //===----------------------------------------------------------------------===// 477 478 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 479 public: 480 YieldOpLowering( 481 MLIRContext *ctx, 482 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 483 : OpConversionPattern<async::YieldOp>(ctx), 484 outlinedFunctions(outlinedFunctions) {} 485 486 LogicalResult 487 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 488 ConversionPatternRewriter &rewriter) const override { 489 // Check if yield operation is inside the async coroutine function. 490 auto func = op->template getParentOfType<FuncOp>(); 491 auto outlined = outlinedFunctions.find(func); 492 if (outlined == outlinedFunctions.end()) 493 return rewriter.notifyMatchFailure( 494 op, "operation is not inside the async coroutine function"); 495 496 Location loc = op->getLoc(); 497 const CoroMachinery &coro = outlined->getSecond(); 498 499 // Store yielded values into the async values storage and switch async 500 // values state to available. 501 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 502 Value yieldValue = std::get<0>(tuple); 503 Value asyncValue = std::get<1>(tuple); 504 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 505 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 506 } 507 508 // Switch the coroutine completion token to available state. 509 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 510 511 return success(); 512 } 513 514 private: 515 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 516 }; 517 518 //===----------------------------------------------------------------------===// 519 // Convert std.assert operation to cond_br into `set_error` block. 520 //===----------------------------------------------------------------------===// 521 522 class AssertOpLowering : public OpConversionPattern<AssertOp> { 523 public: 524 AssertOpLowering(MLIRContext *ctx, 525 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 526 : OpConversionPattern<AssertOp>(ctx), 527 outlinedFunctions(outlinedFunctions) {} 528 529 LogicalResult 530 matchAndRewrite(AssertOp op, ArrayRef<Value> operands, 531 ConversionPatternRewriter &rewriter) const override { 532 // Check if assert operation is inside the async coroutine function. 533 auto func = op->template getParentOfType<FuncOp>(); 534 auto outlined = outlinedFunctions.find(func); 535 if (outlined == outlinedFunctions.end()) 536 return rewriter.notifyMatchFailure( 537 op, "operation is not inside the async coroutine function"); 538 539 Location loc = op->getLoc(); 540 CoroMachinery &coro = outlined->getSecond(); 541 542 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 543 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 544 rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(), 545 /*trueDest=*/cont, 546 /*trueArgs=*/ArrayRef<Value>(), 547 /*falseDest=*/setupSetErrorBlock(coro), 548 /*falseArgs=*/ArrayRef<Value>()); 549 rewriter.eraseOp(op); 550 551 return success(); 552 } 553 554 private: 555 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 556 }; 557 558 //===----------------------------------------------------------------------===// 559 560 void AsyncToAsyncRuntimePass::runOnOperation() { 561 ModuleOp module = getOperation(); 562 SymbolTable symbolTable(module); 563 564 // Outline all `async.execute` body regions into async functions (coroutines). 565 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 566 567 module.walk([&](ExecuteOp execute) { 568 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 569 }); 570 571 LLVM_DEBUG({ 572 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 573 << " functions built from async.execute operations\n"; 574 }); 575 576 // Returns true if operation is inside the coroutine. 577 auto isInCoroutine = [&](Operation *op) -> bool { 578 auto parentFunc = op->getParentOfType<FuncOp>(); 579 return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 580 }; 581 582 // Lower async operations to async.runtime operations. 583 MLIRContext *ctx = module->getContext(); 584 RewritePatternSet asyncPatterns(ctx); 585 586 // Conversion to async runtime augments original CFG with the coroutine CFG, 587 // and we have to make sure that structured control flow operations with async 588 // operations in nested regions will be converted to branch-based control flow 589 // before we add the coroutine basic blocks. 590 populateLoopToStdConversionPatterns(asyncPatterns); 591 592 // Async lowering does not use type converter because it must preserve all 593 // types for async.runtime operations. 594 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 595 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 596 AwaitAllOpLowering, YieldOpLowering>(ctx, 597 outlinedFunctions); 598 599 // Lower assertions to conditional branches into error blocks. 600 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 601 602 // All high level async operations must be lowered to the runtime operations. 603 ConversionTarget runtimeTarget(*ctx); 604 runtimeTarget.addLegalDialect<AsyncDialect>(); 605 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 606 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 607 608 // Decide if structured control flow has to be lowered to branch-based CFG. 609 runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 610 auto walkResult = op->walk([&](Operation *nested) { 611 bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 612 return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 613 : WalkResult::advance(); 614 }); 615 return !walkResult.wasInterrupted(); 616 }); 617 runtimeTarget.addLegalOp<BranchOp, CondBranchOp>(); 618 619 // Assertions must be converted to runtime errors inside async functions. 620 runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool { 621 auto func = op->getParentOfType<FuncOp>(); 622 return outlinedFunctions.find(func) == outlinedFunctions.end(); 623 }); 624 625 if (failed(applyPartialConversion(module, runtimeTarget, 626 std::move(asyncPatterns)))) { 627 signalPassFailure(); 628 return; 629 } 630 } 631 632 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 633 return std::make_unique<AsyncToAsyncRuntimePass>(); 634 } 635