125f80e16SEugene Zhulenev //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
225f80e16SEugene Zhulenev //
325f80e16SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
425f80e16SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
525f80e16SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
625f80e16SEugene Zhulenev //
725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
825f80e16SEugene Zhulenev //
925f80e16SEugene Zhulenev // This file implements lowering from high level async operations to async.coro
1025f80e16SEugene Zhulenev // and async.runtime operations.
1125f80e16SEugene Zhulenev //
1225f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
1325f80e16SEugene Zhulenev 
1425f80e16SEugene Zhulenev #include "PassDetail.h"
1525f80e16SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
1625f80e16SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h"
1725f80e16SEugene Zhulenev #include "mlir/Dialect/StandardOps/IR/Ops.h"
1825f80e16SEugene Zhulenev #include "mlir/IR/BlockAndValueMapping.h"
1925f80e16SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h"
2025f80e16SEugene Zhulenev #include "mlir/IR/PatternMatch.h"
2125f80e16SEugene Zhulenev #include "mlir/Transforms/DialectConversion.h"
2225f80e16SEugene Zhulenev #include "mlir/Transforms/RegionUtils.h"
2325f80e16SEugene Zhulenev #include "llvm/ADT/SetVector.h"
2425f80e16SEugene Zhulenev 
2525f80e16SEugene Zhulenev using namespace mlir;
2625f80e16SEugene Zhulenev using namespace mlir::async;
2725f80e16SEugene Zhulenev 
2825f80e16SEugene Zhulenev #define DEBUG_TYPE "async-to-async-runtime"
2925f80e16SEugene Zhulenev // Prefix for functions outlined from `async.execute` op regions.
3025f80e16SEugene Zhulenev static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
3125f80e16SEugene Zhulenev 
3225f80e16SEugene Zhulenev namespace {
3325f80e16SEugene Zhulenev 
3425f80e16SEugene Zhulenev class AsyncToAsyncRuntimePass
3525f80e16SEugene Zhulenev     : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
3625f80e16SEugene Zhulenev public:
3725f80e16SEugene Zhulenev   AsyncToAsyncRuntimePass() = default;
3825f80e16SEugene Zhulenev   void runOnOperation() override;
3925f80e16SEugene Zhulenev };
4025f80e16SEugene Zhulenev 
4125f80e16SEugene Zhulenev } // namespace
4225f80e16SEugene Zhulenev 
4325f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
4425f80e16SEugene Zhulenev // async.execute op outlining to the coroutine functions.
4525f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
4625f80e16SEugene Zhulenev 
4725f80e16SEugene Zhulenev /// Function targeted for coroutine transformation has two additional blocks at
4825f80e16SEugene Zhulenev /// the end: coroutine cleanup and coroutine suspension.
4925f80e16SEugene Zhulenev ///
5025f80e16SEugene Zhulenev /// async.await op lowering additionaly creates a resume block for each
5125f80e16SEugene Zhulenev /// operation to enable non-blocking waiting via coroutine suspension.
5225f80e16SEugene Zhulenev namespace {
5325f80e16SEugene Zhulenev struct CoroMachinery {
5425f80e16SEugene Zhulenev   // Async execute region returns a completion token, and an async value for
5525f80e16SEugene Zhulenev   // each yielded value.
5625f80e16SEugene Zhulenev   //
5725f80e16SEugene Zhulenev   //   %token, %result = async.execute -> !async.value<T> {
5825f80e16SEugene Zhulenev   //     %0 = constant ... : T
5925f80e16SEugene Zhulenev   //     async.yield %0 : T
6025f80e16SEugene Zhulenev   //   }
6125f80e16SEugene Zhulenev   Value asyncToken; // token representing completion of the async region
6225f80e16SEugene Zhulenev   llvm::SmallVector<Value, 4> returnValues; // returned async values
6325f80e16SEugene Zhulenev 
6425f80e16SEugene Zhulenev   Value coroHandle; // coroutine handle (!async.coro.handle value)
6525f80e16SEugene Zhulenev   Block *cleanup;   // coroutine cleanup block
6625f80e16SEugene Zhulenev   Block *suspend;   // coroutine suspension block
6725f80e16SEugene Zhulenev };
6825f80e16SEugene Zhulenev } // namespace
6925f80e16SEugene Zhulenev 
7025f80e16SEugene Zhulenev /// Builds an coroutine template compatible with LLVM coroutines switched-resume
7125f80e16SEugene Zhulenev /// lowering using `async.runtime.*` and `async.coro.*` operations.
7225f80e16SEugene Zhulenev ///
7325f80e16SEugene Zhulenev /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
7425f80e16SEugene Zhulenev ///
7525f80e16SEugene Zhulenev ///  - `entry` block sets up the coroutine.
7625f80e16SEugene Zhulenev ///  - `cleanup` block cleans up the coroutine state.
7725f80e16SEugene Zhulenev ///  - `suspend block after the @llvm.coro.end() defines what value will be
7825f80e16SEugene Zhulenev ///    returned to the initial caller of a coroutine. Everything before the
7925f80e16SEugene Zhulenev ///    @llvm.coro.end() will be executed at every suspension point.
8025f80e16SEugene Zhulenev ///
8125f80e16SEugene Zhulenev /// Coroutine structure (only the important bits):
8225f80e16SEugene Zhulenev ///
8325f80e16SEugene Zhulenev ///   func @async_execute_fn(<function-arguments>)
8425f80e16SEugene Zhulenev ///        -> (!async.token, !async.value<T>)
8525f80e16SEugene Zhulenev ///   {
8625f80e16SEugene Zhulenev ///     ^entry(<function-arguments>):
8725f80e16SEugene Zhulenev ///       %token = <async token> : !async.token    // create async runtime token
8825f80e16SEugene Zhulenev ///       %value = <async value> : !async.value<T> // create async value
8925f80e16SEugene Zhulenev ///       %id = async.coro.id                      // create a coroutine id
9025f80e16SEugene Zhulenev ///       %hdl = async.coro.begin %id              // create a coroutine handle
9125f80e16SEugene Zhulenev ///       br ^cleanup
9225f80e16SEugene Zhulenev ///
9325f80e16SEugene Zhulenev ///     ^cleanup:
9425f80e16SEugene Zhulenev ///       async.coro.free %hdl // delete the coroutine state
9525f80e16SEugene Zhulenev ///       br ^suspend
9625f80e16SEugene Zhulenev ///
9725f80e16SEugene Zhulenev ///     ^suspend:
9825f80e16SEugene Zhulenev ///       async.coro.end %hdl // marks the end of a coroutine
9925f80e16SEugene Zhulenev ///       return %token, %value : !async.token, !async.value<T>
10025f80e16SEugene Zhulenev ///   }
10125f80e16SEugene Zhulenev ///
10225f80e16SEugene Zhulenev /// The actual code for the async.execute operation body region will be inserted
10325f80e16SEugene Zhulenev /// before the entry block terminator.
10425f80e16SEugene Zhulenev ///
10525f80e16SEugene Zhulenev ///
10625f80e16SEugene Zhulenev static CoroMachinery setupCoroMachinery(FuncOp func) {
10725f80e16SEugene Zhulenev   assert(func.getBody().empty() && "Function must have empty body");
10825f80e16SEugene Zhulenev 
10925f80e16SEugene Zhulenev   MLIRContext *ctx = func.getContext();
11025f80e16SEugene Zhulenev   Block *entryBlock = func.addEntryBlock();
11125f80e16SEugene Zhulenev 
11225f80e16SEugene Zhulenev   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
11325f80e16SEugene Zhulenev 
11425f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
11525f80e16SEugene Zhulenev   // Allocate async token/values that we will return from a ramp function.
11625f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
11725f80e16SEugene Zhulenev   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
11825f80e16SEugene Zhulenev 
11925f80e16SEugene Zhulenev   llvm::SmallVector<Value, 4> retValues;
12025f80e16SEugene Zhulenev   for (auto resType : func.getCallableResults().drop_front())
12125f80e16SEugene Zhulenev     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
12225f80e16SEugene Zhulenev 
12325f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
12425f80e16SEugene Zhulenev   // Initialize coroutine: get coroutine id and coroutine handle.
12525f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
12625f80e16SEugene Zhulenev   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
12725f80e16SEugene Zhulenev   auto coroHdlOp =
12825f80e16SEugene Zhulenev       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
12925f80e16SEugene Zhulenev 
13025f80e16SEugene Zhulenev   Block *cleanupBlock = func.addBlock();
13125f80e16SEugene Zhulenev   Block *suspendBlock = func.addBlock();
13225f80e16SEugene Zhulenev 
13325f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
13425f80e16SEugene Zhulenev   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
13525f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
13625f80e16SEugene Zhulenev   builder.setInsertionPointToStart(cleanupBlock);
13725f80e16SEugene Zhulenev   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
13825f80e16SEugene Zhulenev 
13925f80e16SEugene Zhulenev   // Branch into the suspend block.
14025f80e16SEugene Zhulenev   builder.create<BranchOp>(suspendBlock);
14125f80e16SEugene Zhulenev 
14225f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
14325f80e16SEugene Zhulenev   // Coroutine suspend block: mark the end of a coroutine and return allocated
14425f80e16SEugene Zhulenev   // async token.
14525f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
14625f80e16SEugene Zhulenev   builder.setInsertionPointToStart(suspendBlock);
14725f80e16SEugene Zhulenev 
14825f80e16SEugene Zhulenev   // Mark the end of a coroutine: async.coro.end
14925f80e16SEugene Zhulenev   builder.create<CoroEndOp>(coroHdlOp.handle());
15025f80e16SEugene Zhulenev 
15125f80e16SEugene Zhulenev   // Return created `async.token` and `async.values` from the suspend block.
15225f80e16SEugene Zhulenev   // This will be the return value of a coroutine ramp function.
15325f80e16SEugene Zhulenev   SmallVector<Value, 4> ret{retToken};
15425f80e16SEugene Zhulenev   ret.insert(ret.end(), retValues.begin(), retValues.end());
15525f80e16SEugene Zhulenev   builder.create<ReturnOp>(ret);
15625f80e16SEugene Zhulenev 
15725f80e16SEugene Zhulenev   // Branch from the entry block to the cleanup block to create a valid CFG.
15825f80e16SEugene Zhulenev   builder.setInsertionPointToEnd(entryBlock);
15925f80e16SEugene Zhulenev   builder.create<BranchOp>(cleanupBlock);
16025f80e16SEugene Zhulenev 
16125f80e16SEugene Zhulenev   // `async.await` op lowering will create resume blocks for async
16225f80e16SEugene Zhulenev   // continuations, and will conditionally branch to cleanup or suspend blocks.
16325f80e16SEugene Zhulenev 
16425f80e16SEugene Zhulenev   CoroMachinery machinery;
16525f80e16SEugene Zhulenev   machinery.asyncToken = retToken;
16625f80e16SEugene Zhulenev   machinery.returnValues = retValues;
16725f80e16SEugene Zhulenev   machinery.coroHandle = coroHdlOp.handle();
16825f80e16SEugene Zhulenev   machinery.cleanup = cleanupBlock;
16925f80e16SEugene Zhulenev   machinery.suspend = suspendBlock;
17025f80e16SEugene Zhulenev   return machinery;
17125f80e16SEugene Zhulenev }
17225f80e16SEugene Zhulenev 
17325f80e16SEugene Zhulenev /// Outline the body region attached to the `async.execute` op into a standalone
17425f80e16SEugene Zhulenev /// function.
17525f80e16SEugene Zhulenev ///
17625f80e16SEugene Zhulenev /// Note that this is not reversible transformation.
17725f80e16SEugene Zhulenev static std::pair<FuncOp, CoroMachinery>
17825f80e16SEugene Zhulenev outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
17925f80e16SEugene Zhulenev   ModuleOp module = execute->getParentOfType<ModuleOp>();
18025f80e16SEugene Zhulenev 
18125f80e16SEugene Zhulenev   MLIRContext *ctx = module.getContext();
18225f80e16SEugene Zhulenev   Location loc = execute.getLoc();
18325f80e16SEugene Zhulenev 
18425f80e16SEugene Zhulenev   // Collect all outlined function inputs.
18525f80e16SEugene Zhulenev   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
18625f80e16SEugene Zhulenev                                               execute.dependencies().end());
18725f80e16SEugene Zhulenev   functionInputs.insert(execute.operands().begin(), execute.operands().end());
18825f80e16SEugene Zhulenev   getUsedValuesDefinedAbove(execute.body(), functionInputs);
18925f80e16SEugene Zhulenev 
19025f80e16SEugene Zhulenev   // Collect types for the outlined function inputs and outputs.
19125f80e16SEugene Zhulenev   auto typesRange = llvm::map_range(
19225f80e16SEugene Zhulenev       functionInputs, [](Value value) { return value.getType(); });
19325f80e16SEugene Zhulenev   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
19425f80e16SEugene Zhulenev   auto outputTypes = execute.getResultTypes();
19525f80e16SEugene Zhulenev 
19625f80e16SEugene Zhulenev   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
19725f80e16SEugene Zhulenev   auto funcAttrs = ArrayRef<NamedAttribute>();
19825f80e16SEugene Zhulenev 
19925f80e16SEugene Zhulenev   // TODO: Derive outlined function name from the parent FuncOp (support
20025f80e16SEugene Zhulenev   // multiple nested async.execute operations).
20125f80e16SEugene Zhulenev   FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
20225f80e16SEugene Zhulenev   symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
20325f80e16SEugene Zhulenev 
20425f80e16SEugene Zhulenev   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
20525f80e16SEugene Zhulenev 
20625f80e16SEugene Zhulenev   // Prepare a function for coroutine lowering by adding entry/cleanup/suspend
20725f80e16SEugene Zhulenev   // blocks, adding async.coro operations and setting up control flow.
20825f80e16SEugene Zhulenev   CoroMachinery coro = setupCoroMachinery(func);
20925f80e16SEugene Zhulenev 
21025f80e16SEugene Zhulenev   // Suspend async function at the end of an entry block, and resume it using
21125f80e16SEugene Zhulenev   // Async resume operation (execution will be resumed in a thread managed by
21225f80e16SEugene Zhulenev   // the async runtime).
21325f80e16SEugene Zhulenev   Block *entryBlock = &func.getBlocks().front();
21425f80e16SEugene Zhulenev   auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock);
21525f80e16SEugene Zhulenev 
21625f80e16SEugene Zhulenev   // Save the coroutine state: async.coro.save
21725f80e16SEugene Zhulenev   auto coroSaveOp =
21825f80e16SEugene Zhulenev       builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
21925f80e16SEugene Zhulenev 
22025f80e16SEugene Zhulenev   // Pass coroutine to the runtime to be resumed on a runtime managed thread.
22125f80e16SEugene Zhulenev   builder.create<RuntimeResumeOp>(coro.coroHandle);
22225f80e16SEugene Zhulenev 
22325f80e16SEugene Zhulenev   // Split the entry block before the terminator (branch to suspend block).
22425f80e16SEugene Zhulenev   auto *terminatorOp = entryBlock->getTerminator();
22525f80e16SEugene Zhulenev   Block *suspended = terminatorOp->getBlock();
22625f80e16SEugene Zhulenev   Block *resume = suspended->splitBlock(terminatorOp);
22725f80e16SEugene Zhulenev 
22825f80e16SEugene Zhulenev   // Add async.coro.suspend as a suspended block terminator.
22925f80e16SEugene Zhulenev   builder.setInsertionPointToEnd(suspended);
23025f80e16SEugene Zhulenev   builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
23125f80e16SEugene Zhulenev                                 coro.cleanup);
23225f80e16SEugene Zhulenev 
23325f80e16SEugene Zhulenev   size_t numDependencies = execute.dependencies().size();
23425f80e16SEugene Zhulenev   size_t numOperands = execute.operands().size();
23525f80e16SEugene Zhulenev 
23625f80e16SEugene Zhulenev   // Await on all dependencies before starting to execute the body region.
23725f80e16SEugene Zhulenev   builder.setInsertionPointToStart(resume);
23825f80e16SEugene Zhulenev   for (size_t i = 0; i < numDependencies; ++i)
23925f80e16SEugene Zhulenev     builder.create<AwaitOp>(func.getArgument(i));
24025f80e16SEugene Zhulenev 
24125f80e16SEugene Zhulenev   // Await on all async value operands and unwrap the payload.
24225f80e16SEugene Zhulenev   SmallVector<Value, 4> unwrappedOperands(numOperands);
24325f80e16SEugene Zhulenev   for (size_t i = 0; i < numOperands; ++i) {
24425f80e16SEugene Zhulenev     Value operand = func.getArgument(numDependencies + i);
24525f80e16SEugene Zhulenev     unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
24625f80e16SEugene Zhulenev   }
24725f80e16SEugene Zhulenev 
24825f80e16SEugene Zhulenev   // Map from function inputs defined above the execute op to the function
24925f80e16SEugene Zhulenev   // arguments.
25025f80e16SEugene Zhulenev   BlockAndValueMapping valueMapping;
25125f80e16SEugene Zhulenev   valueMapping.map(functionInputs, func.getArguments());
25225f80e16SEugene Zhulenev   valueMapping.map(execute.body().getArguments(), unwrappedOperands);
25325f80e16SEugene Zhulenev 
25425f80e16SEugene Zhulenev   // Clone all operations from the execute operation body into the outlined
25525f80e16SEugene Zhulenev   // function body.
25625f80e16SEugene Zhulenev   for (Operation &op : execute.body().getOps())
25725f80e16SEugene Zhulenev     builder.clone(op, valueMapping);
25825f80e16SEugene Zhulenev 
25925f80e16SEugene Zhulenev   // Replace the original `async.execute` with a call to outlined function.
26025f80e16SEugene Zhulenev   ImplicitLocOpBuilder callBuilder(loc, execute);
26125f80e16SEugene Zhulenev   auto callOutlinedFunc = callBuilder.create<CallOp>(
26225f80e16SEugene Zhulenev       func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
26325f80e16SEugene Zhulenev   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
26425f80e16SEugene Zhulenev   execute.erase();
26525f80e16SEugene Zhulenev 
26625f80e16SEugene Zhulenev   return {func, coro};
26725f80e16SEugene Zhulenev }
26825f80e16SEugene Zhulenev 
26925f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
27025f80e16SEugene Zhulenev // Convert async.create_group operation to async.runtime.create
27125f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
27225f80e16SEugene Zhulenev 
27325f80e16SEugene Zhulenev namespace {
27425f80e16SEugene Zhulenev class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
27525f80e16SEugene Zhulenev public:
27625f80e16SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
27725f80e16SEugene Zhulenev 
27825f80e16SEugene Zhulenev   LogicalResult
27925f80e16SEugene Zhulenev   matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
28025f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
28125f80e16SEugene Zhulenev     rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
28225f80e16SEugene Zhulenev         op, GroupType::get(op->getContext()));
28325f80e16SEugene Zhulenev     return success();
28425f80e16SEugene Zhulenev   }
28525f80e16SEugene Zhulenev };
28625f80e16SEugene Zhulenev } // namespace
28725f80e16SEugene Zhulenev 
28825f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
28925f80e16SEugene Zhulenev // Convert async.add_to_group operation to async.runtime.add_to_group.
29025f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
29125f80e16SEugene Zhulenev 
29225f80e16SEugene Zhulenev namespace {
29325f80e16SEugene Zhulenev class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
29425f80e16SEugene Zhulenev public:
29525f80e16SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
29625f80e16SEugene Zhulenev 
29725f80e16SEugene Zhulenev   LogicalResult
29825f80e16SEugene Zhulenev   matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands,
29925f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
30025f80e16SEugene Zhulenev     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
30125f80e16SEugene Zhulenev         op, rewriter.getIndexType(), operands);
30225f80e16SEugene Zhulenev     return success();
30325f80e16SEugene Zhulenev   }
30425f80e16SEugene Zhulenev };
30525f80e16SEugene Zhulenev } // namespace
30625f80e16SEugene Zhulenev 
30725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
30825f80e16SEugene Zhulenev // Convert async.await and async.await_all operations to the async.runtime.await
30925f80e16SEugene Zhulenev // or async.runtime.await_and_resume operations.
31025f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
31125f80e16SEugene Zhulenev 
31225f80e16SEugene Zhulenev namespace {
31325f80e16SEugene Zhulenev template <typename AwaitType, typename AwaitableType>
31425f80e16SEugene Zhulenev class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
31525f80e16SEugene Zhulenev   using AwaitAdaptor = typename AwaitType::Adaptor;
31625f80e16SEugene Zhulenev 
31725f80e16SEugene Zhulenev public:
31825f80e16SEugene Zhulenev   AwaitOpLoweringBase(
31925f80e16SEugene Zhulenev       MLIRContext *ctx,
32025f80e16SEugene Zhulenev       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
32125f80e16SEugene Zhulenev       : OpConversionPattern<AwaitType>(ctx),
32225f80e16SEugene Zhulenev         outlinedFunctions(outlinedFunctions) {}
32325f80e16SEugene Zhulenev 
32425f80e16SEugene Zhulenev   LogicalResult
32525f80e16SEugene Zhulenev   matchAndRewrite(AwaitType op, ArrayRef<Value> operands,
32625f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
32725f80e16SEugene Zhulenev     // We can only await on one the `AwaitableType` (for `await` it can be
32825f80e16SEugene Zhulenev     // a `token` or a `value`, for `await_all` it must be a `group`).
32925f80e16SEugene Zhulenev     if (!op.operand().getType().template isa<AwaitableType>())
33025f80e16SEugene Zhulenev       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
33125f80e16SEugene Zhulenev 
33225f80e16SEugene Zhulenev     // Check if await operation is inside the outlined coroutine function.
33325f80e16SEugene Zhulenev     auto func = op->template getParentOfType<FuncOp>();
33425f80e16SEugene Zhulenev     auto outlined = outlinedFunctions.find(func);
33525f80e16SEugene Zhulenev     const bool isInCoroutine = outlined != outlinedFunctions.end();
33625f80e16SEugene Zhulenev 
33725f80e16SEugene Zhulenev     Location loc = op->getLoc();
33825f80e16SEugene Zhulenev     Value operand = AwaitAdaptor(operands).operand();
33925f80e16SEugene Zhulenev 
34025f80e16SEugene Zhulenev     // Inside regular functions we use the blocking wait operation to wait for
34125f80e16SEugene Zhulenev     // the async object (token, value or group) to become available.
34225f80e16SEugene Zhulenev     if (!isInCoroutine)
34325f80e16SEugene Zhulenev       rewriter.create<RuntimeAwaitOp>(loc, operand);
34425f80e16SEugene Zhulenev 
34525f80e16SEugene Zhulenev     // Inside the coroutine we convert await operation into coroutine suspension
34625f80e16SEugene Zhulenev     // point, and resume execution asynchronously.
34725f80e16SEugene Zhulenev     if (isInCoroutine) {
34825f80e16SEugene Zhulenev       const CoroMachinery &coro = outlined->getSecond();
34925f80e16SEugene Zhulenev       Block *suspended = op->getBlock();
35025f80e16SEugene Zhulenev 
35125f80e16SEugene Zhulenev       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
35225f80e16SEugene Zhulenev       MLIRContext *ctx = op->getContext();
35325f80e16SEugene Zhulenev 
35425f80e16SEugene Zhulenev       // Save the coroutine state and resume on a runtime managed thread when
35525f80e16SEugene Zhulenev       // the operand becomes available.
35625f80e16SEugene Zhulenev       auto coroSaveOp =
35725f80e16SEugene Zhulenev           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
35825f80e16SEugene Zhulenev       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
35925f80e16SEugene Zhulenev 
36025f80e16SEugene Zhulenev       // Split the entry block before the await operation.
36125f80e16SEugene Zhulenev       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
36225f80e16SEugene Zhulenev 
36325f80e16SEugene Zhulenev       // Add async.coro.suspend as a suspended block terminator.
36425f80e16SEugene Zhulenev       builder.setInsertionPointToEnd(suspended);
36525f80e16SEugene Zhulenev       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
36625f80e16SEugene Zhulenev                                     coro.cleanup);
36725f80e16SEugene Zhulenev 
36825f80e16SEugene Zhulenev       // Make sure that replacement value will be constructed in resume block.
36925f80e16SEugene Zhulenev       rewriter.setInsertionPointToStart(resume);
37025f80e16SEugene Zhulenev     }
37125f80e16SEugene Zhulenev 
37225f80e16SEugene Zhulenev     // Erase or replace the await operation with the new value.
37325f80e16SEugene Zhulenev     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
37425f80e16SEugene Zhulenev       rewriter.replaceOp(op, replaceWith);
37525f80e16SEugene Zhulenev     else
37625f80e16SEugene Zhulenev       rewriter.eraseOp(op);
37725f80e16SEugene Zhulenev 
37825f80e16SEugene Zhulenev     return success();
37925f80e16SEugene Zhulenev   }
38025f80e16SEugene Zhulenev 
38125f80e16SEugene Zhulenev   virtual Value getReplacementValue(AwaitType op, Value operand,
38225f80e16SEugene Zhulenev                                     ConversionPatternRewriter &rewriter) const {
38325f80e16SEugene Zhulenev     return Value();
38425f80e16SEugene Zhulenev   }
38525f80e16SEugene Zhulenev 
38625f80e16SEugene Zhulenev private:
38725f80e16SEugene Zhulenev   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
38825f80e16SEugene Zhulenev };
38925f80e16SEugene Zhulenev 
39025f80e16SEugene Zhulenev /// Lowering for `async.await` with a token operand.
39125f80e16SEugene Zhulenev class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
39225f80e16SEugene Zhulenev   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
39325f80e16SEugene Zhulenev 
39425f80e16SEugene Zhulenev public:
39525f80e16SEugene Zhulenev   using Base::Base;
39625f80e16SEugene Zhulenev };
39725f80e16SEugene Zhulenev 
39825f80e16SEugene Zhulenev /// Lowering for `async.await` with a value operand.
39925f80e16SEugene Zhulenev class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
40025f80e16SEugene Zhulenev   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
40125f80e16SEugene Zhulenev 
40225f80e16SEugene Zhulenev public:
40325f80e16SEugene Zhulenev   using Base::Base;
40425f80e16SEugene Zhulenev 
40525f80e16SEugene Zhulenev   Value
40625f80e16SEugene Zhulenev   getReplacementValue(AwaitOp op, Value operand,
40725f80e16SEugene Zhulenev                       ConversionPatternRewriter &rewriter) const override {
40825f80e16SEugene Zhulenev     // Load from the async value storage.
40925f80e16SEugene Zhulenev     auto valueType = operand.getType().cast<ValueType>().getValueType();
41025f80e16SEugene Zhulenev     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
41125f80e16SEugene Zhulenev   }
41225f80e16SEugene Zhulenev };
41325f80e16SEugene Zhulenev 
41425f80e16SEugene Zhulenev /// Lowering for `async.await_all` operation.
41525f80e16SEugene Zhulenev class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
41625f80e16SEugene Zhulenev   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
41725f80e16SEugene Zhulenev 
41825f80e16SEugene Zhulenev public:
41925f80e16SEugene Zhulenev   using Base::Base;
42025f80e16SEugene Zhulenev };
42125f80e16SEugene Zhulenev 
42225f80e16SEugene Zhulenev } // namespace
42325f80e16SEugene Zhulenev 
42425f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
42525f80e16SEugene Zhulenev // Convert async.yield operation to async.runtime operations.
42625f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
42725f80e16SEugene Zhulenev 
42825f80e16SEugene Zhulenev class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
42925f80e16SEugene Zhulenev public:
43025f80e16SEugene Zhulenev   YieldOpLowering(
43125f80e16SEugene Zhulenev       MLIRContext *ctx,
43225f80e16SEugene Zhulenev       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
43325f80e16SEugene Zhulenev       : OpConversionPattern<async::YieldOp>(ctx),
43425f80e16SEugene Zhulenev         outlinedFunctions(outlinedFunctions) {}
43525f80e16SEugene Zhulenev 
43625f80e16SEugene Zhulenev   LogicalResult
43725f80e16SEugene Zhulenev   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
43825f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
43925f80e16SEugene Zhulenev     // Check if yield operation is inside the outlined coroutine function.
44025f80e16SEugene Zhulenev     auto func = op->template getParentOfType<FuncOp>();
44125f80e16SEugene Zhulenev     auto outlined = outlinedFunctions.find(func);
44225f80e16SEugene Zhulenev     if (outlined == outlinedFunctions.end())
44325f80e16SEugene Zhulenev       return rewriter.notifyMatchFailure(
44425f80e16SEugene Zhulenev           op, "operation is not inside the outlined async.execute function");
44525f80e16SEugene Zhulenev 
44625f80e16SEugene Zhulenev     Location loc = op->getLoc();
44725f80e16SEugene Zhulenev     const CoroMachinery &coro = outlined->getSecond();
44825f80e16SEugene Zhulenev 
44925f80e16SEugene Zhulenev     // Store yielded values into the async values storage and switch async
45025f80e16SEugene Zhulenev     // values state to available.
45125f80e16SEugene Zhulenev     for (auto tuple : llvm::zip(operands, coro.returnValues)) {
45225f80e16SEugene Zhulenev       Value yieldValue = std::get<0>(tuple);
45325f80e16SEugene Zhulenev       Value asyncValue = std::get<1>(tuple);
45425f80e16SEugene Zhulenev       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
45525f80e16SEugene Zhulenev       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
45625f80e16SEugene Zhulenev     }
45725f80e16SEugene Zhulenev 
45825f80e16SEugene Zhulenev     // Switch the coroutine completion token to available state.
45925f80e16SEugene Zhulenev     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
46025f80e16SEugene Zhulenev 
46125f80e16SEugene Zhulenev     return success();
46225f80e16SEugene Zhulenev   }
46325f80e16SEugene Zhulenev 
46425f80e16SEugene Zhulenev private:
46525f80e16SEugene Zhulenev   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
46625f80e16SEugene Zhulenev };
46725f80e16SEugene Zhulenev 
46825f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
46925f80e16SEugene Zhulenev 
47025f80e16SEugene Zhulenev void AsyncToAsyncRuntimePass::runOnOperation() {
47125f80e16SEugene Zhulenev   ModuleOp module = getOperation();
47225f80e16SEugene Zhulenev   SymbolTable symbolTable(module);
47325f80e16SEugene Zhulenev 
47425f80e16SEugene Zhulenev   // Outline all `async.execute` body regions into async functions (coroutines).
47525f80e16SEugene Zhulenev   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
47625f80e16SEugene Zhulenev 
47725f80e16SEugene Zhulenev   module.walk([&](ExecuteOp execute) {
47825f80e16SEugene Zhulenev     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
47925f80e16SEugene Zhulenev   });
48025f80e16SEugene Zhulenev 
48125f80e16SEugene Zhulenev   LLVM_DEBUG({
48225f80e16SEugene Zhulenev     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
48325f80e16SEugene Zhulenev                  << " functions built from async.execute operations\n";
48425f80e16SEugene Zhulenev   });
48525f80e16SEugene Zhulenev 
48625f80e16SEugene Zhulenev   // Lower async operations to async.runtime operations.
48725f80e16SEugene Zhulenev   MLIRContext *ctx = module->getContext();
488*dc4e913bSChris Lattner   RewritePatternSet asyncPatterns(ctx);
48925f80e16SEugene Zhulenev 
49025f80e16SEugene Zhulenev   // Async lowering does not use type converter because it must preserve all
49125f80e16SEugene Zhulenev   // types for async.runtime operations.
492*dc4e913bSChris Lattner   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
493*dc4e913bSChris Lattner   asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
49425f80e16SEugene Zhulenev                     AwaitAllOpLowering, YieldOpLowering>(ctx,
49525f80e16SEugene Zhulenev                                                          outlinedFunctions);
49625f80e16SEugene Zhulenev 
49725f80e16SEugene Zhulenev   // All high level async operations must be lowered to the runtime operations.
49825f80e16SEugene Zhulenev   ConversionTarget runtimeTarget(*ctx);
49925f80e16SEugene Zhulenev   runtimeTarget.addLegalDialect<AsyncDialect>();
50025f80e16SEugene Zhulenev   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
50125f80e16SEugene Zhulenev   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
50225f80e16SEugene Zhulenev 
50325f80e16SEugene Zhulenev   if (failed(applyPartialConversion(module, runtimeTarget,
50425f80e16SEugene Zhulenev                                     std::move(asyncPatterns)))) {
50525f80e16SEugene Zhulenev     signalPassFailure();
50625f80e16SEugene Zhulenev     return;
50725f80e16SEugene Zhulenev   }
50825f80e16SEugene Zhulenev }
50925f80e16SEugene Zhulenev 
51025f80e16SEugene Zhulenev std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
51125f80e16SEugene Zhulenev   return std::make_unique<AsyncToAsyncRuntimePass>();
51225f80e16SEugene Zhulenev }
513