125f80e16SEugene Zhulenev //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 225f80e16SEugene Zhulenev // 325f80e16SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 425f80e16SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 525f80e16SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 625f80e16SEugene Zhulenev // 725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 825f80e16SEugene Zhulenev // 925f80e16SEugene Zhulenev // This file implements lowering from high level async operations to async.coro 1025f80e16SEugene Zhulenev // and async.runtime operations. 1125f80e16SEugene Zhulenev // 1225f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 1325f80e16SEugene Zhulenev 1425f80e16SEugene Zhulenev #include "PassDetail.h" 15de7a4e53SEugene Zhulenev #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 16a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1725f80e16SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h" 1825f80e16SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h" 19de7a4e53SEugene Zhulenev #include "mlir/Dialect/SCF/SCF.h" 2025f80e16SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h" 2125f80e16SEugene Zhulenev #include "mlir/IR/BlockAndValueMapping.h" 2225f80e16SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h" 2325f80e16SEugene Zhulenev #include "mlir/IR/PatternMatch.h" 2425f80e16SEugene Zhulenev #include "mlir/Transforms/DialectConversion.h" 2525f80e16SEugene Zhulenev #include "mlir/Transforms/RegionUtils.h" 2625f80e16SEugene Zhulenev #include "llvm/ADT/SetVector.h" 27297a5b7cSNico Weber #include "llvm/Support/Debug.h" 2825f80e16SEugene Zhulenev 2925f80e16SEugene Zhulenev using namespace mlir; 3025f80e16SEugene Zhulenev using namespace mlir::async; 3125f80e16SEugene Zhulenev 3225f80e16SEugene Zhulenev #define DEBUG_TYPE "async-to-async-runtime" 3325f80e16SEugene Zhulenev // Prefix for functions outlined from `async.execute` op regions. 3425f80e16SEugene Zhulenev static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 3525f80e16SEugene Zhulenev 3625f80e16SEugene Zhulenev namespace { 3725f80e16SEugene Zhulenev 3825f80e16SEugene Zhulenev class AsyncToAsyncRuntimePass 3925f80e16SEugene Zhulenev : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 4025f80e16SEugene Zhulenev public: 4125f80e16SEugene Zhulenev AsyncToAsyncRuntimePass() = default; 4225f80e16SEugene Zhulenev void runOnOperation() override; 4325f80e16SEugene Zhulenev }; 4425f80e16SEugene Zhulenev 4525f80e16SEugene Zhulenev } // namespace 4625f80e16SEugene Zhulenev 4725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 4825f80e16SEugene Zhulenev // async.execute op outlining to the coroutine functions. 4925f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 5025f80e16SEugene Zhulenev 5125f80e16SEugene Zhulenev /// Function targeted for coroutine transformation has two additional blocks at 5225f80e16SEugene Zhulenev /// the end: coroutine cleanup and coroutine suspension. 5325f80e16SEugene Zhulenev /// 5425f80e16SEugene Zhulenev /// async.await op lowering additionaly creates a resume block for each 5525f80e16SEugene Zhulenev /// operation to enable non-blocking waiting via coroutine suspension. 5625f80e16SEugene Zhulenev namespace { 5725f80e16SEugene Zhulenev struct CoroMachinery { 5839957aa4SEugene Zhulenev FuncOp func; 5939957aa4SEugene Zhulenev 6025f80e16SEugene Zhulenev // Async execute region returns a completion token, and an async value for 6125f80e16SEugene Zhulenev // each yielded value. 6225f80e16SEugene Zhulenev // 6325f80e16SEugene Zhulenev // %token, %result = async.execute -> !async.value<T> { 64cb3aa49eSMogball // %0 = arith.constant ... : T 6525f80e16SEugene Zhulenev // async.yield %0 : T 6625f80e16SEugene Zhulenev // } 6725f80e16SEugene Zhulenev Value asyncToken; // token representing completion of the async region 6825f80e16SEugene Zhulenev llvm::SmallVector<Value, 4> returnValues; // returned async values 6925f80e16SEugene Zhulenev 7025f80e16SEugene Zhulenev Value coroHandle; // coroutine handle (!async.coro.handle value) 711c144410Sbakhtiyar Block *entry; // coroutine entry block 7239957aa4SEugene Zhulenev Block *setError; // switch completion token and all values to error state 7325f80e16SEugene Zhulenev Block *cleanup; // coroutine cleanup block 7425f80e16SEugene Zhulenev Block *suspend; // coroutine suspension block 7525f80e16SEugene Zhulenev }; 7625f80e16SEugene Zhulenev } // namespace 7725f80e16SEugene Zhulenev 786ea22d46Sbakhtiyar /// Utility to partially update the regular function CFG to the coroutine CFG 796ea22d46Sbakhtiyar /// compatible with LLVM coroutines switched-resume lowering using 801c144410Sbakhtiyar /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block 811c144410Sbakhtiyar /// that branches into preexisting entry block. Also inserts trailing blocks. 826ea22d46Sbakhtiyar /// 836ea22d46Sbakhtiyar /// The result types of the passed `func` must start with an `async.token` 846ea22d46Sbakhtiyar /// and be continued with some number of `async.value`s. 856ea22d46Sbakhtiyar /// 861c144410Sbakhtiyar /// The func given to this function needs to have been preprocessed to have 871c144410Sbakhtiyar /// either branch or yield ops as terminators. Branches to the cleanup block are 881c144410Sbakhtiyar /// inserted after each yield. 8925f80e16SEugene Zhulenev /// 9025f80e16SEugene Zhulenev /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 9125f80e16SEugene Zhulenev /// 9225f80e16SEugene Zhulenev /// - `entry` block sets up the coroutine. 9339957aa4SEugene Zhulenev /// - `set_error` block sets completion token and async values state to error. 9425f80e16SEugene Zhulenev /// - `cleanup` block cleans up the coroutine state. 9525f80e16SEugene Zhulenev /// - `suspend block after the @llvm.coro.end() defines what value will be 9625f80e16SEugene Zhulenev /// returned to the initial caller of a coroutine. Everything before the 9725f80e16SEugene Zhulenev /// @llvm.coro.end() will be executed at every suspension point. 9825f80e16SEugene Zhulenev /// 9925f80e16SEugene Zhulenev /// Coroutine structure (only the important bits): 10025f80e16SEugene Zhulenev /// 1016ea22d46Sbakhtiyar /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>) 10225f80e16SEugene Zhulenev /// { 10325f80e16SEugene Zhulenev /// ^entry(<function-arguments>): 10425f80e16SEugene Zhulenev /// %token = <async token> : !async.token // create async runtime token 10525f80e16SEugene Zhulenev /// %value = <async value> : !async.value<T> // create async value 10625f80e16SEugene Zhulenev /// %id = async.coro.id // create a coroutine id 10725f80e16SEugene Zhulenev /// %hdl = async.coro.begin %id // create a coroutine handle 1081c144410Sbakhtiyar /// br ^preexisting_entry_block 1096ea22d46Sbakhtiyar /// 1101c144410Sbakhtiyar /// /* preexisting blocks modified to branch to the cleanup block */ 11125f80e16SEugene Zhulenev /// 11239957aa4SEugene Zhulenev /// ^set_error: // this block created lazily only if needed (see code below) 11339957aa4SEugene Zhulenev /// async.runtime.set_error %token : !async.token 11439957aa4SEugene Zhulenev /// async.runtime.set_error %value : !async.value<T> 11539957aa4SEugene Zhulenev /// br ^cleanup 11639957aa4SEugene Zhulenev /// 11725f80e16SEugene Zhulenev /// ^cleanup: 11825f80e16SEugene Zhulenev /// async.coro.free %hdl // delete the coroutine state 11925f80e16SEugene Zhulenev /// br ^suspend 12025f80e16SEugene Zhulenev /// 12125f80e16SEugene Zhulenev /// ^suspend: 12225f80e16SEugene Zhulenev /// async.coro.end %hdl // marks the end of a coroutine 12325f80e16SEugene Zhulenev /// return %token, %value : !async.token, !async.value<T> 12425f80e16SEugene Zhulenev /// } 12525f80e16SEugene Zhulenev /// 12625f80e16SEugene Zhulenev static CoroMachinery setupCoroMachinery(FuncOp func) { 1276ea22d46Sbakhtiyar assert(!func.getBlocks().empty() && "Function must have an entry block"); 12825f80e16SEugene Zhulenev 12925f80e16SEugene Zhulenev MLIRContext *ctx = func.getContext(); 1306ea22d46Sbakhtiyar Block *entryBlock = &func.getBlocks().front(); 1311c144410Sbakhtiyar Block *originalEntryBlock = 1321c144410Sbakhtiyar entryBlock->splitBlock(entryBlock->getOperations().begin()); 13325f80e16SEugene Zhulenev auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 13425f80e16SEugene Zhulenev 13525f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 13625f80e16SEugene Zhulenev // Allocate async token/values that we will return from a ramp function. 13725f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 13825f80e16SEugene Zhulenev auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 13925f80e16SEugene Zhulenev 14025f80e16SEugene Zhulenev llvm::SmallVector<Value, 4> retValues; 14125f80e16SEugene Zhulenev for (auto resType : func.getCallableResults().drop_front()) 14225f80e16SEugene Zhulenev retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 14325f80e16SEugene Zhulenev 14425f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 14525f80e16SEugene Zhulenev // Initialize coroutine: get coroutine id and coroutine handle. 14625f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 14725f80e16SEugene Zhulenev auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 14825f80e16SEugene Zhulenev auto coroHdlOp = 14925f80e16SEugene Zhulenev builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 1501c144410Sbakhtiyar builder.create<BranchOp>(originalEntryBlock); 15125f80e16SEugene Zhulenev 15225f80e16SEugene Zhulenev Block *cleanupBlock = func.addBlock(); 15325f80e16SEugene Zhulenev Block *suspendBlock = func.addBlock(); 15425f80e16SEugene Zhulenev 15525f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 15625f80e16SEugene Zhulenev // Coroutine cleanup block: deallocate coroutine frame, free the memory. 15725f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 15825f80e16SEugene Zhulenev builder.setInsertionPointToStart(cleanupBlock); 15925f80e16SEugene Zhulenev builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 16025f80e16SEugene Zhulenev 16125f80e16SEugene Zhulenev // Branch into the suspend block. 16225f80e16SEugene Zhulenev builder.create<BranchOp>(suspendBlock); 16325f80e16SEugene Zhulenev 16425f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 16525f80e16SEugene Zhulenev // Coroutine suspend block: mark the end of a coroutine and return allocated 16625f80e16SEugene Zhulenev // async token. 16725f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 16825f80e16SEugene Zhulenev builder.setInsertionPointToStart(suspendBlock); 16925f80e16SEugene Zhulenev 17025f80e16SEugene Zhulenev // Mark the end of a coroutine: async.coro.end 17125f80e16SEugene Zhulenev builder.create<CoroEndOp>(coroHdlOp.handle()); 17225f80e16SEugene Zhulenev 17325f80e16SEugene Zhulenev // Return created `async.token` and `async.values` from the suspend block. 17425f80e16SEugene Zhulenev // This will be the return value of a coroutine ramp function. 17525f80e16SEugene Zhulenev SmallVector<Value, 4> ret{retToken}; 17625f80e16SEugene Zhulenev ret.insert(ret.end(), retValues.begin(), retValues.end()); 17725f80e16SEugene Zhulenev builder.create<ReturnOp>(ret); 17825f80e16SEugene Zhulenev 17925f80e16SEugene Zhulenev // `async.await` op lowering will create resume blocks for async 18025f80e16SEugene Zhulenev // continuations, and will conditionally branch to cleanup or suspend blocks. 18125f80e16SEugene Zhulenev 1821c144410Sbakhtiyar for (Block &block : func.body().getBlocks()) { 1831c144410Sbakhtiyar if (&block == entryBlock || &block == cleanupBlock || 1841c144410Sbakhtiyar &block == suspendBlock) 1851c144410Sbakhtiyar continue; 1861c144410Sbakhtiyar Operation *terminator = block.getTerminator(); 1871c144410Sbakhtiyar if (auto yield = dyn_cast<YieldOp>(terminator)) { 1881c144410Sbakhtiyar builder.setInsertionPointToEnd(&block); 1891c144410Sbakhtiyar builder.create<BranchOp>(cleanupBlock); 1901c144410Sbakhtiyar } 1911c144410Sbakhtiyar } 1921c144410Sbakhtiyar 193*c75cedc2SChuanqi Xu // The switch-resumed API based coroutine should be marked with 194*c75cedc2SChuanqi Xu // "coroutine.presplit" attribute with value "0" to mark the function as a 195*c75cedc2SChuanqi Xu // coroutine. 196*c75cedc2SChuanqi Xu func->setAttr("passthrough", builder.getArrayAttr(builder.getArrayAttr( 197*c75cedc2SChuanqi Xu {builder.getStringAttr("coroutine.presplit"), 198*c75cedc2SChuanqi Xu builder.getStringAttr("0")}))); 199*c75cedc2SChuanqi Xu 20025f80e16SEugene Zhulenev CoroMachinery machinery; 20139957aa4SEugene Zhulenev machinery.func = func; 20225f80e16SEugene Zhulenev machinery.asyncToken = retToken; 20325f80e16SEugene Zhulenev machinery.returnValues = retValues; 20425f80e16SEugene Zhulenev machinery.coroHandle = coroHdlOp.handle(); 2051c144410Sbakhtiyar machinery.entry = entryBlock; 20639957aa4SEugene Zhulenev machinery.setError = nullptr; // created lazily only if needed 20725f80e16SEugene Zhulenev machinery.cleanup = cleanupBlock; 20825f80e16SEugene Zhulenev machinery.suspend = suspendBlock; 20925f80e16SEugene Zhulenev return machinery; 21025f80e16SEugene Zhulenev } 21125f80e16SEugene Zhulenev 21239957aa4SEugene Zhulenev // Lazily creates `set_error` block only if it is required for lowering to the 21339957aa4SEugene Zhulenev // runtime operations (see for example lowering of assert operation). 21439957aa4SEugene Zhulenev static Block *setupSetErrorBlock(CoroMachinery &coro) { 21539957aa4SEugene Zhulenev if (coro.setError) 21639957aa4SEugene Zhulenev return coro.setError; 21739957aa4SEugene Zhulenev 21839957aa4SEugene Zhulenev coro.setError = coro.func.addBlock(); 21939957aa4SEugene Zhulenev coro.setError->moveBefore(coro.cleanup); 22039957aa4SEugene Zhulenev 22139957aa4SEugene Zhulenev auto builder = 22239957aa4SEugene Zhulenev ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); 22339957aa4SEugene Zhulenev 22439957aa4SEugene Zhulenev // Coroutine set_error block: set error on token and all returned values. 22539957aa4SEugene Zhulenev builder.create<RuntimeSetErrorOp>(coro.asyncToken); 22639957aa4SEugene Zhulenev for (Value retValue : coro.returnValues) 22739957aa4SEugene Zhulenev builder.create<RuntimeSetErrorOp>(retValue); 22839957aa4SEugene Zhulenev 22939957aa4SEugene Zhulenev // Branch into the cleanup block. 23039957aa4SEugene Zhulenev builder.create<BranchOp>(coro.cleanup); 23139957aa4SEugene Zhulenev 23239957aa4SEugene Zhulenev return coro.setError; 23339957aa4SEugene Zhulenev } 23439957aa4SEugene Zhulenev 23525f80e16SEugene Zhulenev /// Outline the body region attached to the `async.execute` op into a standalone 23625f80e16SEugene Zhulenev /// function. 23725f80e16SEugene Zhulenev /// 23825f80e16SEugene Zhulenev /// Note that this is not reversible transformation. 23925f80e16SEugene Zhulenev static std::pair<FuncOp, CoroMachinery> 24025f80e16SEugene Zhulenev outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 24125f80e16SEugene Zhulenev ModuleOp module = execute->getParentOfType<ModuleOp>(); 24225f80e16SEugene Zhulenev 24325f80e16SEugene Zhulenev MLIRContext *ctx = module.getContext(); 24425f80e16SEugene Zhulenev Location loc = execute.getLoc(); 24525f80e16SEugene Zhulenev 246b537c5b4SEugene Zhulenev // Make sure that all constants will be inside the outlined async function to 247b537c5b4SEugene Zhulenev // reduce the number of function arguments. 248b537c5b4SEugene Zhulenev cloneConstantsIntoTheRegion(execute.body()); 249b537c5b4SEugene Zhulenev 25025f80e16SEugene Zhulenev // Collect all outlined function inputs. 2514efb7754SRiver Riddle SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 25225f80e16SEugene Zhulenev execute.dependencies().end()); 25325f80e16SEugene Zhulenev functionInputs.insert(execute.operands().begin(), execute.operands().end()); 25425f80e16SEugene Zhulenev getUsedValuesDefinedAbove(execute.body(), functionInputs); 25525f80e16SEugene Zhulenev 25625f80e16SEugene Zhulenev // Collect types for the outlined function inputs and outputs. 25725f80e16SEugene Zhulenev auto typesRange = llvm::map_range( 25825f80e16SEugene Zhulenev functionInputs, [](Value value) { return value.getType(); }); 25925f80e16SEugene Zhulenev SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 26025f80e16SEugene Zhulenev auto outputTypes = execute.getResultTypes(); 26125f80e16SEugene Zhulenev 26225f80e16SEugene Zhulenev auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 26325f80e16SEugene Zhulenev auto funcAttrs = ArrayRef<NamedAttribute>(); 26425f80e16SEugene Zhulenev 26525f80e16SEugene Zhulenev // TODO: Derive outlined function name from the parent FuncOp (support 26625f80e16SEugene Zhulenev // multiple nested async.execute operations). 26725f80e16SEugene Zhulenev FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 268973ddb7dSMehdi Amini symbolTable.insert(func); 26925f80e16SEugene Zhulenev 27025f80e16SEugene Zhulenev SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 2711c144410Sbakhtiyar auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); 27225f80e16SEugene Zhulenev 2731c144410Sbakhtiyar // Prepare for coroutine conversion by creating the body of the function. 2741c144410Sbakhtiyar { 27525f80e16SEugene Zhulenev size_t numDependencies = execute.dependencies().size(); 27625f80e16SEugene Zhulenev size_t numOperands = execute.operands().size(); 27725f80e16SEugene Zhulenev 27825f80e16SEugene Zhulenev // Await on all dependencies before starting to execute the body region. 27925f80e16SEugene Zhulenev for (size_t i = 0; i < numDependencies; ++i) 28025f80e16SEugene Zhulenev builder.create<AwaitOp>(func.getArgument(i)); 28125f80e16SEugene Zhulenev 28225f80e16SEugene Zhulenev // Await on all async value operands and unwrap the payload. 28325f80e16SEugene Zhulenev SmallVector<Value, 4> unwrappedOperands(numOperands); 28425f80e16SEugene Zhulenev for (size_t i = 0; i < numOperands; ++i) { 28525f80e16SEugene Zhulenev Value operand = func.getArgument(numDependencies + i); 28625f80e16SEugene Zhulenev unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 28725f80e16SEugene Zhulenev } 28825f80e16SEugene Zhulenev 28925f80e16SEugene Zhulenev // Map from function inputs defined above the execute op to the function 29025f80e16SEugene Zhulenev // arguments. 29125f80e16SEugene Zhulenev BlockAndValueMapping valueMapping; 29225f80e16SEugene Zhulenev valueMapping.map(functionInputs, func.getArguments()); 29325f80e16SEugene Zhulenev valueMapping.map(execute.body().getArguments(), unwrappedOperands); 29425f80e16SEugene Zhulenev 29525f80e16SEugene Zhulenev // Clone all operations from the execute operation body into the outlined 29625f80e16SEugene Zhulenev // function body. 29725f80e16SEugene Zhulenev for (Operation &op : execute.body().getOps()) 29825f80e16SEugene Zhulenev builder.clone(op, valueMapping); 2991c144410Sbakhtiyar } 3001c144410Sbakhtiyar 3011c144410Sbakhtiyar // Adding entry/cleanup/suspend blocks. 3021c144410Sbakhtiyar CoroMachinery coro = setupCoroMachinery(func); 3031c144410Sbakhtiyar 3041c144410Sbakhtiyar // Suspend async function at the end of an entry block, and resume it using 3051c144410Sbakhtiyar // Async resume operation (execution will be resumed in a thread managed by 3061c144410Sbakhtiyar // the async runtime). 3071c144410Sbakhtiyar { 3081c144410Sbakhtiyar BranchOp branch = cast<BranchOp>(coro.entry->getTerminator()); 3091c144410Sbakhtiyar builder.setInsertionPointToEnd(coro.entry); 3101c144410Sbakhtiyar 3111c144410Sbakhtiyar // Save the coroutine state: async.coro.save 3121c144410Sbakhtiyar auto coroSaveOp = 3131c144410Sbakhtiyar builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 3141c144410Sbakhtiyar 3151c144410Sbakhtiyar // Pass coroutine to the runtime to be resumed on a runtime managed 3161c144410Sbakhtiyar // thread. 3171c144410Sbakhtiyar builder.create<RuntimeResumeOp>(coro.coroHandle); 3181c144410Sbakhtiyar 3191c144410Sbakhtiyar // Add async.coro.suspend as a suspended block terminator. 3201c144410Sbakhtiyar builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, 3211c144410Sbakhtiyar branch.getDest(), coro.cleanup); 3221c144410Sbakhtiyar 3231c144410Sbakhtiyar branch.erase(); 3241c144410Sbakhtiyar } 32525f80e16SEugene Zhulenev 32625f80e16SEugene Zhulenev // Replace the original `async.execute` with a call to outlined function. 3271c144410Sbakhtiyar { 32825f80e16SEugene Zhulenev ImplicitLocOpBuilder callBuilder(loc, execute); 32925f80e16SEugene Zhulenev auto callOutlinedFunc = callBuilder.create<CallOp>( 33025f80e16SEugene Zhulenev func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 33125f80e16SEugene Zhulenev execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 33225f80e16SEugene Zhulenev execute.erase(); 3331c144410Sbakhtiyar } 33425f80e16SEugene Zhulenev 33525f80e16SEugene Zhulenev return {func, coro}; 33625f80e16SEugene Zhulenev } 33725f80e16SEugene Zhulenev 33825f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 339d43b2360SEugene Zhulenev // Convert async.create_group operation to async.runtime.create_group 34025f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 34125f80e16SEugene Zhulenev 34225f80e16SEugene Zhulenev namespace { 34325f80e16SEugene Zhulenev class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 34425f80e16SEugene Zhulenev public: 34525f80e16SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 34625f80e16SEugene Zhulenev 34725f80e16SEugene Zhulenev LogicalResult 348b54c724bSRiver Riddle matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor, 34925f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 350d43b2360SEugene Zhulenev rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>( 351b54c724bSRiver Riddle op, GroupType::get(op->getContext()), adaptor.getOperands()); 35225f80e16SEugene Zhulenev return success(); 35325f80e16SEugene Zhulenev } 35425f80e16SEugene Zhulenev }; 35525f80e16SEugene Zhulenev } // namespace 35625f80e16SEugene Zhulenev 35725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 35825f80e16SEugene Zhulenev // Convert async.add_to_group operation to async.runtime.add_to_group. 35925f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 36025f80e16SEugene Zhulenev 36125f80e16SEugene Zhulenev namespace { 36225f80e16SEugene Zhulenev class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 36325f80e16SEugene Zhulenev public: 36425f80e16SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 36525f80e16SEugene Zhulenev 36625f80e16SEugene Zhulenev LogicalResult 367b54c724bSRiver Riddle matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor, 36825f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 36925f80e16SEugene Zhulenev rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 370b54c724bSRiver Riddle op, rewriter.getIndexType(), adaptor.getOperands()); 37125f80e16SEugene Zhulenev return success(); 37225f80e16SEugene Zhulenev } 37325f80e16SEugene Zhulenev }; 37425f80e16SEugene Zhulenev } // namespace 37525f80e16SEugene Zhulenev 37625f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 37725f80e16SEugene Zhulenev // Convert async.await and async.await_all operations to the async.runtime.await 37825f80e16SEugene Zhulenev // or async.runtime.await_and_resume operations. 37925f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 38025f80e16SEugene Zhulenev 38125f80e16SEugene Zhulenev namespace { 38225f80e16SEugene Zhulenev template <typename AwaitType, typename AwaitableType> 38325f80e16SEugene Zhulenev class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 38425f80e16SEugene Zhulenev using AwaitAdaptor = typename AwaitType::Adaptor; 38525f80e16SEugene Zhulenev 38625f80e16SEugene Zhulenev public: 38739957aa4SEugene Zhulenev AwaitOpLoweringBase(MLIRContext *ctx, 38839957aa4SEugene Zhulenev llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 38925f80e16SEugene Zhulenev : OpConversionPattern<AwaitType>(ctx), 39025f80e16SEugene Zhulenev outlinedFunctions(outlinedFunctions) {} 39125f80e16SEugene Zhulenev 39225f80e16SEugene Zhulenev LogicalResult 393b54c724bSRiver Riddle matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, 39425f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 39525f80e16SEugene Zhulenev // We can only await on one the `AwaitableType` (for `await` it can be 39625f80e16SEugene Zhulenev // a `token` or a `value`, for `await_all` it must be a `group`). 39725f80e16SEugene Zhulenev if (!op.operand().getType().template isa<AwaitableType>()) 39825f80e16SEugene Zhulenev return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 39925f80e16SEugene Zhulenev 40025f80e16SEugene Zhulenev // Check if await operation is inside the outlined coroutine function. 40125f80e16SEugene Zhulenev auto func = op->template getParentOfType<FuncOp>(); 40225f80e16SEugene Zhulenev auto outlined = outlinedFunctions.find(func); 40325f80e16SEugene Zhulenev const bool isInCoroutine = outlined != outlinedFunctions.end(); 40425f80e16SEugene Zhulenev 40525f80e16SEugene Zhulenev Location loc = op->getLoc(); 406b54c724bSRiver Riddle Value operand = adaptor.operand(); 40725f80e16SEugene Zhulenev 408fd52b435SEugene Zhulenev Type i1 = rewriter.getI1Type(); 409fd52b435SEugene Zhulenev 41025f80e16SEugene Zhulenev // Inside regular functions we use the blocking wait operation to wait for 41125f80e16SEugene Zhulenev // the async object (token, value or group) to become available. 412fd52b435SEugene Zhulenev if (!isInCoroutine) { 413fd52b435SEugene Zhulenev ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 414fd52b435SEugene Zhulenev builder.create<RuntimeAwaitOp>(loc, operand); 415fd52b435SEugene Zhulenev 416fd52b435SEugene Zhulenev // Assert that the awaited operands is not in the error state. 417fd52b435SEugene Zhulenev Value isError = builder.create<RuntimeIsErrorOp>(i1, operand); 418a54f4eaeSMogball Value notError = builder.create<arith::XOrIOp>( 419a54f4eaeSMogball isError, builder.create<arith::ConstantOp>( 420a54f4eaeSMogball loc, i1, builder.getIntegerAttr(i1, 1))); 421fd52b435SEugene Zhulenev 422fd52b435SEugene Zhulenev builder.create<AssertOp>(notError, 423fd52b435SEugene Zhulenev "Awaited async operand is in error state"); 424fd52b435SEugene Zhulenev } 42525f80e16SEugene Zhulenev 42625f80e16SEugene Zhulenev // Inside the coroutine we convert await operation into coroutine suspension 42725f80e16SEugene Zhulenev // point, and resume execution asynchronously. 42825f80e16SEugene Zhulenev if (isInCoroutine) { 42939957aa4SEugene Zhulenev CoroMachinery &coro = outlined->getSecond(); 43025f80e16SEugene Zhulenev Block *suspended = op->getBlock(); 43125f80e16SEugene Zhulenev 43225f80e16SEugene Zhulenev ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 43325f80e16SEugene Zhulenev MLIRContext *ctx = op->getContext(); 43425f80e16SEugene Zhulenev 43525f80e16SEugene Zhulenev // Save the coroutine state and resume on a runtime managed thread when 43625f80e16SEugene Zhulenev // the operand becomes available. 43725f80e16SEugene Zhulenev auto coroSaveOp = 43825f80e16SEugene Zhulenev builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 43925f80e16SEugene Zhulenev builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 44025f80e16SEugene Zhulenev 44125f80e16SEugene Zhulenev // Split the entry block before the await operation. 44225f80e16SEugene Zhulenev Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 44325f80e16SEugene Zhulenev 44425f80e16SEugene Zhulenev // Add async.coro.suspend as a suspended block terminator. 44525f80e16SEugene Zhulenev builder.setInsertionPointToEnd(suspended); 44625f80e16SEugene Zhulenev builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 44725f80e16SEugene Zhulenev coro.cleanup); 44825f80e16SEugene Zhulenev 44939957aa4SEugene Zhulenev // Split the resume block into error checking and continuation. 45039957aa4SEugene Zhulenev Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); 45139957aa4SEugene Zhulenev 45239957aa4SEugene Zhulenev // Check if the awaited value is in the error state. 45339957aa4SEugene Zhulenev builder.setInsertionPointToStart(resume); 454fd52b435SEugene Zhulenev auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand); 45539957aa4SEugene Zhulenev builder.create<CondBranchOp>(isError, 45639957aa4SEugene Zhulenev /*trueDest=*/setupSetErrorBlock(coro), 45739957aa4SEugene Zhulenev /*trueArgs=*/ArrayRef<Value>(), 45839957aa4SEugene Zhulenev /*falseDest=*/continuation, 45939957aa4SEugene Zhulenev /*falseArgs=*/ArrayRef<Value>()); 46039957aa4SEugene Zhulenev 46139957aa4SEugene Zhulenev // Make sure that replacement value will be constructed in the 46239957aa4SEugene Zhulenev // continuation block. 46339957aa4SEugene Zhulenev rewriter.setInsertionPointToStart(continuation); 46439957aa4SEugene Zhulenev } 46525f80e16SEugene Zhulenev 46625f80e16SEugene Zhulenev // Erase or replace the await operation with the new value. 46725f80e16SEugene Zhulenev if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 46825f80e16SEugene Zhulenev rewriter.replaceOp(op, replaceWith); 46925f80e16SEugene Zhulenev else 47025f80e16SEugene Zhulenev rewriter.eraseOp(op); 47125f80e16SEugene Zhulenev 47225f80e16SEugene Zhulenev return success(); 47325f80e16SEugene Zhulenev } 47425f80e16SEugene Zhulenev 47525f80e16SEugene Zhulenev virtual Value getReplacementValue(AwaitType op, Value operand, 47625f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const { 47725f80e16SEugene Zhulenev return Value(); 47825f80e16SEugene Zhulenev } 47925f80e16SEugene Zhulenev 48025f80e16SEugene Zhulenev private: 48139957aa4SEugene Zhulenev llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 48225f80e16SEugene Zhulenev }; 48325f80e16SEugene Zhulenev 48425f80e16SEugene Zhulenev /// Lowering for `async.await` with a token operand. 48525f80e16SEugene Zhulenev class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 48625f80e16SEugene Zhulenev using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 48725f80e16SEugene Zhulenev 48825f80e16SEugene Zhulenev public: 48925f80e16SEugene Zhulenev using Base::Base; 49025f80e16SEugene Zhulenev }; 49125f80e16SEugene Zhulenev 49225f80e16SEugene Zhulenev /// Lowering for `async.await` with a value operand. 49325f80e16SEugene Zhulenev class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 49425f80e16SEugene Zhulenev using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 49525f80e16SEugene Zhulenev 49625f80e16SEugene Zhulenev public: 49725f80e16SEugene Zhulenev using Base::Base; 49825f80e16SEugene Zhulenev 49925f80e16SEugene Zhulenev Value 50025f80e16SEugene Zhulenev getReplacementValue(AwaitOp op, Value operand, 50125f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 50225f80e16SEugene Zhulenev // Load from the async value storage. 50325f80e16SEugene Zhulenev auto valueType = operand.getType().cast<ValueType>().getValueType(); 50425f80e16SEugene Zhulenev return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 50525f80e16SEugene Zhulenev } 50625f80e16SEugene Zhulenev }; 50725f80e16SEugene Zhulenev 50825f80e16SEugene Zhulenev /// Lowering for `async.await_all` operation. 50925f80e16SEugene Zhulenev class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 51025f80e16SEugene Zhulenev using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 51125f80e16SEugene Zhulenev 51225f80e16SEugene Zhulenev public: 51325f80e16SEugene Zhulenev using Base::Base; 51425f80e16SEugene Zhulenev }; 51525f80e16SEugene Zhulenev 51625f80e16SEugene Zhulenev } // namespace 51725f80e16SEugene Zhulenev 51825f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 51925f80e16SEugene Zhulenev // Convert async.yield operation to async.runtime operations. 52025f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 52125f80e16SEugene Zhulenev 52225f80e16SEugene Zhulenev class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 52325f80e16SEugene Zhulenev public: 52425f80e16SEugene Zhulenev YieldOpLowering( 52525f80e16SEugene Zhulenev MLIRContext *ctx, 52625f80e16SEugene Zhulenev const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 52725f80e16SEugene Zhulenev : OpConversionPattern<async::YieldOp>(ctx), 52825f80e16SEugene Zhulenev outlinedFunctions(outlinedFunctions) {} 52925f80e16SEugene Zhulenev 53025f80e16SEugene Zhulenev LogicalResult 531b54c724bSRiver Riddle matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 53225f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 53339957aa4SEugene Zhulenev // Check if yield operation is inside the async coroutine function. 53425f80e16SEugene Zhulenev auto func = op->template getParentOfType<FuncOp>(); 53525f80e16SEugene Zhulenev auto outlined = outlinedFunctions.find(func); 53625f80e16SEugene Zhulenev if (outlined == outlinedFunctions.end()) 53725f80e16SEugene Zhulenev return rewriter.notifyMatchFailure( 53839957aa4SEugene Zhulenev op, "operation is not inside the async coroutine function"); 53925f80e16SEugene Zhulenev 54025f80e16SEugene Zhulenev Location loc = op->getLoc(); 54125f80e16SEugene Zhulenev const CoroMachinery &coro = outlined->getSecond(); 54225f80e16SEugene Zhulenev 54325f80e16SEugene Zhulenev // Store yielded values into the async values storage and switch async 54425f80e16SEugene Zhulenev // values state to available. 545b54c724bSRiver Riddle for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { 54625f80e16SEugene Zhulenev Value yieldValue = std::get<0>(tuple); 54725f80e16SEugene Zhulenev Value asyncValue = std::get<1>(tuple); 54825f80e16SEugene Zhulenev rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 54925f80e16SEugene Zhulenev rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 55025f80e16SEugene Zhulenev } 55125f80e16SEugene Zhulenev 55225f80e16SEugene Zhulenev // Switch the coroutine completion token to available state. 55325f80e16SEugene Zhulenev rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 55425f80e16SEugene Zhulenev 55525f80e16SEugene Zhulenev return success(); 55625f80e16SEugene Zhulenev } 55725f80e16SEugene Zhulenev 55825f80e16SEugene Zhulenev private: 55925f80e16SEugene Zhulenev const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 56025f80e16SEugene Zhulenev }; 56125f80e16SEugene Zhulenev 56225f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 56339957aa4SEugene Zhulenev // Convert std.assert operation to cond_br into `set_error` block. 56439957aa4SEugene Zhulenev //===----------------------------------------------------------------------===// 56539957aa4SEugene Zhulenev 56639957aa4SEugene Zhulenev class AssertOpLowering : public OpConversionPattern<AssertOp> { 56739957aa4SEugene Zhulenev public: 56839957aa4SEugene Zhulenev AssertOpLowering(MLIRContext *ctx, 56939957aa4SEugene Zhulenev llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 57039957aa4SEugene Zhulenev : OpConversionPattern<AssertOp>(ctx), 57139957aa4SEugene Zhulenev outlinedFunctions(outlinedFunctions) {} 57239957aa4SEugene Zhulenev 57339957aa4SEugene Zhulenev LogicalResult 574b54c724bSRiver Riddle matchAndRewrite(AssertOp op, OpAdaptor adaptor, 57539957aa4SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 57639957aa4SEugene Zhulenev // Check if assert operation is inside the async coroutine function. 57739957aa4SEugene Zhulenev auto func = op->template getParentOfType<FuncOp>(); 57839957aa4SEugene Zhulenev auto outlined = outlinedFunctions.find(func); 57939957aa4SEugene Zhulenev if (outlined == outlinedFunctions.end()) 58039957aa4SEugene Zhulenev return rewriter.notifyMatchFailure( 58139957aa4SEugene Zhulenev op, "operation is not inside the async coroutine function"); 58239957aa4SEugene Zhulenev 58339957aa4SEugene Zhulenev Location loc = op->getLoc(); 58439957aa4SEugene Zhulenev CoroMachinery &coro = outlined->getSecond(); 58539957aa4SEugene Zhulenev 58639957aa4SEugene Zhulenev Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); 58739957aa4SEugene Zhulenev rewriter.setInsertionPointToEnd(cont->getPrevNode()); 588cfb72fd3SJacques Pienaar rewriter.create<CondBranchOp>(loc, adaptor.getArg(), 58939957aa4SEugene Zhulenev /*trueDest=*/cont, 59039957aa4SEugene Zhulenev /*trueArgs=*/ArrayRef<Value>(), 59139957aa4SEugene Zhulenev /*falseDest=*/setupSetErrorBlock(coro), 59239957aa4SEugene Zhulenev /*falseArgs=*/ArrayRef<Value>()); 59339957aa4SEugene Zhulenev rewriter.eraseOp(op); 59439957aa4SEugene Zhulenev 59539957aa4SEugene Zhulenev return success(); 59639957aa4SEugene Zhulenev } 59739957aa4SEugene Zhulenev 59839957aa4SEugene Zhulenev private: 59939957aa4SEugene Zhulenev llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 60039957aa4SEugene Zhulenev }; 60139957aa4SEugene Zhulenev 60239957aa4SEugene Zhulenev //===----------------------------------------------------------------------===// 60325f80e16SEugene Zhulenev 6046ea22d46Sbakhtiyar /// Rewrite a func as a coroutine by: 6056ea22d46Sbakhtiyar /// 1) Wrapping the results into `async.value`. 6066ea22d46Sbakhtiyar /// 2) Prepending the results with `async.token`. 6076ea22d46Sbakhtiyar /// 3) Setting up coroutine blocks. 6086ea22d46Sbakhtiyar /// 4) Rewriting return ops as yield op and branch op into the suspend block. 6096ea22d46Sbakhtiyar static CoroMachinery rewriteFuncAsCoroutine(FuncOp func) { 6106ea22d46Sbakhtiyar auto *ctx = func->getContext(); 6116ea22d46Sbakhtiyar auto loc = func.getLoc(); 6126ea22d46Sbakhtiyar SmallVector<Type> resultTypes; 6136ea22d46Sbakhtiyar resultTypes.reserve(func.getCallableResults().size()); 6146ea22d46Sbakhtiyar llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes), 6156ea22d46Sbakhtiyar [](Type type) { return ValueType::get(type); }); 6166ea22d46Sbakhtiyar func.setType(FunctionType::get(ctx, func.getType().getInputs(), resultTypes)); 6176ea22d46Sbakhtiyar func.insertResult(0, TokenType::get(ctx), {}); 6186ea22d46Sbakhtiyar for (Block &block : func.getBlocks()) { 6196ea22d46Sbakhtiyar Operation *terminator = block.getTerminator(); 6206ea22d46Sbakhtiyar if (auto returnOp = dyn_cast<ReturnOp>(*terminator)) { 6216ea22d46Sbakhtiyar ImplicitLocOpBuilder builder(loc, returnOp); 6226ea22d46Sbakhtiyar builder.create<YieldOp>(returnOp.getOperands()); 6236ea22d46Sbakhtiyar returnOp.erase(); 6246ea22d46Sbakhtiyar } 6256ea22d46Sbakhtiyar } 6261c144410Sbakhtiyar return setupCoroMachinery(func); 6276ea22d46Sbakhtiyar } 6286ea22d46Sbakhtiyar 6296ea22d46Sbakhtiyar /// Rewrites a call into a function that has been rewritten as a coroutine. 6306ea22d46Sbakhtiyar /// 6316ea22d46Sbakhtiyar /// The invocation of this function is safe only when call ops are traversed in 6326ea22d46Sbakhtiyar /// reverse order of how they appear in a single block. See `funcsToCoroutines`. 6336ea22d46Sbakhtiyar static void rewriteCallsiteForCoroutine(CallOp oldCall, FuncOp func) { 6346ea22d46Sbakhtiyar auto loc = func.getLoc(); 6356ea22d46Sbakhtiyar ImplicitLocOpBuilder callBuilder(loc, oldCall); 6366ea22d46Sbakhtiyar auto newCall = callBuilder.create<CallOp>( 6376ea22d46Sbakhtiyar func.getName(), func.getCallableResults(), oldCall.getArgOperands()); 6386ea22d46Sbakhtiyar 6396ea22d46Sbakhtiyar // Await on the async token and all the value results and unwrap the latter. 6406ea22d46Sbakhtiyar callBuilder.create<AwaitOp>(loc, newCall.getResults().front()); 6416ea22d46Sbakhtiyar SmallVector<Value> unwrappedResults; 6426ea22d46Sbakhtiyar unwrappedResults.reserve(newCall->getResults().size() - 1); 6436ea22d46Sbakhtiyar for (Value result : newCall.getResults().drop_front()) 6446ea22d46Sbakhtiyar unwrappedResults.push_back( 6456ea22d46Sbakhtiyar callBuilder.create<AwaitOp>(loc, result).result()); 6466ea22d46Sbakhtiyar // Careful, when result of a call is piped into another call this could lead 6476ea22d46Sbakhtiyar // to a dangling pointer. 6486ea22d46Sbakhtiyar oldCall.replaceAllUsesWith(unwrappedResults); 6496ea22d46Sbakhtiyar oldCall.erase(); 6506ea22d46Sbakhtiyar } 6516ea22d46Sbakhtiyar 6529a5bc836Sbakhtiyar static bool isAllowedToBlock(FuncOp func) { 6539a5bc836Sbakhtiyar return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName); 6549a5bc836Sbakhtiyar } 6559a5bc836Sbakhtiyar 6566ea22d46Sbakhtiyar static LogicalResult 6576ea22d46Sbakhtiyar funcsToCoroutines(ModuleOp module, 6586ea22d46Sbakhtiyar llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) { 6596ea22d46Sbakhtiyar // The following code supports the general case when 2 functions mutually 6606ea22d46Sbakhtiyar // recurse into each other. Because of this and that we are relying on 6616ea22d46Sbakhtiyar // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase 6626ea22d46Sbakhtiyar // a FuncOp while inserting an equivalent coroutine, because that could lead 6636ea22d46Sbakhtiyar // to dangling pointers. 6646ea22d46Sbakhtiyar 6656ea22d46Sbakhtiyar SmallVector<FuncOp> funcWorklist; 6666ea22d46Sbakhtiyar 6676ea22d46Sbakhtiyar // Careful, it's okay to add a func to the worklist multiple times if and only 6686ea22d46Sbakhtiyar // if the loop processing the worklist will skip the functions that have 6696ea22d46Sbakhtiyar // already been converted to coroutines. 6709a5bc836Sbakhtiyar auto addToWorklist = [&](FuncOp func) { 6719a5bc836Sbakhtiyar if (isAllowedToBlock(func)) 6729a5bc836Sbakhtiyar return; 6736ea22d46Sbakhtiyar // N.B. To refactor this code into a separate pass the lookup in 6746ea22d46Sbakhtiyar // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary 6756ea22d46Sbakhtiyar // func and recognizing if it has a coroutine structure is messy. Passing 6766ea22d46Sbakhtiyar // this dict between the passes is ugly. 6779a5bc836Sbakhtiyar if (isAllowedToBlock(func) || 6789a5bc836Sbakhtiyar outlinedFunctions.find(func) == outlinedFunctions.end()) { 6796ea22d46Sbakhtiyar for (Operation &op : func.body().getOps()) { 6806ea22d46Sbakhtiyar if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) { 6816ea22d46Sbakhtiyar funcWorklist.push_back(func); 6826ea22d46Sbakhtiyar break; 6836ea22d46Sbakhtiyar } 6846ea22d46Sbakhtiyar } 6856ea22d46Sbakhtiyar } 6866ea22d46Sbakhtiyar }; 6876ea22d46Sbakhtiyar 6886ea22d46Sbakhtiyar // Traverse in post-order collecting for each func op the await ops it has. 6896ea22d46Sbakhtiyar for (FuncOp func : module.getOps<FuncOp>()) 6906ea22d46Sbakhtiyar addToWorklist(func); 6916ea22d46Sbakhtiyar 6926ea22d46Sbakhtiyar SymbolTableCollection symbolTable; 6936ea22d46Sbakhtiyar SymbolUserMap symbolUserMap(symbolTable, module); 6946ea22d46Sbakhtiyar 6956ea22d46Sbakhtiyar // Rewrite funcs, while updating call sites and adding them to the worklist. 6966ea22d46Sbakhtiyar while (!funcWorklist.empty()) { 6976ea22d46Sbakhtiyar auto func = funcWorklist.pop_back_val(); 6986ea22d46Sbakhtiyar auto insertion = outlinedFunctions.insert({func, CoroMachinery{}}); 6996ea22d46Sbakhtiyar if (!insertion.second) 7006ea22d46Sbakhtiyar // This function has already been processed because this is either 7016ea22d46Sbakhtiyar // the corecursive case, or a caller with multiple calls to a newly 7026ea22d46Sbakhtiyar // created corouting. Either way, skip updating the call sites. 7036ea22d46Sbakhtiyar continue; 7046ea22d46Sbakhtiyar insertion.first->second = rewriteFuncAsCoroutine(func); 7056ea22d46Sbakhtiyar SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(), 7066ea22d46Sbakhtiyar symbolUserMap.getUsers(func).end()); 7076ea22d46Sbakhtiyar // If there are multiple calls from the same block they need to be traversed 7086ea22d46Sbakhtiyar // in reverse order so that symbolUserMap references are not invalidated 7096ea22d46Sbakhtiyar // when updating the users of the call op which is earlier in the block. 7106ea22d46Sbakhtiyar llvm::sort(users, [](Operation *a, Operation *b) { 7116ea22d46Sbakhtiyar Block *blockA = a->getBlock(); 7126ea22d46Sbakhtiyar Block *blockB = b->getBlock(); 7136ea22d46Sbakhtiyar // Impose arbitrary order on blocks so that there is a well-defined order. 7146ea22d46Sbakhtiyar return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b)); 7156ea22d46Sbakhtiyar }); 7166ea22d46Sbakhtiyar // Rewrite the callsites to await on results of the newly created coroutine. 7176ea22d46Sbakhtiyar for (Operation *op : users) { 7186ea22d46Sbakhtiyar if (CallOp call = dyn_cast<mlir::CallOp>(*op)) { 7196ea22d46Sbakhtiyar FuncOp caller = call->getParentOfType<FuncOp>(); 7206ea22d46Sbakhtiyar rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op. 7216ea22d46Sbakhtiyar addToWorklist(caller); 7226ea22d46Sbakhtiyar } else { 7236ea22d46Sbakhtiyar op->emitError("Unexpected reference to func referenced by symbol"); 7246ea22d46Sbakhtiyar return failure(); 7256ea22d46Sbakhtiyar } 7266ea22d46Sbakhtiyar } 7276ea22d46Sbakhtiyar } 7286ea22d46Sbakhtiyar return success(); 7296ea22d46Sbakhtiyar } 7306ea22d46Sbakhtiyar 7316ea22d46Sbakhtiyar //===----------------------------------------------------------------------===// 73225f80e16SEugene Zhulenev void AsyncToAsyncRuntimePass::runOnOperation() { 73325f80e16SEugene Zhulenev ModuleOp module = getOperation(); 73425f80e16SEugene Zhulenev SymbolTable symbolTable(module); 73525f80e16SEugene Zhulenev 73625f80e16SEugene Zhulenev // Outline all `async.execute` body regions into async functions (coroutines). 73725f80e16SEugene Zhulenev llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 73825f80e16SEugene Zhulenev 73925f80e16SEugene Zhulenev module.walk([&](ExecuteOp execute) { 74025f80e16SEugene Zhulenev outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 74125f80e16SEugene Zhulenev }); 74225f80e16SEugene Zhulenev 74325f80e16SEugene Zhulenev LLVM_DEBUG({ 74425f80e16SEugene Zhulenev llvm::dbgs() << "Outlined " << outlinedFunctions.size() 74525f80e16SEugene Zhulenev << " functions built from async.execute operations\n"; 74625f80e16SEugene Zhulenev }); 74725f80e16SEugene Zhulenev 748de7a4e53SEugene Zhulenev // Returns true if operation is inside the coroutine. 749de7a4e53SEugene Zhulenev auto isInCoroutine = [&](Operation *op) -> bool { 750de7a4e53SEugene Zhulenev auto parentFunc = op->getParentOfType<FuncOp>(); 751de7a4e53SEugene Zhulenev return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); 752de7a4e53SEugene Zhulenev }; 753de7a4e53SEugene Zhulenev 7546ea22d46Sbakhtiyar if (eliminateBlockingAwaitOps && 7556ea22d46Sbakhtiyar failed(funcsToCoroutines(module, outlinedFunctions))) { 7566ea22d46Sbakhtiyar signalPassFailure(); 7576ea22d46Sbakhtiyar return; 7586ea22d46Sbakhtiyar } 7596ea22d46Sbakhtiyar 76025f80e16SEugene Zhulenev // Lower async operations to async.runtime operations. 76125f80e16SEugene Zhulenev MLIRContext *ctx = module->getContext(); 762dc4e913bSChris Lattner RewritePatternSet asyncPatterns(ctx); 76325f80e16SEugene Zhulenev 764de7a4e53SEugene Zhulenev // Conversion to async runtime augments original CFG with the coroutine CFG, 765de7a4e53SEugene Zhulenev // and we have to make sure that structured control flow operations with async 766de7a4e53SEugene Zhulenev // operations in nested regions will be converted to branch-based control flow 767de7a4e53SEugene Zhulenev // before we add the coroutine basic blocks. 768de7a4e53SEugene Zhulenev populateLoopToStdConversionPatterns(asyncPatterns); 769de7a4e53SEugene Zhulenev 77025f80e16SEugene Zhulenev // Async lowering does not use type converter because it must preserve all 77125f80e16SEugene Zhulenev // types for async.runtime operations. 772dc4e913bSChris Lattner asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 773dc4e913bSChris Lattner asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 77425f80e16SEugene Zhulenev AwaitAllOpLowering, YieldOpLowering>(ctx, 77525f80e16SEugene Zhulenev outlinedFunctions); 77625f80e16SEugene Zhulenev 77739957aa4SEugene Zhulenev // Lower assertions to conditional branches into error blocks. 77839957aa4SEugene Zhulenev asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions); 77939957aa4SEugene Zhulenev 78025f80e16SEugene Zhulenev // All high level async operations must be lowered to the runtime operations. 78125f80e16SEugene Zhulenev ConversionTarget runtimeTarget(*ctx); 78225f80e16SEugene Zhulenev runtimeTarget.addLegalDialect<AsyncDialect>(); 78325f80e16SEugene Zhulenev runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 78425f80e16SEugene Zhulenev runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 78525f80e16SEugene Zhulenev 786de7a4e53SEugene Zhulenev // Decide if structured control flow has to be lowered to branch-based CFG. 787de7a4e53SEugene Zhulenev runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) { 788de7a4e53SEugene Zhulenev auto walkResult = op->walk([&](Operation *nested) { 789de7a4e53SEugene Zhulenev bool isAsync = isa<async::AsyncDialect>(nested->getDialect()); 790de7a4e53SEugene Zhulenev return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() 791de7a4e53SEugene Zhulenev : WalkResult::advance(); 792de7a4e53SEugene Zhulenev }); 793de7a4e53SEugene Zhulenev return !walkResult.wasInterrupted(); 794de7a4e53SEugene Zhulenev }); 795a54f4eaeSMogball runtimeTarget.addLegalOp<AssertOp, arith::XOrIOp, arith::ConstantOp, 796a54f4eaeSMogball ConstantOp, BranchOp, CondBranchOp>(); 797de7a4e53SEugene Zhulenev 7988f23fac4SEugene Zhulenev // Assertions must be converted to runtime errors inside async functions. 7998f23fac4SEugene Zhulenev runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool { 8008f23fac4SEugene Zhulenev auto func = op->getParentOfType<FuncOp>(); 8018f23fac4SEugene Zhulenev return outlinedFunctions.find(func) == outlinedFunctions.end(); 8028f23fac4SEugene Zhulenev }); 80339957aa4SEugene Zhulenev 8046ea22d46Sbakhtiyar if (eliminateBlockingAwaitOps) 8059a5bc836Sbakhtiyar runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>( 8069a5bc836Sbakhtiyar [&](RuntimeAwaitOp op) -> bool { 8079a5bc836Sbakhtiyar return isAllowedToBlock(op->getParentOfType<FuncOp>()); 8089a5bc836Sbakhtiyar }); 8096ea22d46Sbakhtiyar 81025f80e16SEugene Zhulenev if (failed(applyPartialConversion(module, runtimeTarget, 81125f80e16SEugene Zhulenev std::move(asyncPatterns)))) { 81225f80e16SEugene Zhulenev signalPassFailure(); 81325f80e16SEugene Zhulenev return; 81425f80e16SEugene Zhulenev } 81525f80e16SEugene Zhulenev } 81625f80e16SEugene Zhulenev 81725f80e16SEugene Zhulenev std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 81825f80e16SEugene Zhulenev return std::make_unique<AsyncToAsyncRuntimePass>(); 81925f80e16SEugene Zhulenev } 820