125f80e16SEugene Zhulenev //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// 225f80e16SEugene Zhulenev // 325f80e16SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 425f80e16SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 525f80e16SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 625f80e16SEugene Zhulenev // 725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 825f80e16SEugene Zhulenev // 925f80e16SEugene Zhulenev // This file implements lowering from high level async operations to async.coro 1025f80e16SEugene Zhulenev // and async.runtime operations. 1125f80e16SEugene Zhulenev // 1225f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 1325f80e16SEugene Zhulenev 1425f80e16SEugene Zhulenev #include "PassDetail.h" 1525f80e16SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h" 1625f80e16SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h" 1725f80e16SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h" 1825f80e16SEugene Zhulenev #include "mlir/IR/BlockAndValueMapping.h" 1925f80e16SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h" 2025f80e16SEugene Zhulenev #include "mlir/IR/PatternMatch.h" 2125f80e16SEugene Zhulenev #include "mlir/Transforms/DialectConversion.h" 2225f80e16SEugene Zhulenev #include "mlir/Transforms/RegionUtils.h" 2325f80e16SEugene Zhulenev #include "llvm/ADT/SetVector.h" 2425f80e16SEugene Zhulenev 2525f80e16SEugene Zhulenev using namespace mlir; 2625f80e16SEugene Zhulenev using namespace mlir::async; 2725f80e16SEugene Zhulenev 2825f80e16SEugene Zhulenev #define DEBUG_TYPE "async-to-async-runtime" 2925f80e16SEugene Zhulenev // Prefix for functions outlined from `async.execute` op regions. 3025f80e16SEugene Zhulenev static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; 3125f80e16SEugene Zhulenev 3225f80e16SEugene Zhulenev namespace { 3325f80e16SEugene Zhulenev 3425f80e16SEugene Zhulenev class AsyncToAsyncRuntimePass 3525f80e16SEugene Zhulenev : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> { 3625f80e16SEugene Zhulenev public: 3725f80e16SEugene Zhulenev AsyncToAsyncRuntimePass() = default; 3825f80e16SEugene Zhulenev void runOnOperation() override; 3925f80e16SEugene Zhulenev }; 4025f80e16SEugene Zhulenev 4125f80e16SEugene Zhulenev } // namespace 4225f80e16SEugene Zhulenev 4325f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 4425f80e16SEugene Zhulenev // async.execute op outlining to the coroutine functions. 4525f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 4625f80e16SEugene Zhulenev 4725f80e16SEugene Zhulenev /// Function targeted for coroutine transformation has two additional blocks at 4825f80e16SEugene Zhulenev /// the end: coroutine cleanup and coroutine suspension. 4925f80e16SEugene Zhulenev /// 5025f80e16SEugene Zhulenev /// async.await op lowering additionaly creates a resume block for each 5125f80e16SEugene Zhulenev /// operation to enable non-blocking waiting via coroutine suspension. 5225f80e16SEugene Zhulenev namespace { 5325f80e16SEugene Zhulenev struct CoroMachinery { 5425f80e16SEugene Zhulenev // Async execute region returns a completion token, and an async value for 5525f80e16SEugene Zhulenev // each yielded value. 5625f80e16SEugene Zhulenev // 5725f80e16SEugene Zhulenev // %token, %result = async.execute -> !async.value<T> { 5825f80e16SEugene Zhulenev // %0 = constant ... : T 5925f80e16SEugene Zhulenev // async.yield %0 : T 6025f80e16SEugene Zhulenev // } 6125f80e16SEugene Zhulenev Value asyncToken; // token representing completion of the async region 6225f80e16SEugene Zhulenev llvm::SmallVector<Value, 4> returnValues; // returned async values 6325f80e16SEugene Zhulenev 6425f80e16SEugene Zhulenev Value coroHandle; // coroutine handle (!async.coro.handle value) 6525f80e16SEugene Zhulenev Block *cleanup; // coroutine cleanup block 6625f80e16SEugene Zhulenev Block *suspend; // coroutine suspension block 6725f80e16SEugene Zhulenev }; 6825f80e16SEugene Zhulenev } // namespace 6925f80e16SEugene Zhulenev 7025f80e16SEugene Zhulenev /// Builds an coroutine template compatible with LLVM coroutines switched-resume 7125f80e16SEugene Zhulenev /// lowering using `async.runtime.*` and `async.coro.*` operations. 7225f80e16SEugene Zhulenev /// 7325f80e16SEugene Zhulenev /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html 7425f80e16SEugene Zhulenev /// 7525f80e16SEugene Zhulenev /// - `entry` block sets up the coroutine. 7625f80e16SEugene Zhulenev /// - `cleanup` block cleans up the coroutine state. 7725f80e16SEugene Zhulenev /// - `suspend block after the @llvm.coro.end() defines what value will be 7825f80e16SEugene Zhulenev /// returned to the initial caller of a coroutine. Everything before the 7925f80e16SEugene Zhulenev /// @llvm.coro.end() will be executed at every suspension point. 8025f80e16SEugene Zhulenev /// 8125f80e16SEugene Zhulenev /// Coroutine structure (only the important bits): 8225f80e16SEugene Zhulenev /// 8325f80e16SEugene Zhulenev /// func @async_execute_fn(<function-arguments>) 8425f80e16SEugene Zhulenev /// -> (!async.token, !async.value<T>) 8525f80e16SEugene Zhulenev /// { 8625f80e16SEugene Zhulenev /// ^entry(<function-arguments>): 8725f80e16SEugene Zhulenev /// %token = <async token> : !async.token // create async runtime token 8825f80e16SEugene Zhulenev /// %value = <async value> : !async.value<T> // create async value 8925f80e16SEugene Zhulenev /// %id = async.coro.id // create a coroutine id 9025f80e16SEugene Zhulenev /// %hdl = async.coro.begin %id // create a coroutine handle 9125f80e16SEugene Zhulenev /// br ^cleanup 9225f80e16SEugene Zhulenev /// 9325f80e16SEugene Zhulenev /// ^cleanup: 9425f80e16SEugene Zhulenev /// async.coro.free %hdl // delete the coroutine state 9525f80e16SEugene Zhulenev /// br ^suspend 9625f80e16SEugene Zhulenev /// 9725f80e16SEugene Zhulenev /// ^suspend: 9825f80e16SEugene Zhulenev /// async.coro.end %hdl // marks the end of a coroutine 9925f80e16SEugene Zhulenev /// return %token, %value : !async.token, !async.value<T> 10025f80e16SEugene Zhulenev /// } 10125f80e16SEugene Zhulenev /// 10225f80e16SEugene Zhulenev /// The actual code for the async.execute operation body region will be inserted 10325f80e16SEugene Zhulenev /// before the entry block terminator. 10425f80e16SEugene Zhulenev /// 10525f80e16SEugene Zhulenev /// 10625f80e16SEugene Zhulenev static CoroMachinery setupCoroMachinery(FuncOp func) { 10725f80e16SEugene Zhulenev assert(func.getBody().empty() && "Function must have empty body"); 10825f80e16SEugene Zhulenev 10925f80e16SEugene Zhulenev MLIRContext *ctx = func.getContext(); 11025f80e16SEugene Zhulenev Block *entryBlock = func.addEntryBlock(); 11125f80e16SEugene Zhulenev 11225f80e16SEugene Zhulenev auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); 11325f80e16SEugene Zhulenev 11425f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 11525f80e16SEugene Zhulenev // Allocate async token/values that we will return from a ramp function. 11625f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 11725f80e16SEugene Zhulenev auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result(); 11825f80e16SEugene Zhulenev 11925f80e16SEugene Zhulenev llvm::SmallVector<Value, 4> retValues; 12025f80e16SEugene Zhulenev for (auto resType : func.getCallableResults().drop_front()) 12125f80e16SEugene Zhulenev retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result()); 12225f80e16SEugene Zhulenev 12325f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 12425f80e16SEugene Zhulenev // Initialize coroutine: get coroutine id and coroutine handle. 12525f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 12625f80e16SEugene Zhulenev auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); 12725f80e16SEugene Zhulenev auto coroHdlOp = 12825f80e16SEugene Zhulenev builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id()); 12925f80e16SEugene Zhulenev 13025f80e16SEugene Zhulenev Block *cleanupBlock = func.addBlock(); 13125f80e16SEugene Zhulenev Block *suspendBlock = func.addBlock(); 13225f80e16SEugene Zhulenev 13325f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 13425f80e16SEugene Zhulenev // Coroutine cleanup block: deallocate coroutine frame, free the memory. 13525f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 13625f80e16SEugene Zhulenev builder.setInsertionPointToStart(cleanupBlock); 13725f80e16SEugene Zhulenev builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle()); 13825f80e16SEugene Zhulenev 13925f80e16SEugene Zhulenev // Branch into the suspend block. 14025f80e16SEugene Zhulenev builder.create<BranchOp>(suspendBlock); 14125f80e16SEugene Zhulenev 14225f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 14325f80e16SEugene Zhulenev // Coroutine suspend block: mark the end of a coroutine and return allocated 14425f80e16SEugene Zhulenev // async token. 14525f80e16SEugene Zhulenev // ------------------------------------------------------------------------ // 14625f80e16SEugene Zhulenev builder.setInsertionPointToStart(suspendBlock); 14725f80e16SEugene Zhulenev 14825f80e16SEugene Zhulenev // Mark the end of a coroutine: async.coro.end 14925f80e16SEugene Zhulenev builder.create<CoroEndOp>(coroHdlOp.handle()); 15025f80e16SEugene Zhulenev 15125f80e16SEugene Zhulenev // Return created `async.token` and `async.values` from the suspend block. 15225f80e16SEugene Zhulenev // This will be the return value of a coroutine ramp function. 15325f80e16SEugene Zhulenev SmallVector<Value, 4> ret{retToken}; 15425f80e16SEugene Zhulenev ret.insert(ret.end(), retValues.begin(), retValues.end()); 15525f80e16SEugene Zhulenev builder.create<ReturnOp>(ret); 15625f80e16SEugene Zhulenev 15725f80e16SEugene Zhulenev // Branch from the entry block to the cleanup block to create a valid CFG. 15825f80e16SEugene Zhulenev builder.setInsertionPointToEnd(entryBlock); 15925f80e16SEugene Zhulenev builder.create<BranchOp>(cleanupBlock); 16025f80e16SEugene Zhulenev 16125f80e16SEugene Zhulenev // `async.await` op lowering will create resume blocks for async 16225f80e16SEugene Zhulenev // continuations, and will conditionally branch to cleanup or suspend blocks. 16325f80e16SEugene Zhulenev 16425f80e16SEugene Zhulenev CoroMachinery machinery; 16525f80e16SEugene Zhulenev machinery.asyncToken = retToken; 16625f80e16SEugene Zhulenev machinery.returnValues = retValues; 16725f80e16SEugene Zhulenev machinery.coroHandle = coroHdlOp.handle(); 16825f80e16SEugene Zhulenev machinery.cleanup = cleanupBlock; 16925f80e16SEugene Zhulenev machinery.suspend = suspendBlock; 17025f80e16SEugene Zhulenev return machinery; 17125f80e16SEugene Zhulenev } 17225f80e16SEugene Zhulenev 17325f80e16SEugene Zhulenev /// Outline the body region attached to the `async.execute` op into a standalone 17425f80e16SEugene Zhulenev /// function. 17525f80e16SEugene Zhulenev /// 17625f80e16SEugene Zhulenev /// Note that this is not reversible transformation. 17725f80e16SEugene Zhulenev static std::pair<FuncOp, CoroMachinery> 17825f80e16SEugene Zhulenev outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { 17925f80e16SEugene Zhulenev ModuleOp module = execute->getParentOfType<ModuleOp>(); 18025f80e16SEugene Zhulenev 18125f80e16SEugene Zhulenev MLIRContext *ctx = module.getContext(); 18225f80e16SEugene Zhulenev Location loc = execute.getLoc(); 18325f80e16SEugene Zhulenev 18425f80e16SEugene Zhulenev // Collect all outlined function inputs. 18525f80e16SEugene Zhulenev llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(), 18625f80e16SEugene Zhulenev execute.dependencies().end()); 18725f80e16SEugene Zhulenev functionInputs.insert(execute.operands().begin(), execute.operands().end()); 18825f80e16SEugene Zhulenev getUsedValuesDefinedAbove(execute.body(), functionInputs); 18925f80e16SEugene Zhulenev 19025f80e16SEugene Zhulenev // Collect types for the outlined function inputs and outputs. 19125f80e16SEugene Zhulenev auto typesRange = llvm::map_range( 19225f80e16SEugene Zhulenev functionInputs, [](Value value) { return value.getType(); }); 19325f80e16SEugene Zhulenev SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end()); 19425f80e16SEugene Zhulenev auto outputTypes = execute.getResultTypes(); 19525f80e16SEugene Zhulenev 19625f80e16SEugene Zhulenev auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); 19725f80e16SEugene Zhulenev auto funcAttrs = ArrayRef<NamedAttribute>(); 19825f80e16SEugene Zhulenev 19925f80e16SEugene Zhulenev // TODO: Derive outlined function name from the parent FuncOp (support 20025f80e16SEugene Zhulenev // multiple nested async.execute operations). 20125f80e16SEugene Zhulenev FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); 20225f80e16SEugene Zhulenev symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); 20325f80e16SEugene Zhulenev 20425f80e16SEugene Zhulenev SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); 20525f80e16SEugene Zhulenev 20625f80e16SEugene Zhulenev // Prepare a function for coroutine lowering by adding entry/cleanup/suspend 20725f80e16SEugene Zhulenev // blocks, adding async.coro operations and setting up control flow. 20825f80e16SEugene Zhulenev CoroMachinery coro = setupCoroMachinery(func); 20925f80e16SEugene Zhulenev 21025f80e16SEugene Zhulenev // Suspend async function at the end of an entry block, and resume it using 21125f80e16SEugene Zhulenev // Async resume operation (execution will be resumed in a thread managed by 21225f80e16SEugene Zhulenev // the async runtime). 21325f80e16SEugene Zhulenev Block *entryBlock = &func.getBlocks().front(); 21425f80e16SEugene Zhulenev auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); 21525f80e16SEugene Zhulenev 21625f80e16SEugene Zhulenev // Save the coroutine state: async.coro.save 21725f80e16SEugene Zhulenev auto coroSaveOp = 21825f80e16SEugene Zhulenev builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 21925f80e16SEugene Zhulenev 22025f80e16SEugene Zhulenev // Pass coroutine to the runtime to be resumed on a runtime managed thread. 22125f80e16SEugene Zhulenev builder.create<RuntimeResumeOp>(coro.coroHandle); 22225f80e16SEugene Zhulenev 22325f80e16SEugene Zhulenev // Split the entry block before the terminator (branch to suspend block). 22425f80e16SEugene Zhulenev auto *terminatorOp = entryBlock->getTerminator(); 22525f80e16SEugene Zhulenev Block *suspended = terminatorOp->getBlock(); 22625f80e16SEugene Zhulenev Block *resume = suspended->splitBlock(terminatorOp); 22725f80e16SEugene Zhulenev 22825f80e16SEugene Zhulenev // Add async.coro.suspend as a suspended block terminator. 22925f80e16SEugene Zhulenev builder.setInsertionPointToEnd(suspended); 23025f80e16SEugene Zhulenev builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 23125f80e16SEugene Zhulenev coro.cleanup); 23225f80e16SEugene Zhulenev 23325f80e16SEugene Zhulenev size_t numDependencies = execute.dependencies().size(); 23425f80e16SEugene Zhulenev size_t numOperands = execute.operands().size(); 23525f80e16SEugene Zhulenev 23625f80e16SEugene Zhulenev // Await on all dependencies before starting to execute the body region. 23725f80e16SEugene Zhulenev builder.setInsertionPointToStart(resume); 23825f80e16SEugene Zhulenev for (size_t i = 0; i < numDependencies; ++i) 23925f80e16SEugene Zhulenev builder.create<AwaitOp>(func.getArgument(i)); 24025f80e16SEugene Zhulenev 24125f80e16SEugene Zhulenev // Await on all async value operands and unwrap the payload. 24225f80e16SEugene Zhulenev SmallVector<Value, 4> unwrappedOperands(numOperands); 24325f80e16SEugene Zhulenev for (size_t i = 0; i < numOperands; ++i) { 24425f80e16SEugene Zhulenev Value operand = func.getArgument(numDependencies + i); 24525f80e16SEugene Zhulenev unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result(); 24625f80e16SEugene Zhulenev } 24725f80e16SEugene Zhulenev 24825f80e16SEugene Zhulenev // Map from function inputs defined above the execute op to the function 24925f80e16SEugene Zhulenev // arguments. 25025f80e16SEugene Zhulenev BlockAndValueMapping valueMapping; 25125f80e16SEugene Zhulenev valueMapping.map(functionInputs, func.getArguments()); 25225f80e16SEugene Zhulenev valueMapping.map(execute.body().getArguments(), unwrappedOperands); 25325f80e16SEugene Zhulenev 25425f80e16SEugene Zhulenev // Clone all operations from the execute operation body into the outlined 25525f80e16SEugene Zhulenev // function body. 25625f80e16SEugene Zhulenev for (Operation &op : execute.body().getOps()) 25725f80e16SEugene Zhulenev builder.clone(op, valueMapping); 25825f80e16SEugene Zhulenev 25925f80e16SEugene Zhulenev // Replace the original `async.execute` with a call to outlined function. 26025f80e16SEugene Zhulenev ImplicitLocOpBuilder callBuilder(loc, execute); 26125f80e16SEugene Zhulenev auto callOutlinedFunc = callBuilder.create<CallOp>( 26225f80e16SEugene Zhulenev func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); 26325f80e16SEugene Zhulenev execute.replaceAllUsesWith(callOutlinedFunc.getResults()); 26425f80e16SEugene Zhulenev execute.erase(); 26525f80e16SEugene Zhulenev 26625f80e16SEugene Zhulenev return {func, coro}; 26725f80e16SEugene Zhulenev } 26825f80e16SEugene Zhulenev 26925f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 27025f80e16SEugene Zhulenev // Convert async.create_group operation to async.runtime.create 27125f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 27225f80e16SEugene Zhulenev 27325f80e16SEugene Zhulenev namespace { 27425f80e16SEugene Zhulenev class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> { 27525f80e16SEugene Zhulenev public: 27625f80e16SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 27725f80e16SEugene Zhulenev 27825f80e16SEugene Zhulenev LogicalResult 27925f80e16SEugene Zhulenev matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands, 28025f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 28125f80e16SEugene Zhulenev rewriter.replaceOpWithNewOp<RuntimeCreateOp>( 28225f80e16SEugene Zhulenev op, GroupType::get(op->getContext())); 28325f80e16SEugene Zhulenev return success(); 28425f80e16SEugene Zhulenev } 28525f80e16SEugene Zhulenev }; 28625f80e16SEugene Zhulenev } // namespace 28725f80e16SEugene Zhulenev 28825f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 28925f80e16SEugene Zhulenev // Convert async.add_to_group operation to async.runtime.add_to_group. 29025f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 29125f80e16SEugene Zhulenev 29225f80e16SEugene Zhulenev namespace { 29325f80e16SEugene Zhulenev class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> { 29425f80e16SEugene Zhulenev public: 29525f80e16SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 29625f80e16SEugene Zhulenev 29725f80e16SEugene Zhulenev LogicalResult 29825f80e16SEugene Zhulenev matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands, 29925f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 30025f80e16SEugene Zhulenev rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>( 30125f80e16SEugene Zhulenev op, rewriter.getIndexType(), operands); 30225f80e16SEugene Zhulenev return success(); 30325f80e16SEugene Zhulenev } 30425f80e16SEugene Zhulenev }; 30525f80e16SEugene Zhulenev } // namespace 30625f80e16SEugene Zhulenev 30725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 30825f80e16SEugene Zhulenev // Convert async.await and async.await_all operations to the async.runtime.await 30925f80e16SEugene Zhulenev // or async.runtime.await_and_resume operations. 31025f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 31125f80e16SEugene Zhulenev 31225f80e16SEugene Zhulenev namespace { 31325f80e16SEugene Zhulenev template <typename AwaitType, typename AwaitableType> 31425f80e16SEugene Zhulenev class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> { 31525f80e16SEugene Zhulenev using AwaitAdaptor = typename AwaitType::Adaptor; 31625f80e16SEugene Zhulenev 31725f80e16SEugene Zhulenev public: 31825f80e16SEugene Zhulenev AwaitOpLoweringBase( 31925f80e16SEugene Zhulenev MLIRContext *ctx, 32025f80e16SEugene Zhulenev const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 32125f80e16SEugene Zhulenev : OpConversionPattern<AwaitType>(ctx), 32225f80e16SEugene Zhulenev outlinedFunctions(outlinedFunctions) {} 32325f80e16SEugene Zhulenev 32425f80e16SEugene Zhulenev LogicalResult 32525f80e16SEugene Zhulenev matchAndRewrite(AwaitType op, ArrayRef<Value> operands, 32625f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 32725f80e16SEugene Zhulenev // We can only await on one the `AwaitableType` (for `await` it can be 32825f80e16SEugene Zhulenev // a `token` or a `value`, for `await_all` it must be a `group`). 32925f80e16SEugene Zhulenev if (!op.operand().getType().template isa<AwaitableType>()) 33025f80e16SEugene Zhulenev return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); 33125f80e16SEugene Zhulenev 33225f80e16SEugene Zhulenev // Check if await operation is inside the outlined coroutine function. 33325f80e16SEugene Zhulenev auto func = op->template getParentOfType<FuncOp>(); 33425f80e16SEugene Zhulenev auto outlined = outlinedFunctions.find(func); 33525f80e16SEugene Zhulenev const bool isInCoroutine = outlined != outlinedFunctions.end(); 33625f80e16SEugene Zhulenev 33725f80e16SEugene Zhulenev Location loc = op->getLoc(); 33825f80e16SEugene Zhulenev Value operand = AwaitAdaptor(operands).operand(); 33925f80e16SEugene Zhulenev 34025f80e16SEugene Zhulenev // Inside regular functions we use the blocking wait operation to wait for 34125f80e16SEugene Zhulenev // the async object (token, value or group) to become available. 34225f80e16SEugene Zhulenev if (!isInCoroutine) 34325f80e16SEugene Zhulenev rewriter.create<RuntimeAwaitOp>(loc, operand); 34425f80e16SEugene Zhulenev 34525f80e16SEugene Zhulenev // Inside the coroutine we convert await operation into coroutine suspension 34625f80e16SEugene Zhulenev // point, and resume execution asynchronously. 34725f80e16SEugene Zhulenev if (isInCoroutine) { 34825f80e16SEugene Zhulenev const CoroMachinery &coro = outlined->getSecond(); 34925f80e16SEugene Zhulenev Block *suspended = op->getBlock(); 35025f80e16SEugene Zhulenev 35125f80e16SEugene Zhulenev ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); 35225f80e16SEugene Zhulenev MLIRContext *ctx = op->getContext(); 35325f80e16SEugene Zhulenev 35425f80e16SEugene Zhulenev // Save the coroutine state and resume on a runtime managed thread when 35525f80e16SEugene Zhulenev // the operand becomes available. 35625f80e16SEugene Zhulenev auto coroSaveOp = 35725f80e16SEugene Zhulenev builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); 35825f80e16SEugene Zhulenev builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); 35925f80e16SEugene Zhulenev 36025f80e16SEugene Zhulenev // Split the entry block before the await operation. 36125f80e16SEugene Zhulenev Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); 36225f80e16SEugene Zhulenev 36325f80e16SEugene Zhulenev // Add async.coro.suspend as a suspended block terminator. 36425f80e16SEugene Zhulenev builder.setInsertionPointToEnd(suspended); 36525f80e16SEugene Zhulenev builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume, 36625f80e16SEugene Zhulenev coro.cleanup); 36725f80e16SEugene Zhulenev 36825f80e16SEugene Zhulenev // Make sure that replacement value will be constructed in resume block. 36925f80e16SEugene Zhulenev rewriter.setInsertionPointToStart(resume); 37025f80e16SEugene Zhulenev } 37125f80e16SEugene Zhulenev 37225f80e16SEugene Zhulenev // Erase or replace the await operation with the new value. 37325f80e16SEugene Zhulenev if (Value replaceWith = getReplacementValue(op, operand, rewriter)) 37425f80e16SEugene Zhulenev rewriter.replaceOp(op, replaceWith); 37525f80e16SEugene Zhulenev else 37625f80e16SEugene Zhulenev rewriter.eraseOp(op); 37725f80e16SEugene Zhulenev 37825f80e16SEugene Zhulenev return success(); 37925f80e16SEugene Zhulenev } 38025f80e16SEugene Zhulenev 38125f80e16SEugene Zhulenev virtual Value getReplacementValue(AwaitType op, Value operand, 38225f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const { 38325f80e16SEugene Zhulenev return Value(); 38425f80e16SEugene Zhulenev } 38525f80e16SEugene Zhulenev 38625f80e16SEugene Zhulenev private: 38725f80e16SEugene Zhulenev const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 38825f80e16SEugene Zhulenev }; 38925f80e16SEugene Zhulenev 39025f80e16SEugene Zhulenev /// Lowering for `async.await` with a token operand. 39125f80e16SEugene Zhulenev class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> { 39225f80e16SEugene Zhulenev using Base = AwaitOpLoweringBase<AwaitOp, TokenType>; 39325f80e16SEugene Zhulenev 39425f80e16SEugene Zhulenev public: 39525f80e16SEugene Zhulenev using Base::Base; 39625f80e16SEugene Zhulenev }; 39725f80e16SEugene Zhulenev 39825f80e16SEugene Zhulenev /// Lowering for `async.await` with a value operand. 39925f80e16SEugene Zhulenev class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> { 40025f80e16SEugene Zhulenev using Base = AwaitOpLoweringBase<AwaitOp, ValueType>; 40125f80e16SEugene Zhulenev 40225f80e16SEugene Zhulenev public: 40325f80e16SEugene Zhulenev using Base::Base; 40425f80e16SEugene Zhulenev 40525f80e16SEugene Zhulenev Value 40625f80e16SEugene Zhulenev getReplacementValue(AwaitOp op, Value operand, 40725f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 40825f80e16SEugene Zhulenev // Load from the async value storage. 40925f80e16SEugene Zhulenev auto valueType = operand.getType().cast<ValueType>().getValueType(); 41025f80e16SEugene Zhulenev return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); 41125f80e16SEugene Zhulenev } 41225f80e16SEugene Zhulenev }; 41325f80e16SEugene Zhulenev 41425f80e16SEugene Zhulenev /// Lowering for `async.await_all` operation. 41525f80e16SEugene Zhulenev class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> { 41625f80e16SEugene Zhulenev using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>; 41725f80e16SEugene Zhulenev 41825f80e16SEugene Zhulenev public: 41925f80e16SEugene Zhulenev using Base::Base; 42025f80e16SEugene Zhulenev }; 42125f80e16SEugene Zhulenev 42225f80e16SEugene Zhulenev } // namespace 42325f80e16SEugene Zhulenev 42425f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 42525f80e16SEugene Zhulenev // Convert async.yield operation to async.runtime operations. 42625f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 42725f80e16SEugene Zhulenev 42825f80e16SEugene Zhulenev class YieldOpLowering : public OpConversionPattern<async::YieldOp> { 42925f80e16SEugene Zhulenev public: 43025f80e16SEugene Zhulenev YieldOpLowering( 43125f80e16SEugene Zhulenev MLIRContext *ctx, 43225f80e16SEugene Zhulenev const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions) 43325f80e16SEugene Zhulenev : OpConversionPattern<async::YieldOp>(ctx), 43425f80e16SEugene Zhulenev outlinedFunctions(outlinedFunctions) {} 43525f80e16SEugene Zhulenev 43625f80e16SEugene Zhulenev LogicalResult 43725f80e16SEugene Zhulenev matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands, 43825f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 43925f80e16SEugene Zhulenev // Check if yield operation is inside the outlined coroutine function. 44025f80e16SEugene Zhulenev auto func = op->template getParentOfType<FuncOp>(); 44125f80e16SEugene Zhulenev auto outlined = outlinedFunctions.find(func); 44225f80e16SEugene Zhulenev if (outlined == outlinedFunctions.end()) 44325f80e16SEugene Zhulenev return rewriter.notifyMatchFailure( 44425f80e16SEugene Zhulenev op, "operation is not inside the outlined async.execute function"); 44525f80e16SEugene Zhulenev 44625f80e16SEugene Zhulenev Location loc = op->getLoc(); 44725f80e16SEugene Zhulenev const CoroMachinery &coro = outlined->getSecond(); 44825f80e16SEugene Zhulenev 44925f80e16SEugene Zhulenev // Store yielded values into the async values storage and switch async 45025f80e16SEugene Zhulenev // values state to available. 45125f80e16SEugene Zhulenev for (auto tuple : llvm::zip(operands, coro.returnValues)) { 45225f80e16SEugene Zhulenev Value yieldValue = std::get<0>(tuple); 45325f80e16SEugene Zhulenev Value asyncValue = std::get<1>(tuple); 45425f80e16SEugene Zhulenev rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); 45525f80e16SEugene Zhulenev rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); 45625f80e16SEugene Zhulenev } 45725f80e16SEugene Zhulenev 45825f80e16SEugene Zhulenev // Switch the coroutine completion token to available state. 45925f80e16SEugene Zhulenev rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken); 46025f80e16SEugene Zhulenev 46125f80e16SEugene Zhulenev return success(); 46225f80e16SEugene Zhulenev } 46325f80e16SEugene Zhulenev 46425f80e16SEugene Zhulenev private: 46525f80e16SEugene Zhulenev const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions; 46625f80e16SEugene Zhulenev }; 46725f80e16SEugene Zhulenev 46825f80e16SEugene Zhulenev //===----------------------------------------------------------------------===// 46925f80e16SEugene Zhulenev 47025f80e16SEugene Zhulenev void AsyncToAsyncRuntimePass::runOnOperation() { 47125f80e16SEugene Zhulenev ModuleOp module = getOperation(); 47225f80e16SEugene Zhulenev SymbolTable symbolTable(module); 47325f80e16SEugene Zhulenev 47425f80e16SEugene Zhulenev // Outline all `async.execute` body regions into async functions (coroutines). 47525f80e16SEugene Zhulenev llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions; 47625f80e16SEugene Zhulenev 47725f80e16SEugene Zhulenev module.walk([&](ExecuteOp execute) { 47825f80e16SEugene Zhulenev outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); 47925f80e16SEugene Zhulenev }); 48025f80e16SEugene Zhulenev 48125f80e16SEugene Zhulenev LLVM_DEBUG({ 48225f80e16SEugene Zhulenev llvm::dbgs() << "Outlined " << outlinedFunctions.size() 48325f80e16SEugene Zhulenev << " functions built from async.execute operations\n"; 48425f80e16SEugene Zhulenev }); 48525f80e16SEugene Zhulenev 48625f80e16SEugene Zhulenev // Lower async operations to async.runtime operations. 48725f80e16SEugene Zhulenev MLIRContext *ctx = module->getContext(); 488*dc4e913bSChris Lattner RewritePatternSet asyncPatterns(ctx); 48925f80e16SEugene Zhulenev 49025f80e16SEugene Zhulenev // Async lowering does not use type converter because it must preserve all 49125f80e16SEugene Zhulenev // types for async.runtime operations. 492*dc4e913bSChris Lattner asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx); 493*dc4e913bSChris Lattner asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, 49425f80e16SEugene Zhulenev AwaitAllOpLowering, YieldOpLowering>(ctx, 49525f80e16SEugene Zhulenev outlinedFunctions); 49625f80e16SEugene Zhulenev 49725f80e16SEugene Zhulenev // All high level async operations must be lowered to the runtime operations. 49825f80e16SEugene Zhulenev ConversionTarget runtimeTarget(*ctx); 49925f80e16SEugene Zhulenev runtimeTarget.addLegalDialect<AsyncDialect>(); 50025f80e16SEugene Zhulenev runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>(); 50125f80e16SEugene Zhulenev runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>(); 50225f80e16SEugene Zhulenev 50325f80e16SEugene Zhulenev if (failed(applyPartialConversion(module, runtimeTarget, 50425f80e16SEugene Zhulenev std::move(asyncPatterns)))) { 50525f80e16SEugene Zhulenev signalPassFailure(); 50625f80e16SEugene Zhulenev return; 50725f80e16SEugene Zhulenev } 50825f80e16SEugene Zhulenev } 50925f80e16SEugene Zhulenev 51025f80e16SEugene Zhulenev std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() { 51125f80e16SEugene Zhulenev return std::make_unique<AsyncToAsyncRuntimePass>(); 51225f80e16SEugene Zhulenev } 513