1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering from high level async operations to async.coro 10 // and async.runtime operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Async/IR/Async.h" 16 #include "mlir/Dialect/Async/Passes.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include "mlir/Transforms/RegionUtils.h" 23 #include "llvm/ADT/SetVector.h" 24 #include "llvm/Support/Debug.h" 25 26 using namespace mlir; 27 using namespace mlir::async; 28 29 #define DEBUG_TYPE "async-to-async-runtime" 30 // Prefix for functions outlined from `async.execute` op regions. 31 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 32 33 namespace { 34 35 class AsyncToAsyncRuntimePass 36 : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 37 public: 38 AsyncToAsyncRuntimePass() = default; 39 void runOnOperation() override; 40 }; 41 42 } // namespace 43 44 //===----------------------------------------------------------------------===// 45 // async.execute op outlining to the coroutine functions. 46 //===----------------------------------------------------------------------===// 47 48 /// Function targeted for coroutine transformation has two additional blocks at 49 /// the end: coroutine cleanup and coroutine suspension. 50 /// 51 /// async.await op lowering additionaly creates a resume block for each 52 /// operation to enable non-blocking waiting via coroutine suspension. 53 namespace { 54 struct CoroMachinery { 55 FuncOp func; 56 57 // Async execute region returns a completion token, and an async value for 58 // each yielded value. 59 // 60 // %token, %result = async.execute -> !async.value<T> { 61 // %0 = constant ... : T 62 // async.yield %0 : T 63 // } 64 Value asyncToken; // token representing completion of the async region 65 llvm::SmallVector<Value, 4> returnValues; // returned async values 66 67 Value coroHandle; // coroutine handle (!async.coro.handle value) 68 Block *setError; // switch completion token and all values to error state 69 Block *cleanup; // coroutine cleanup block 70 Block *suspend; // coroutine suspension block 71 }; 72 } // namespace 73 74 /// Builds an coroutine template compatible with LLVM coroutines switched-resume 75 /// lowering using `async.runtime.*` and `async.coro.*` operations. 76 /// 77 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 78 /// 79 /// - `entry` block sets up the coroutine. 80 /// - `set_error` block sets completion token and async values state to error. 81 /// - `cleanup` block cleans up the coroutine state. 82 /// - `suspend block after the @llvm.coro.end() defines what value will be 83 /// returned to the initial caller of a coroutine. Everything before the 84 /// @llvm.coro.end() will be executed at every suspension point. 85 /// 86 /// Coroutine structure (only the important bits): 87 /// 88 /// func @async_execute_fn(<function-arguments>) 89 /// -> (!async.token, !async.value<T>) 90 /// { 91 /// ^entry(<function-arguments>): 92 /// %token = <async token> : !async.token // create async runtime token 93 /// %value = <async value> : !async.value<T> // create async value 94 /// %id = async.coro.id // create a coroutine id 95 /// %hdl = async.coro.begin %id // create a coroutine handle 96 /// br ^cleanup 97 /// 98 /// ^set_error: // this block created lazily only if needed (see code below) 99 /// async.runtime.set_error %token : !async.token 100 /// async.runtime.set_error %value : !async.value<T> 101 /// br ^cleanup 102 /// 103 /// ^cleanup: 104 /// async.coro.free %hdl // delete the coroutine state 105 /// br ^suspend 106 /// 107 /// ^suspend: 108 /// async.coro.end %hdl // marks the end of a coroutine 109 /// return %token, %value : !async.token, !async.value<T> 110 /// } 111 /// 112 /// The actual code for the async.execute operation body region will be inserted 113 /// before the entry block terminator. 114 /// 115 /// 116 static CoroMachinery setupCoroMachinery(FuncOp func) { 117 assert(func.getBody().empty() && "Function must have empty body"); 118 119 MLIRContext *ctx = func.getContext(); 120 Block *entryBlock = func.addEntryBlock(); 121 122 auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 123 124 // ------------------------------------------------------------------------ // 125 // Allocate async token/values that we will return from a ramp function. 126 // ------------------------------------------------------------------------ // 127 auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 128 129 llvm::SmallVector<Value, 4> retValues; 130 for (auto resType : func.getCallableResults().drop_front()) 131 retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 132 133 // ------------------------------------------------------------------------ // 134 // Initialize coroutine: get coroutine id and coroutine handle. 135 // ------------------------------------------------------------------------ // 136 auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 137 auto coroHdlOp = 138 builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 139 140 Block *cleanupBlock = func.addBlock(); 141 Block *suspendBlock = func.addBlock(); 142 143 // ------------------------------------------------------------------------ // 144 // Coroutine cleanup block: deallocate coroutine frame, free the memory. 145 // ------------------------------------------------------------------------ // 146 builder.setInsertionPointToStart(cleanupBlock); 147 builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 148 149 // Branch into the suspend block. 150 builder.create<BranchOp>(suspendBlock); 151 152 // ------------------------------------------------------------------------ // 153 // Coroutine suspend block: mark the end of a coroutine and return allocated 154 // async token. 155 // ------------------------------------------------------------------------ // 156 builder.setInsertionPointToStart(suspendBlock); 157 158 // Mark the end of a coroutine: async.coro.end 159 builder.create<CoroEndOp>(coroHdlOp.handle()); 160 161 // Return created `async.token` and `async.values` from the suspend block. 162 // This will be the return value of a coroutine ramp function. 163 SmallVector<Value, 4> ret{retToken}; 164 ret.insert(ret.end(), retValues.begin(), retValues.end()); 165 builder.create<ReturnOp>(ret); 166 167 // Branch from the entry block to the cleanup block to create a valid CFG. 168 builder.setInsertionPointToEnd(entryBlock); 169 builder.create<BranchOp>(cleanupBlock); 170 171 // `async.await` op lowering will create resume blocks for async 172 // continuations, and will conditionally branch to cleanup or suspend blocks. 173 174 CoroMachinery machinery; 175 machinery.func = func; 176 machinery.asyncToken = retToken; 177 machinery.returnValues = retValues; 178 machinery.coroHandle = coroHdlOp.handle(); 179 machinery.setError = nullptr; // created lazily only if needed 180 machinery.cleanup = cleanupBlock; 181 machinery.suspend = suspendBlock; 182 return machinery; 183 } 184 185 // Lazily creates `set_error` block only if it is required for lowering to the 186 // runtime operations (see for example lowering of assert operation). 187 static Block *setupSetErrorBlock(CoroMachinery &coro) { 188 if (coro.setError) 189 return coro.setError; 190 191 coro.setError = coro.func.addBlock(); 192 coro.setError->moveBefore(coro.cleanup); 193 194 auto builder = 195 ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 196 197 // Coroutine set_error block: set error on token and all returned values. 198 builder.create<RuntimeSetErrorOp>(coro.asyncToken); 199 for (Value retValue : coro.returnValues) 200 builder.create<RuntimeSetErrorOp>(retValue); 201 202 // Branch into the cleanup block. 203 builder.create<BranchOp>(coro.cleanup); 204 205 return coro.setError; 206 } 207 208 /// Outline the body region attached to the `async.execute` op into a standalone 209 /// function. 210 /// 211 /// Note that this is not reversible transformation. 212 static std::pair<FuncOp, CoroMachinery> 213 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 214 ModuleOp module = execute->getParentOfType<ModuleOp>(); 215 216 MLIRContext *ctx = module.getContext(); 217 Location loc = execute.getLoc(); 218 219 // Collect all outlined function inputs. 220 SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 221 execute.dependencies().end()); 222 functionInputs.insert(execute.operands().begin(), execute.operands().end()); 223 getUsedValuesDefinedAbove(execute.body(), functionInputs); 224 225 // Collect types for the outlined function inputs and outputs. 226 auto typesRange = llvm::map_range( 227 functionInputs, [](Value value) { return value.getType(); }); 228 SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 229 auto outputTypes = execute.getResultTypes(); 230 231 auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 232 auto funcAttrs = ArrayRef<NamedAttribute>(); 233 234 // TODO: Derive outlined function name from the parent FuncOp (support 235 // multiple nested async.execute operations). 236 FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 237 symbolTable.insert(func); 238 239 SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 240 241 // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 242 // blocks, adding async.coro operations and setting up control flow. 243 CoroMachinery coro = setupCoroMachinery(func); 244 245 // Suspend async function at the end of an entry block, and resume it using 246 // Async resume operation (execution will be resumed in a thread managed by 247 // the async runtime). 248 Block *entryBlock = &func.getBlocks().front(); 249 auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); 250 251 // Save the coroutine state: async.coro.save 252 auto coroSaveOp = 253 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 254 255 // Pass coroutine to the runtime to be resumed on a runtime managed thread. 256 builder.create<RuntimeResumeOp>(coro.coroHandle); 257 258 // Split the entry block before the terminator (branch to suspend block). 259 auto *terminatorOp = entryBlock->getTerminator(); 260 Block *suspended = terminatorOp->getBlock(); 261 Block *resume = suspended->splitBlock(terminatorOp); 262 263 // Add async.coro.suspend as a suspended block terminator. 264 builder.setInsertionPointToEnd(suspended); 265 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 266 coro.cleanup); 267 268 size_t numDependencies = execute.dependencies().size(); 269 size_t numOperands = execute.operands().size(); 270 271 // Await on all dependencies before starting to execute the body region. 272 builder.setInsertionPointToStart(resume); 273 for (size_t i = 0; i < numDependencies; ++i) 274 builder.create<AwaitOp>(func.getArgument(i)); 275 276 // Await on all async value operands and unwrap the payload. 277 SmallVector<Value, 4> unwrappedOperands(numOperands); 278 for (size_t i = 0; i < numOperands; ++i) { 279 Value operand = func.getArgument(numDependencies + i); 280 unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 281 } 282 283 // Map from function inputs defined above the execute op to the function 284 // arguments. 285 BlockAndValueMapping valueMapping; 286 valueMapping.map(functionInputs, func.getArguments()); 287 valueMapping.map(execute.body().getArguments(), unwrappedOperands); 288 289 // Clone all operations from the execute operation body into the outlined 290 // function body. 291 for (Operation &op : execute.body().getOps()) 292 builder.clone(op, valueMapping); 293 294 // Replace the original `async.execute` with a call to outlined function. 295 ImplicitLocOpBuilder callBuilder(loc, execute); 296 auto callOutlinedFunc = callBuilder.create<CallOp>( 297 func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 298 execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 299 execute.erase(); 300 301 return {func, coro}; 302 } 303 304 //===----------------------------------------------------------------------===// 305 // Convert async.create_group operation to async.runtime.create_group 306 //===----------------------------------------------------------------------===// 307 308 namespace { 309 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 310 public: 311 using OpConversionPattern::OpConversionPattern; 312 313 LogicalResult 314 matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands, 315 ConversionPatternRewriter &rewriter) const override { 316 rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 317 op, GroupType::get(op->getContext()), operands); 318 return success(); 319 } 320 }; 321 } // namespace 322 323 //===----------------------------------------------------------------------===// 324 // Convert async.add_to_group operation to async.runtime.add_to_group. 325 //===----------------------------------------------------------------------===// 326 327 namespace { 328 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 329 public: 330 using OpConversionPattern::OpConversionPattern; 331 332 LogicalResult 333 matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands, 334 ConversionPatternRewriter &rewriter) const override { 335 rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 336 op, rewriter.getIndexType(), operands); 337 return success(); 338 } 339 }; 340 } // namespace 341 342 //===----------------------------------------------------------------------===// 343 // Convert async.await and async.await_all operations to the async.runtime.await 344 // or async.runtime.await_and_resume operations. 345 //===----------------------------------------------------------------------===// 346 347 namespace { 348 template <typename AwaitType, typename AwaitableType> 349 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 350 using AwaitAdaptor = typename AwaitType::Adaptor; 351 352 public: 353 AwaitOpLoweringBase(MLIRContext *ctx, 354 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 355 : OpConversionPattern<AwaitType>(ctx), 356 outlinedFunctions(outlinedFunctions) {} 357 358 LogicalResult 359 matchAndRewrite(AwaitType op, ArrayRef<Value> operands, 360 ConversionPatternRewriter &rewriter) const override { 361 // We can only await on one the `AwaitableType` (for `await` it can be 362 // a `token` or a `value`, for `await_all` it must be a `group`). 363 if (!op.operand().getType().template isa<AwaitableType>()) 364 return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 365 366 // Check if await operation is inside the outlined coroutine function. 367 auto func = op->template getParentOfType<FuncOp>(); 368 auto outlined = outlinedFunctions.find(func); 369 const bool isInCoroutine = outlined != outlinedFunctions.end(); 370 371 Location loc = op->getLoc(); 372 Value operand = AwaitAdaptor(operands).operand(); 373 374 // Inside regular functions we use the blocking wait operation to wait for 375 // the async object (token, value or group) to become available. 376 if (!isInCoroutine) 377 rewriter.create<RuntimeAwaitOp>(loc, operand); 378 379 // Inside the coroutine we convert await operation into coroutine suspension 380 // point, and resume execution asynchronously. 381 if (isInCoroutine) { 382 CoroMachinery &coro = outlined->getSecond(); 383 Block *suspended = op->getBlock(); 384 385 ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 386 MLIRContext *ctx = op->getContext(); 387 388 // Save the coroutine state and resume on a runtime managed thread when 389 // the operand becomes available. 390 auto coroSaveOp = 391 builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 392 builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 393 394 // Split the entry block before the await operation. 395 Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 396 397 // Add async.coro.suspend as a suspended block terminator. 398 builder.setInsertionPointToEnd(suspended); 399 builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 400 coro.cleanup); 401 402 // Split the resume block into error checking and continuation. 403 Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 404 405 // Check if the awaited value is in the error state. 406 builder.setInsertionPointToStart(resume); 407 auto isError = 408 builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand); 409 builder.create<CondBranchOp>(isError, 410 /*trueDest=*/setupSetErrorBlock(coro), 411 /*trueArgs=*/ArrayRef<Value>(), 412 /*falseDest=*/continuation, 413 /*falseArgs=*/ArrayRef<Value>()); 414 415 // Make sure that replacement value will be constructed in the 416 // continuation block. 417 rewriter.setInsertionPointToStart(continuation); 418 } 419 420 // Erase or replace the await operation with the new value. 421 if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 422 rewriter.replaceOp(op, replaceWith); 423 else 424 rewriter.eraseOp(op); 425 426 return success(); 427 } 428 429 virtual Value getReplacementValue(AwaitType op, Value operand, 430 ConversionPatternRewriter &rewriter) const { 431 return Value(); 432 } 433 434 private: 435 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 436 }; 437 438 /// Lowering for `async.await` with a token operand. 439 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 440 using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 441 442 public: 443 using Base::Base; 444 }; 445 446 /// Lowering for `async.await` with a value operand. 447 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 448 using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 449 450 public: 451 using Base::Base; 452 453 Value 454 getReplacementValue(AwaitOp op, Value operand, 455 ConversionPatternRewriter &rewriter) const override { 456 // Load from the async value storage. 457 auto valueType = operand.getType().cast<ValueType>().getValueType(); 458 return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 459 } 460 }; 461 462 /// Lowering for `async.await_all` operation. 463 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 464 using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 465 466 public: 467 using Base::Base; 468 }; 469 470 } // namespace 471 472 //===----------------------------------------------------------------------===// 473 // Convert async.yield operation to async.runtime operations. 474 //===----------------------------------------------------------------------===// 475 476 class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 477 public: 478 YieldOpLowering( 479 MLIRContext *ctx, 480 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 481 : OpConversionPattern<async::YieldOp>(ctx), 482 outlinedFunctions(outlinedFunctions) {} 483 484 LogicalResult 485 matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 486 ConversionPatternRewriter &rewriter) const override { 487 // Check if yield operation is inside the async coroutine function. 488 auto func = op->template getParentOfType<FuncOp>(); 489 auto outlined = outlinedFunctions.find(func); 490 if (outlined == outlinedFunctions.end()) 491 return rewriter.notifyMatchFailure( 492 op, "operation is not inside the async coroutine function"); 493 494 Location loc = op->getLoc(); 495 const CoroMachinery &coro = outlined->getSecond(); 496 497 // Store yielded values into the async values storage and switch async 498 // values state to available. 499 for (auto tuple : llvm::zip(operands, coro.returnValues)) { 500 Value yieldValue = std::get<0>(tuple); 501 Value asyncValue = std::get<1>(tuple); 502 rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 503 rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 504 } 505 506 // Switch the coroutine completion token to available state. 507 rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 508 509 return success(); 510 } 511 512 private: 513 const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 514 }; 515 516 //===----------------------------------------------------------------------===// 517 // Convert std.assert operation to cond_br into `set_error` block. 518 //===----------------------------------------------------------------------===// 519 520 class AssertOpLowering : public OpConversionPattern<AssertOp> { 521 public: 522 AssertOpLowering(MLIRContext *ctx, 523 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 524 : OpConversionPattern<AssertOp>(ctx), 525 outlinedFunctions(outlinedFunctions) {} 526 527 LogicalResult 528 matchAndRewrite(AssertOp op, ArrayRef<Value> operands, 529 ConversionPatternRewriter &rewriter) const override { 530 // Check if assert operation is inside the async coroutine function. 531 auto func = op->template getParentOfType<FuncOp>(); 532 auto outlined = outlinedFunctions.find(func); 533 if (outlined == outlinedFunctions.end()) 534 return rewriter.notifyMatchFailure( 535 op, "operation is not inside the async coroutine function"); 536 537 Location loc = op->getLoc(); 538 CoroMachinery &coro = outlined->getSecond(); 539 540 Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 541 rewriter.setInsertionPointToEnd(cont->getPrevNode()); 542 rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(), 543 /*trueDest=*/cont, 544 /*trueArgs=*/ArrayRef<Value>(), 545 /*falseDest=*/setupSetErrorBlock(coro), 546 /*falseArgs=*/ArrayRef<Value>()); 547 rewriter.eraseOp(op); 548 549 return success(); 550 } 551 552 private: 553 llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 554 }; 555 556 //===----------------------------------------------------------------------===// 557 558 void AsyncToAsyncRuntimePass::runOnOperation() { 559 ModuleOp module = getOperation(); 560 SymbolTable symbolTable(module); 561 562 // Outline all `async.execute` body regions into async functions (coroutines). 563 llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 564 565 module.walk([&](ExecuteOp execute) { 566 outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 567 }); 568 569 LLVM_DEBUG({ 570 llvm::dbgs() << "Outlined " << outlinedFunctions.size() 571 << " functions built from async.execute operations\n"; 572 }); 573 574 // Lower async operations to async.runtime operations. 575 MLIRContext *ctx = module->getContext(); 576 RewritePatternSet asyncPatterns(ctx); 577 578 // Async lowering does not use type converter because it must preserve all 579 // types for async.runtime operations. 580 asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 581 asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 582 AwaitAllOpLowering, YieldOpLowering>(ctx, 583 outlinedFunctions); 584 585 // Lower assertions to conditional branches into error blocks. 586 asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 587 588 // All high level async operations must be lowered to the runtime operations. 589 ConversionTarget runtimeTarget(*ctx); 590 runtimeTarget.addLegalDialect<AsyncDialect>(); 591 runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 592 runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 593 594 // Assertions must be converted to runtime errors inside async functions. 595 runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool { 596 auto func = op->getParentOfType<FuncOp>(); 597 return outlinedFunctions.find(func) == outlinedFunctions.end(); 598 }); 599 runtimeTarget.addLegalOp<CondBranchOp>(); 600 601 if (failed(applyPartialConversion(module, runtimeTarget, 602 std::move(asyncPatterns)))) { 603 signalPassFailure(); 604 return; 605 } 606 } 607 608 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 609 return std::make_unique<AsyncToAsyncRuntimePass>(); 610 } 611