125f80e16SEugene Zhulenev //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
225f80e16SEugene Zhulenev //
325f80e16SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
425f80e16SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
525f80e16SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
625f80e16SEugene Zhulenev //
725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
825f80e16SEugene Zhulenev //
925f80e16SEugene Zhulenev // This file implements lowering from high level async operations to async.coro
1025f80e16SEugene Zhulenev // and async.runtime operations.
1125f80e16SEugene Zhulenev //
1225f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
1325f80e16SEugene Zhulenev 
1425f80e16SEugene Zhulenev #include "PassDetail.h"
15ace01605SRiver Riddle #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
16a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1725f80e16SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
1825f80e16SEugene Zhulenev #include "mlir/Dialect/Async/Passes.h"
19ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2023aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
21*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
2225f80e16SEugene Zhulenev #include "mlir/IR/BlockAndValueMapping.h"
2325f80e16SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h"
2425f80e16SEugene Zhulenev #include "mlir/IR/PatternMatch.h"
2525f80e16SEugene Zhulenev #include "mlir/Transforms/DialectConversion.h"
2625f80e16SEugene Zhulenev #include "mlir/Transforms/RegionUtils.h"
2725f80e16SEugene Zhulenev #include "llvm/ADT/SetVector.h"
28297a5b7cSNico Weber #include "llvm/Support/Debug.h"
2925f80e16SEugene Zhulenev 
3025f80e16SEugene Zhulenev using namespace mlir;
3125f80e16SEugene Zhulenev using namespace mlir::async;
3225f80e16SEugene Zhulenev 
3325f80e16SEugene Zhulenev #define DEBUG_TYPE "async-to-async-runtime"
3425f80e16SEugene Zhulenev // Prefix for functions outlined from `async.execute` op regions.
3525f80e16SEugene Zhulenev static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
3625f80e16SEugene Zhulenev 
3725f80e16SEugene Zhulenev namespace {
3825f80e16SEugene Zhulenev 
3925f80e16SEugene Zhulenev class AsyncToAsyncRuntimePass
4025f80e16SEugene Zhulenev     : public AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
4125f80e16SEugene Zhulenev public:
4225f80e16SEugene Zhulenev   AsyncToAsyncRuntimePass() = default;
4325f80e16SEugene Zhulenev   void runOnOperation() override;
4425f80e16SEugene Zhulenev };
4525f80e16SEugene Zhulenev 
4625f80e16SEugene Zhulenev } // namespace
4725f80e16SEugene Zhulenev 
4825f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
4925f80e16SEugene Zhulenev // async.execute op outlining to the coroutine functions.
5025f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
5125f80e16SEugene Zhulenev 
5225f80e16SEugene Zhulenev /// Function targeted for coroutine transformation has two additional blocks at
5325f80e16SEugene Zhulenev /// the end: coroutine cleanup and coroutine suspension.
5425f80e16SEugene Zhulenev ///
5525f80e16SEugene Zhulenev /// async.await op lowering additionaly creates a resume block for each
5625f80e16SEugene Zhulenev /// operation to enable non-blocking waiting via coroutine suspension.
5725f80e16SEugene Zhulenev namespace {
5825f80e16SEugene Zhulenev struct CoroMachinery {
5958ceae95SRiver Riddle   func::FuncOp func;
6039957aa4SEugene Zhulenev 
6125f80e16SEugene Zhulenev   // Async execute region returns a completion token, and an async value for
6225f80e16SEugene Zhulenev   // each yielded value.
6325f80e16SEugene Zhulenev   //
6425f80e16SEugene Zhulenev   //   %token, %result = async.execute -> !async.value<T> {
65cb3aa49eSMogball   //     %0 = arith.constant ... : T
6625f80e16SEugene Zhulenev   //     async.yield %0 : T
6725f80e16SEugene Zhulenev   //   }
6825f80e16SEugene Zhulenev   Value asyncToken; // token representing completion of the async region
6925f80e16SEugene Zhulenev   llvm::SmallVector<Value, 4> returnValues; // returned async values
7025f80e16SEugene Zhulenev 
7125f80e16SEugene Zhulenev   Value coroHandle; // coroutine handle (!async.coro.handle value)
721c144410Sbakhtiyar   Block *entry;     // coroutine entry block
7339957aa4SEugene Zhulenev   Block *setError;  // switch completion token and all values to error state
7425f80e16SEugene Zhulenev   Block *cleanup;   // coroutine cleanup block
7525f80e16SEugene Zhulenev   Block *suspend;   // coroutine suspension block
7625f80e16SEugene Zhulenev };
7725f80e16SEugene Zhulenev } // namespace
7825f80e16SEugene Zhulenev 
796ea22d46Sbakhtiyar /// Utility to partially update the regular function CFG to the coroutine CFG
806ea22d46Sbakhtiyar /// compatible with LLVM coroutines switched-resume lowering using
811c144410Sbakhtiyar /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
821c144410Sbakhtiyar /// that branches into preexisting entry block. Also inserts trailing blocks.
836ea22d46Sbakhtiyar ///
846ea22d46Sbakhtiyar /// The result types of the passed `func` must start with an `async.token`
856ea22d46Sbakhtiyar /// and be continued with some number of `async.value`s.
866ea22d46Sbakhtiyar ///
871c144410Sbakhtiyar /// The func given to this function needs to have been preprocessed to have
881c144410Sbakhtiyar /// either branch or yield ops as terminators. Branches to the cleanup block are
891c144410Sbakhtiyar /// inserted after each yield.
9025f80e16SEugene Zhulenev ///
9125f80e16SEugene Zhulenev /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
9225f80e16SEugene Zhulenev ///
9325f80e16SEugene Zhulenev ///  - `entry` block sets up the coroutine.
9439957aa4SEugene Zhulenev ///  - `set_error` block sets completion token and async values state to error.
9525f80e16SEugene Zhulenev ///  - `cleanup` block cleans up the coroutine state.
9625f80e16SEugene Zhulenev ///  - `suspend block after the @llvm.coro.end() defines what value will be
9725f80e16SEugene Zhulenev ///    returned to the initial caller of a coroutine. Everything before the
9825f80e16SEugene Zhulenev ///    @llvm.coro.end() will be executed at every suspension point.
9925f80e16SEugene Zhulenev ///
10025f80e16SEugene Zhulenev /// Coroutine structure (only the important bits):
10125f80e16SEugene Zhulenev ///
1026ea22d46Sbakhtiyar ///   func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
10325f80e16SEugene Zhulenev ///   {
10425f80e16SEugene Zhulenev ///     ^entry(<function-arguments>):
10525f80e16SEugene Zhulenev ///       %token = <async token> : !async.token    // create async runtime token
10625f80e16SEugene Zhulenev ///       %value = <async value> : !async.value<T> // create async value
10725f80e16SEugene Zhulenev ///       %id = async.coro.id                      // create a coroutine id
10825f80e16SEugene Zhulenev ///       %hdl = async.coro.begin %id              // create a coroutine handle
109ace01605SRiver Riddle ///       cf.br ^preexisting_entry_block
1106ea22d46Sbakhtiyar ///
1111c144410Sbakhtiyar ///     /*  preexisting blocks modified to branch to the cleanup block */
11225f80e16SEugene Zhulenev ///
11339957aa4SEugene Zhulenev ///     ^set_error: // this block created lazily only if needed (see code below)
11439957aa4SEugene Zhulenev ///       async.runtime.set_error %token : !async.token
11539957aa4SEugene Zhulenev ///       async.runtime.set_error %value : !async.value<T>
116ace01605SRiver Riddle ///       cf.br ^cleanup
11739957aa4SEugene Zhulenev ///
11825f80e16SEugene Zhulenev ///     ^cleanup:
11925f80e16SEugene Zhulenev ///       async.coro.free %hdl // delete the coroutine state
120ace01605SRiver Riddle ///       cf.br ^suspend
12125f80e16SEugene Zhulenev ///
12225f80e16SEugene Zhulenev ///     ^suspend:
12325f80e16SEugene Zhulenev ///       async.coro.end %hdl // marks the end of a coroutine
12425f80e16SEugene Zhulenev ///       return %token, %value : !async.token, !async.value<T>
12525f80e16SEugene Zhulenev ///   }
12625f80e16SEugene Zhulenev ///
setupCoroMachinery(func::FuncOp func)12758ceae95SRiver Riddle static CoroMachinery setupCoroMachinery(func::FuncOp func) {
1286ea22d46Sbakhtiyar   assert(!func.getBlocks().empty() && "Function must have an entry block");
12925f80e16SEugene Zhulenev 
13025f80e16SEugene Zhulenev   MLIRContext *ctx = func.getContext();
1316ea22d46Sbakhtiyar   Block *entryBlock = &func.getBlocks().front();
1321c144410Sbakhtiyar   Block *originalEntryBlock =
1331c144410Sbakhtiyar       entryBlock->splitBlock(entryBlock->getOperations().begin());
13425f80e16SEugene Zhulenev   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
13525f80e16SEugene Zhulenev 
13625f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
13725f80e16SEugene Zhulenev   // Allocate async token/values that we will return from a ramp function.
13825f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
13925f80e16SEugene Zhulenev   auto retToken = builder.create<RuntimeCreateOp>(TokenType::get(ctx)).result();
14025f80e16SEugene Zhulenev 
14125f80e16SEugene Zhulenev   llvm::SmallVector<Value, 4> retValues;
14225f80e16SEugene Zhulenev   for (auto resType : func.getCallableResults().drop_front())
14325f80e16SEugene Zhulenev     retValues.emplace_back(builder.create<RuntimeCreateOp>(resType).result());
14425f80e16SEugene Zhulenev 
14525f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
14625f80e16SEugene Zhulenev   // Initialize coroutine: get coroutine id and coroutine handle.
14725f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
14825f80e16SEugene Zhulenev   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
14925f80e16SEugene Zhulenev   auto coroHdlOp =
15025f80e16SEugene Zhulenev       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.id());
151ace01605SRiver Riddle   builder.create<cf::BranchOp>(originalEntryBlock);
15225f80e16SEugene Zhulenev 
15325f80e16SEugene Zhulenev   Block *cleanupBlock = func.addBlock();
15425f80e16SEugene Zhulenev   Block *suspendBlock = func.addBlock();
15525f80e16SEugene Zhulenev 
15625f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
15725f80e16SEugene Zhulenev   // Coroutine cleanup block: deallocate coroutine frame, free the memory.
15825f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
15925f80e16SEugene Zhulenev   builder.setInsertionPointToStart(cleanupBlock);
16025f80e16SEugene Zhulenev   builder.create<CoroFreeOp>(coroIdOp.id(), coroHdlOp.handle());
16125f80e16SEugene Zhulenev 
16225f80e16SEugene Zhulenev   // Branch into the suspend block.
163ace01605SRiver Riddle   builder.create<cf::BranchOp>(suspendBlock);
16425f80e16SEugene Zhulenev 
16525f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
16625f80e16SEugene Zhulenev   // Coroutine suspend block: mark the end of a coroutine and return allocated
16725f80e16SEugene Zhulenev   // async token.
16825f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
16925f80e16SEugene Zhulenev   builder.setInsertionPointToStart(suspendBlock);
17025f80e16SEugene Zhulenev 
17125f80e16SEugene Zhulenev   // Mark the end of a coroutine: async.coro.end
17225f80e16SEugene Zhulenev   builder.create<CoroEndOp>(coroHdlOp.handle());
17325f80e16SEugene Zhulenev 
17425f80e16SEugene Zhulenev   // Return created `async.token` and `async.values` from the suspend block.
17525f80e16SEugene Zhulenev   // This will be the return value of a coroutine ramp function.
17625f80e16SEugene Zhulenev   SmallVector<Value, 4> ret{retToken};
17725f80e16SEugene Zhulenev   ret.insert(ret.end(), retValues.begin(), retValues.end());
17823aa5a74SRiver Riddle   builder.create<func::ReturnOp>(ret);
17925f80e16SEugene Zhulenev 
18025f80e16SEugene Zhulenev   // `async.await` op lowering will create resume blocks for async
18125f80e16SEugene Zhulenev   // continuations, and will conditionally branch to cleanup or suspend blocks.
18225f80e16SEugene Zhulenev 
183f8d5c73cSRiver Riddle   for (Block &block : func.getBody().getBlocks()) {
1841c144410Sbakhtiyar     if (&block == entryBlock || &block == cleanupBlock ||
1851c144410Sbakhtiyar         &block == suspendBlock)
1861c144410Sbakhtiyar       continue;
1871c144410Sbakhtiyar     Operation *terminator = block.getTerminator();
1881c144410Sbakhtiyar     if (auto yield = dyn_cast<YieldOp>(terminator)) {
1891c144410Sbakhtiyar       builder.setInsertionPointToEnd(&block);
190ace01605SRiver Riddle       builder.create<cf::BranchOp>(cleanupBlock);
1911c144410Sbakhtiyar     }
1921c144410Sbakhtiyar   }
1931c144410Sbakhtiyar 
194c75cedc2SChuanqi Xu   // The switch-resumed API based coroutine should be marked with
195735e6c40SChuanqi Xu   // coroutine.presplit attribute to mark the function as a coroutine.
196735e6c40SChuanqi Xu   func->setAttr("passthrough", builder.getArrayAttr(
197735e6c40SChuanqi Xu                                    StringAttr::get(ctx, "presplitcoroutine")));
198c75cedc2SChuanqi Xu 
19925f80e16SEugene Zhulenev   CoroMachinery machinery;
20039957aa4SEugene Zhulenev   machinery.func = func;
20125f80e16SEugene Zhulenev   machinery.asyncToken = retToken;
20225f80e16SEugene Zhulenev   machinery.returnValues = retValues;
20325f80e16SEugene Zhulenev   machinery.coroHandle = coroHdlOp.handle();
2041c144410Sbakhtiyar   machinery.entry = entryBlock;
20539957aa4SEugene Zhulenev   machinery.setError = nullptr; // created lazily only if needed
20625f80e16SEugene Zhulenev   machinery.cleanup = cleanupBlock;
20725f80e16SEugene Zhulenev   machinery.suspend = suspendBlock;
20825f80e16SEugene Zhulenev   return machinery;
20925f80e16SEugene Zhulenev }
21025f80e16SEugene Zhulenev 
21139957aa4SEugene Zhulenev // Lazily creates `set_error` block only if it is required for lowering to the
21239957aa4SEugene Zhulenev // runtime operations (see for example lowering of assert operation).
setupSetErrorBlock(CoroMachinery & coro)21339957aa4SEugene Zhulenev static Block *setupSetErrorBlock(CoroMachinery &coro) {
21439957aa4SEugene Zhulenev   if (coro.setError)
21539957aa4SEugene Zhulenev     return coro.setError;
21639957aa4SEugene Zhulenev 
21739957aa4SEugene Zhulenev   coro.setError = coro.func.addBlock();
21839957aa4SEugene Zhulenev   coro.setError->moveBefore(coro.cleanup);
21939957aa4SEugene Zhulenev 
22039957aa4SEugene Zhulenev   auto builder =
22139957aa4SEugene Zhulenev       ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
22239957aa4SEugene Zhulenev 
22339957aa4SEugene Zhulenev   // Coroutine set_error block: set error on token and all returned values.
22439957aa4SEugene Zhulenev   builder.create<RuntimeSetErrorOp>(coro.asyncToken);
22539957aa4SEugene Zhulenev   for (Value retValue : coro.returnValues)
22639957aa4SEugene Zhulenev     builder.create<RuntimeSetErrorOp>(retValue);
22739957aa4SEugene Zhulenev 
22839957aa4SEugene Zhulenev   // Branch into the cleanup block.
229ace01605SRiver Riddle   builder.create<cf::BranchOp>(coro.cleanup);
23039957aa4SEugene Zhulenev 
23139957aa4SEugene Zhulenev   return coro.setError;
23239957aa4SEugene Zhulenev }
23339957aa4SEugene Zhulenev 
23425f80e16SEugene Zhulenev /// Outline the body region attached to the `async.execute` op into a standalone
23525f80e16SEugene Zhulenev /// function.
23625f80e16SEugene Zhulenev ///
23725f80e16SEugene Zhulenev /// Note that this is not reversible transformation.
23858ceae95SRiver Riddle static std::pair<func::FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable & symbolTable,ExecuteOp execute)23925f80e16SEugene Zhulenev outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
24025f80e16SEugene Zhulenev   ModuleOp module = execute->getParentOfType<ModuleOp>();
24125f80e16SEugene Zhulenev 
24225f80e16SEugene Zhulenev   MLIRContext *ctx = module.getContext();
24325f80e16SEugene Zhulenev   Location loc = execute.getLoc();
24425f80e16SEugene Zhulenev 
245b537c5b4SEugene Zhulenev   // Make sure that all constants will be inside the outlined async function to
246b537c5b4SEugene Zhulenev   // reduce the number of function arguments.
247b537c5b4SEugene Zhulenev   cloneConstantsIntoTheRegion(execute.body());
248b537c5b4SEugene Zhulenev 
24925f80e16SEugene Zhulenev   // Collect all outlined function inputs.
2504efb7754SRiver Riddle   SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
25125f80e16SEugene Zhulenev                                         execute.dependencies().end());
25225f80e16SEugene Zhulenev   functionInputs.insert(execute.operands().begin(), execute.operands().end());
25325f80e16SEugene Zhulenev   getUsedValuesDefinedAbove(execute.body(), functionInputs);
25425f80e16SEugene Zhulenev 
25525f80e16SEugene Zhulenev   // Collect types for the outlined function inputs and outputs.
25625f80e16SEugene Zhulenev   auto typesRange = llvm::map_range(
25725f80e16SEugene Zhulenev       functionInputs, [](Value value) { return value.getType(); });
25825f80e16SEugene Zhulenev   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
25925f80e16SEugene Zhulenev   auto outputTypes = execute.getResultTypes();
26025f80e16SEugene Zhulenev 
26125f80e16SEugene Zhulenev   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
26225f80e16SEugene Zhulenev   auto funcAttrs = ArrayRef<NamedAttribute>();
26325f80e16SEugene Zhulenev 
26425f80e16SEugene Zhulenev   // TODO: Derive outlined function name from the parent FuncOp (support
26525f80e16SEugene Zhulenev   // multiple nested async.execute operations).
26658ceae95SRiver Riddle   func::FuncOp func =
26758ceae95SRiver Riddle       func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
268973ddb7dSMehdi Amini   symbolTable.insert(func);
26925f80e16SEugene Zhulenev 
27025f80e16SEugene Zhulenev   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
2711c144410Sbakhtiyar   auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
27225f80e16SEugene Zhulenev 
2731c144410Sbakhtiyar   // Prepare for coroutine conversion by creating the body of the function.
2741c144410Sbakhtiyar   {
27525f80e16SEugene Zhulenev     size_t numDependencies = execute.dependencies().size();
27625f80e16SEugene Zhulenev     size_t numOperands = execute.operands().size();
27725f80e16SEugene Zhulenev 
27825f80e16SEugene Zhulenev     // Await on all dependencies before starting to execute the body region.
27925f80e16SEugene Zhulenev     for (size_t i = 0; i < numDependencies; ++i)
28025f80e16SEugene Zhulenev       builder.create<AwaitOp>(func.getArgument(i));
28125f80e16SEugene Zhulenev 
28225f80e16SEugene Zhulenev     // Await on all async value operands and unwrap the payload.
28325f80e16SEugene Zhulenev     SmallVector<Value, 4> unwrappedOperands(numOperands);
28425f80e16SEugene Zhulenev     for (size_t i = 0; i < numOperands; ++i) {
28525f80e16SEugene Zhulenev       Value operand = func.getArgument(numDependencies + i);
28625f80e16SEugene Zhulenev       unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).result();
28725f80e16SEugene Zhulenev     }
28825f80e16SEugene Zhulenev 
28925f80e16SEugene Zhulenev     // Map from function inputs defined above the execute op to the function
29025f80e16SEugene Zhulenev     // arguments.
29125f80e16SEugene Zhulenev     BlockAndValueMapping valueMapping;
29225f80e16SEugene Zhulenev     valueMapping.map(functionInputs, func.getArguments());
29325f80e16SEugene Zhulenev     valueMapping.map(execute.body().getArguments(), unwrappedOperands);
29425f80e16SEugene Zhulenev 
29525f80e16SEugene Zhulenev     // Clone all operations from the execute operation body into the outlined
29625f80e16SEugene Zhulenev     // function body.
29725f80e16SEugene Zhulenev     for (Operation &op : execute.body().getOps())
29825f80e16SEugene Zhulenev       builder.clone(op, valueMapping);
2991c144410Sbakhtiyar   }
3001c144410Sbakhtiyar 
3011c144410Sbakhtiyar   // Adding entry/cleanup/suspend blocks.
3021c144410Sbakhtiyar   CoroMachinery coro = setupCoroMachinery(func);
3031c144410Sbakhtiyar 
3041c144410Sbakhtiyar   // Suspend async function at the end of an entry block, and resume it using
3051c144410Sbakhtiyar   // Async resume operation (execution will be resumed in a thread managed by
3061c144410Sbakhtiyar   // the async runtime).
3071c144410Sbakhtiyar   {
308ace01605SRiver Riddle     cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
3091c144410Sbakhtiyar     builder.setInsertionPointToEnd(coro.entry);
3101c144410Sbakhtiyar 
3111c144410Sbakhtiyar     // Save the coroutine state: async.coro.save
3121c144410Sbakhtiyar     auto coroSaveOp =
3131c144410Sbakhtiyar         builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
3141c144410Sbakhtiyar 
3151c144410Sbakhtiyar     // Pass coroutine to the runtime to be resumed on a runtime managed
3161c144410Sbakhtiyar     // thread.
3171c144410Sbakhtiyar     builder.create<RuntimeResumeOp>(coro.coroHandle);
3181c144410Sbakhtiyar 
3191c144410Sbakhtiyar     // Add async.coro.suspend as a suspended block terminator.
3201c144410Sbakhtiyar     builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend,
3211c144410Sbakhtiyar                                   branch.getDest(), coro.cleanup);
3221c144410Sbakhtiyar 
3231c144410Sbakhtiyar     branch.erase();
3241c144410Sbakhtiyar   }
32525f80e16SEugene Zhulenev 
32625f80e16SEugene Zhulenev   // Replace the original `async.execute` with a call to outlined function.
3271c144410Sbakhtiyar   {
32825f80e16SEugene Zhulenev     ImplicitLocOpBuilder callBuilder(loc, execute);
32923aa5a74SRiver Riddle     auto callOutlinedFunc = callBuilder.create<func::CallOp>(
33025f80e16SEugene Zhulenev         func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
33125f80e16SEugene Zhulenev     execute.replaceAllUsesWith(callOutlinedFunc.getResults());
33225f80e16SEugene Zhulenev     execute.erase();
3331c144410Sbakhtiyar   }
33425f80e16SEugene Zhulenev 
33525f80e16SEugene Zhulenev   return {func, coro};
33625f80e16SEugene Zhulenev }
33725f80e16SEugene Zhulenev 
33825f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
339d43b2360SEugene Zhulenev // Convert async.create_group operation to async.runtime.create_group
34025f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
34125f80e16SEugene Zhulenev 
34225f80e16SEugene Zhulenev namespace {
34325f80e16SEugene Zhulenev class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
34425f80e16SEugene Zhulenev public:
34525f80e16SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
34625f80e16SEugene Zhulenev 
34725f80e16SEugene Zhulenev   LogicalResult
matchAndRewrite(CreateGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const348b54c724bSRiver Riddle   matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
34925f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
350d43b2360SEugene Zhulenev     rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
351b54c724bSRiver Riddle         op, GroupType::get(op->getContext()), adaptor.getOperands());
35225f80e16SEugene Zhulenev     return success();
35325f80e16SEugene Zhulenev   }
35425f80e16SEugene Zhulenev };
35525f80e16SEugene Zhulenev } // namespace
35625f80e16SEugene Zhulenev 
35725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
35825f80e16SEugene Zhulenev // Convert async.add_to_group operation to async.runtime.add_to_group.
35925f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
36025f80e16SEugene Zhulenev 
36125f80e16SEugene Zhulenev namespace {
36225f80e16SEugene Zhulenev class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
36325f80e16SEugene Zhulenev public:
36425f80e16SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
36525f80e16SEugene Zhulenev 
36625f80e16SEugene Zhulenev   LogicalResult
matchAndRewrite(AddToGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const367b54c724bSRiver Riddle   matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
36825f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
36925f80e16SEugene Zhulenev     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
370b54c724bSRiver Riddle         op, rewriter.getIndexType(), adaptor.getOperands());
37125f80e16SEugene Zhulenev     return success();
37225f80e16SEugene Zhulenev   }
37325f80e16SEugene Zhulenev };
37425f80e16SEugene Zhulenev } // namespace
37525f80e16SEugene Zhulenev 
37625f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
37725f80e16SEugene Zhulenev // Convert async.await and async.await_all operations to the async.runtime.await
37825f80e16SEugene Zhulenev // or async.runtime.await_and_resume operations.
37925f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
38025f80e16SEugene Zhulenev 
38125f80e16SEugene Zhulenev namespace {
38225f80e16SEugene Zhulenev template <typename AwaitType, typename AwaitableType>
38325f80e16SEugene Zhulenev class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
38425f80e16SEugene Zhulenev   using AwaitAdaptor = typename AwaitType::Adaptor;
38525f80e16SEugene Zhulenev 
38625f80e16SEugene Zhulenev public:
AwaitOpLoweringBase(MLIRContext * ctx,llvm::DenseMap<func::FuncOp,CoroMachinery> & outlinedFunctions)38758ceae95SRiver Riddle   AwaitOpLoweringBase(
38858ceae95SRiver Riddle       MLIRContext *ctx,
38958ceae95SRiver Riddle       llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
39025f80e16SEugene Zhulenev       : OpConversionPattern<AwaitType>(ctx),
39125f80e16SEugene Zhulenev         outlinedFunctions(outlinedFunctions) {}
39225f80e16SEugene Zhulenev 
39325f80e16SEugene Zhulenev   LogicalResult
matchAndRewrite(AwaitType op,typename AwaitType::Adaptor adaptor,ConversionPatternRewriter & rewriter) const394b54c724bSRiver Riddle   matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
39525f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
39625f80e16SEugene Zhulenev     // We can only await on one the `AwaitableType` (for `await` it can be
39725f80e16SEugene Zhulenev     // a `token` or a `value`, for `await_all` it must be a `group`).
39825f80e16SEugene Zhulenev     if (!op.operand().getType().template isa<AwaitableType>())
39925f80e16SEugene Zhulenev       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
40025f80e16SEugene Zhulenev 
40125f80e16SEugene Zhulenev     // Check if await operation is inside the outlined coroutine function.
40258ceae95SRiver Riddle     auto func = op->template getParentOfType<func::FuncOp>();
40325f80e16SEugene Zhulenev     auto outlined = outlinedFunctions.find(func);
40425f80e16SEugene Zhulenev     const bool isInCoroutine = outlined != outlinedFunctions.end();
40525f80e16SEugene Zhulenev 
40625f80e16SEugene Zhulenev     Location loc = op->getLoc();
407b54c724bSRiver Riddle     Value operand = adaptor.operand();
40825f80e16SEugene Zhulenev 
409fd52b435SEugene Zhulenev     Type i1 = rewriter.getI1Type();
410fd52b435SEugene Zhulenev 
41125f80e16SEugene Zhulenev     // Inside regular functions we use the blocking wait operation to wait for
41225f80e16SEugene Zhulenev     // the async object (token, value or group) to become available.
413fd52b435SEugene Zhulenev     if (!isInCoroutine) {
414fd52b435SEugene Zhulenev       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
415fd52b435SEugene Zhulenev       builder.create<RuntimeAwaitOp>(loc, operand);
416fd52b435SEugene Zhulenev 
417fd52b435SEugene Zhulenev       // Assert that the awaited operands is not in the error state.
418fd52b435SEugene Zhulenev       Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
419a54f4eaeSMogball       Value notError = builder.create<arith::XOrIOp>(
420a54f4eaeSMogball           isError, builder.create<arith::ConstantOp>(
421a54f4eaeSMogball                        loc, i1, builder.getIntegerAttr(i1, 1)));
422fd52b435SEugene Zhulenev 
423ace01605SRiver Riddle       builder.create<cf::AssertOp>(notError,
424fd52b435SEugene Zhulenev                                    "Awaited async operand is in error state");
425fd52b435SEugene Zhulenev     }
42625f80e16SEugene Zhulenev 
42725f80e16SEugene Zhulenev     // Inside the coroutine we convert await operation into coroutine suspension
42825f80e16SEugene Zhulenev     // point, and resume execution asynchronously.
42925f80e16SEugene Zhulenev     if (isInCoroutine) {
43039957aa4SEugene Zhulenev       CoroMachinery &coro = outlined->getSecond();
43125f80e16SEugene Zhulenev       Block *suspended = op->getBlock();
43225f80e16SEugene Zhulenev 
43325f80e16SEugene Zhulenev       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
43425f80e16SEugene Zhulenev       MLIRContext *ctx = op->getContext();
43525f80e16SEugene Zhulenev 
43625f80e16SEugene Zhulenev       // Save the coroutine state and resume on a runtime managed thread when
43725f80e16SEugene Zhulenev       // the operand becomes available.
43825f80e16SEugene Zhulenev       auto coroSaveOp =
43925f80e16SEugene Zhulenev           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
44025f80e16SEugene Zhulenev       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
44125f80e16SEugene Zhulenev 
44225f80e16SEugene Zhulenev       // Split the entry block before the await operation.
44325f80e16SEugene Zhulenev       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
44425f80e16SEugene Zhulenev 
44525f80e16SEugene Zhulenev       // Add async.coro.suspend as a suspended block terminator.
44625f80e16SEugene Zhulenev       builder.setInsertionPointToEnd(suspended);
44725f80e16SEugene Zhulenev       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
44825f80e16SEugene Zhulenev                                     coro.cleanup);
44925f80e16SEugene Zhulenev 
45039957aa4SEugene Zhulenev       // Split the resume block into error checking and continuation.
45139957aa4SEugene Zhulenev       Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
45239957aa4SEugene Zhulenev 
45339957aa4SEugene Zhulenev       // Check if the awaited value is in the error state.
45439957aa4SEugene Zhulenev       builder.setInsertionPointToStart(resume);
455fd52b435SEugene Zhulenev       auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
456ace01605SRiver Riddle       builder.create<cf::CondBranchOp>(isError,
45739957aa4SEugene Zhulenev                                        /*trueDest=*/setupSetErrorBlock(coro),
45839957aa4SEugene Zhulenev                                        /*trueArgs=*/ArrayRef<Value>(),
45939957aa4SEugene Zhulenev                                        /*falseDest=*/continuation,
46039957aa4SEugene Zhulenev                                        /*falseArgs=*/ArrayRef<Value>());
46139957aa4SEugene Zhulenev 
46239957aa4SEugene Zhulenev       // Make sure that replacement value will be constructed in the
46339957aa4SEugene Zhulenev       // continuation block.
46439957aa4SEugene Zhulenev       rewriter.setInsertionPointToStart(continuation);
46539957aa4SEugene Zhulenev     }
46625f80e16SEugene Zhulenev 
46725f80e16SEugene Zhulenev     // Erase or replace the await operation with the new value.
46825f80e16SEugene Zhulenev     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
46925f80e16SEugene Zhulenev       rewriter.replaceOp(op, replaceWith);
47025f80e16SEugene Zhulenev     else
47125f80e16SEugene Zhulenev       rewriter.eraseOp(op);
47225f80e16SEugene Zhulenev 
47325f80e16SEugene Zhulenev     return success();
47425f80e16SEugene Zhulenev   }
47525f80e16SEugene Zhulenev 
getReplacementValue(AwaitType op,Value operand,ConversionPatternRewriter & rewriter) const47625f80e16SEugene Zhulenev   virtual Value getReplacementValue(AwaitType op, Value operand,
47725f80e16SEugene Zhulenev                                     ConversionPatternRewriter &rewriter) const {
47825f80e16SEugene Zhulenev     return Value();
47925f80e16SEugene Zhulenev   }
48025f80e16SEugene Zhulenev 
48125f80e16SEugene Zhulenev private:
48258ceae95SRiver Riddle   llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
48325f80e16SEugene Zhulenev };
48425f80e16SEugene Zhulenev 
48525f80e16SEugene Zhulenev /// Lowering for `async.await` with a token operand.
48625f80e16SEugene Zhulenev class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
48725f80e16SEugene Zhulenev   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
48825f80e16SEugene Zhulenev 
48925f80e16SEugene Zhulenev public:
49025f80e16SEugene Zhulenev   using Base::Base;
49125f80e16SEugene Zhulenev };
49225f80e16SEugene Zhulenev 
49325f80e16SEugene Zhulenev /// Lowering for `async.await` with a value operand.
49425f80e16SEugene Zhulenev class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
49525f80e16SEugene Zhulenev   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
49625f80e16SEugene Zhulenev 
49725f80e16SEugene Zhulenev public:
49825f80e16SEugene Zhulenev   using Base::Base;
49925f80e16SEugene Zhulenev 
50025f80e16SEugene Zhulenev   Value
getReplacementValue(AwaitOp op,Value operand,ConversionPatternRewriter & rewriter) const50125f80e16SEugene Zhulenev   getReplacementValue(AwaitOp op, Value operand,
50225f80e16SEugene Zhulenev                       ConversionPatternRewriter &rewriter) const override {
50325f80e16SEugene Zhulenev     // Load from the async value storage.
50425f80e16SEugene Zhulenev     auto valueType = operand.getType().cast<ValueType>().getValueType();
50525f80e16SEugene Zhulenev     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
50625f80e16SEugene Zhulenev   }
50725f80e16SEugene Zhulenev };
50825f80e16SEugene Zhulenev 
50925f80e16SEugene Zhulenev /// Lowering for `async.await_all` operation.
51025f80e16SEugene Zhulenev class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
51125f80e16SEugene Zhulenev   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
51225f80e16SEugene Zhulenev 
51325f80e16SEugene Zhulenev public:
51425f80e16SEugene Zhulenev   using Base::Base;
51525f80e16SEugene Zhulenev };
51625f80e16SEugene Zhulenev 
51725f80e16SEugene Zhulenev } // namespace
51825f80e16SEugene Zhulenev 
51925f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
52025f80e16SEugene Zhulenev // Convert async.yield operation to async.runtime operations.
52125f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
52225f80e16SEugene Zhulenev 
52325f80e16SEugene Zhulenev class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
52425f80e16SEugene Zhulenev public:
YieldOpLowering(MLIRContext * ctx,const llvm::DenseMap<func::FuncOp,CoroMachinery> & outlinedFunctions)52525f80e16SEugene Zhulenev   YieldOpLowering(
52625f80e16SEugene Zhulenev       MLIRContext *ctx,
52758ceae95SRiver Riddle       const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
52825f80e16SEugene Zhulenev       : OpConversionPattern<async::YieldOp>(ctx),
52925f80e16SEugene Zhulenev         outlinedFunctions(outlinedFunctions) {}
53025f80e16SEugene Zhulenev 
53125f80e16SEugene Zhulenev   LogicalResult
matchAndRewrite(async::YieldOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const532b54c724bSRiver Riddle   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
53325f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
53439957aa4SEugene Zhulenev     // Check if yield operation is inside the async coroutine function.
53558ceae95SRiver Riddle     auto func = op->template getParentOfType<func::FuncOp>();
53625f80e16SEugene Zhulenev     auto outlined = outlinedFunctions.find(func);
53725f80e16SEugene Zhulenev     if (outlined == outlinedFunctions.end())
53825f80e16SEugene Zhulenev       return rewriter.notifyMatchFailure(
53939957aa4SEugene Zhulenev           op, "operation is not inside the async coroutine function");
54025f80e16SEugene Zhulenev 
54125f80e16SEugene Zhulenev     Location loc = op->getLoc();
54225f80e16SEugene Zhulenev     const CoroMachinery &coro = outlined->getSecond();
54325f80e16SEugene Zhulenev 
54425f80e16SEugene Zhulenev     // Store yielded values into the async values storage and switch async
54525f80e16SEugene Zhulenev     // values state to available.
546b54c724bSRiver Riddle     for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
54725f80e16SEugene Zhulenev       Value yieldValue = std::get<0>(tuple);
54825f80e16SEugene Zhulenev       Value asyncValue = std::get<1>(tuple);
54925f80e16SEugene Zhulenev       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
55025f80e16SEugene Zhulenev       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
55125f80e16SEugene Zhulenev     }
55225f80e16SEugene Zhulenev 
55325f80e16SEugene Zhulenev     // Switch the coroutine completion token to available state.
55425f80e16SEugene Zhulenev     rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
55525f80e16SEugene Zhulenev 
55625f80e16SEugene Zhulenev     return success();
55725f80e16SEugene Zhulenev   }
55825f80e16SEugene Zhulenev 
55925f80e16SEugene Zhulenev private:
56058ceae95SRiver Riddle   const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
56125f80e16SEugene Zhulenev };
56225f80e16SEugene Zhulenev 
56325f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
56423aa5a74SRiver Riddle // Convert cf.assert operation to cf.cond_br into `set_error` block.
56539957aa4SEugene Zhulenev //===----------------------------------------------------------------------===//
56639957aa4SEugene Zhulenev 
567ace01605SRiver Riddle class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
56839957aa4SEugene Zhulenev public:
AssertOpLowering(MLIRContext * ctx,llvm::DenseMap<func::FuncOp,CoroMachinery> & outlinedFunctions)56958ceae95SRiver Riddle   AssertOpLowering(
57058ceae95SRiver Riddle       MLIRContext *ctx,
57158ceae95SRiver Riddle       llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
572ace01605SRiver Riddle       : OpConversionPattern<cf::AssertOp>(ctx),
57339957aa4SEugene Zhulenev         outlinedFunctions(outlinedFunctions) {}
57439957aa4SEugene Zhulenev 
57539957aa4SEugene Zhulenev   LogicalResult
matchAndRewrite(cf::AssertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const576ace01605SRiver Riddle   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
57739957aa4SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
57839957aa4SEugene Zhulenev     // Check if assert operation is inside the async coroutine function.
57958ceae95SRiver Riddle     auto func = op->template getParentOfType<func::FuncOp>();
58039957aa4SEugene Zhulenev     auto outlined = outlinedFunctions.find(func);
58139957aa4SEugene Zhulenev     if (outlined == outlinedFunctions.end())
58239957aa4SEugene Zhulenev       return rewriter.notifyMatchFailure(
58339957aa4SEugene Zhulenev           op, "operation is not inside the async coroutine function");
58439957aa4SEugene Zhulenev 
58539957aa4SEugene Zhulenev     Location loc = op->getLoc();
58639957aa4SEugene Zhulenev     CoroMachinery &coro = outlined->getSecond();
58739957aa4SEugene Zhulenev 
58839957aa4SEugene Zhulenev     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
58939957aa4SEugene Zhulenev     rewriter.setInsertionPointToEnd(cont->getPrevNode());
590ace01605SRiver Riddle     rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
59139957aa4SEugene Zhulenev                                       /*trueDest=*/cont,
59239957aa4SEugene Zhulenev                                       /*trueArgs=*/ArrayRef<Value>(),
59339957aa4SEugene Zhulenev                                       /*falseDest=*/setupSetErrorBlock(coro),
59439957aa4SEugene Zhulenev                                       /*falseArgs=*/ArrayRef<Value>());
59539957aa4SEugene Zhulenev     rewriter.eraseOp(op);
59639957aa4SEugene Zhulenev 
59739957aa4SEugene Zhulenev     return success();
59839957aa4SEugene Zhulenev   }
59939957aa4SEugene Zhulenev 
60039957aa4SEugene Zhulenev private:
60158ceae95SRiver Riddle   llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
60239957aa4SEugene Zhulenev };
60339957aa4SEugene Zhulenev 
60439957aa4SEugene Zhulenev //===----------------------------------------------------------------------===//
60525f80e16SEugene Zhulenev 
6066ea22d46Sbakhtiyar /// Rewrite a func as a coroutine by:
6076ea22d46Sbakhtiyar /// 1) Wrapping the results into `async.value`.
6086ea22d46Sbakhtiyar /// 2) Prepending the results with `async.token`.
6096ea22d46Sbakhtiyar /// 3) Setting up coroutine blocks.
6106ea22d46Sbakhtiyar /// 4) Rewriting return ops as yield op and branch op into the suspend block.
rewriteFuncAsCoroutine(func::FuncOp func)61158ceae95SRiver Riddle static CoroMachinery rewriteFuncAsCoroutine(func::FuncOp func) {
6126ea22d46Sbakhtiyar   auto *ctx = func->getContext();
6136ea22d46Sbakhtiyar   auto loc = func.getLoc();
6146ea22d46Sbakhtiyar   SmallVector<Type> resultTypes;
6156ea22d46Sbakhtiyar   resultTypes.reserve(func.getCallableResults().size());
6166ea22d46Sbakhtiyar   llvm::transform(func.getCallableResults(), std::back_inserter(resultTypes),
6176ea22d46Sbakhtiyar                   [](Type type) { return ValueType::get(type); });
6184a3460a7SRiver Riddle   func.setType(
6194a3460a7SRiver Riddle       FunctionType::get(ctx, func.getFunctionType().getInputs(), resultTypes));
6206ea22d46Sbakhtiyar   func.insertResult(0, TokenType::get(ctx), {});
6216ea22d46Sbakhtiyar   for (Block &block : func.getBlocks()) {
6226ea22d46Sbakhtiyar     Operation *terminator = block.getTerminator();
62323aa5a74SRiver Riddle     if (auto returnOp = dyn_cast<func::ReturnOp>(*terminator)) {
6246ea22d46Sbakhtiyar       ImplicitLocOpBuilder builder(loc, returnOp);
6256ea22d46Sbakhtiyar       builder.create<YieldOp>(returnOp.getOperands());
6266ea22d46Sbakhtiyar       returnOp.erase();
6276ea22d46Sbakhtiyar     }
6286ea22d46Sbakhtiyar   }
6291c144410Sbakhtiyar   return setupCoroMachinery(func);
6306ea22d46Sbakhtiyar }
6316ea22d46Sbakhtiyar 
6326ea22d46Sbakhtiyar /// Rewrites a call into a function that has been rewritten as a coroutine.
6336ea22d46Sbakhtiyar ///
6346ea22d46Sbakhtiyar /// The invocation of this function is safe only when call ops are traversed in
6356ea22d46Sbakhtiyar /// reverse order of how they appear in a single block. See `funcsToCoroutines`.
rewriteCallsiteForCoroutine(func::CallOp oldCall,func::FuncOp func)63658ceae95SRiver Riddle static void rewriteCallsiteForCoroutine(func::CallOp oldCall,
63758ceae95SRiver Riddle                                         func::FuncOp func) {
6386ea22d46Sbakhtiyar   auto loc = func.getLoc();
6396ea22d46Sbakhtiyar   ImplicitLocOpBuilder callBuilder(loc, oldCall);
64023aa5a74SRiver Riddle   auto newCall = callBuilder.create<func::CallOp>(
6416ea22d46Sbakhtiyar       func.getName(), func.getCallableResults(), oldCall.getArgOperands());
6426ea22d46Sbakhtiyar 
6436ea22d46Sbakhtiyar   // Await on the async token and all the value results and unwrap the latter.
6446ea22d46Sbakhtiyar   callBuilder.create<AwaitOp>(loc, newCall.getResults().front());
6456ea22d46Sbakhtiyar   SmallVector<Value> unwrappedResults;
6466ea22d46Sbakhtiyar   unwrappedResults.reserve(newCall->getResults().size() - 1);
6476ea22d46Sbakhtiyar   for (Value result : newCall.getResults().drop_front())
6486ea22d46Sbakhtiyar     unwrappedResults.push_back(
6496ea22d46Sbakhtiyar         callBuilder.create<AwaitOp>(loc, result).result());
6506ea22d46Sbakhtiyar   // Careful, when result of a call is piped into another call this could lead
6516ea22d46Sbakhtiyar   // to a dangling pointer.
6526ea22d46Sbakhtiyar   oldCall.replaceAllUsesWith(unwrappedResults);
6536ea22d46Sbakhtiyar   oldCall.erase();
6546ea22d46Sbakhtiyar }
6556ea22d46Sbakhtiyar 
isAllowedToBlock(func::FuncOp func)65658ceae95SRiver Riddle static bool isAllowedToBlock(func::FuncOp func) {
6579a5bc836Sbakhtiyar   return !!func->getAttrOfType<UnitAttr>(AsyncDialect::kAllowedToBlockAttrName);
6589a5bc836Sbakhtiyar }
6599a5bc836Sbakhtiyar 
funcsToCoroutines(ModuleOp module,llvm::DenseMap<func::FuncOp,CoroMachinery> & outlinedFunctions)66058ceae95SRiver Riddle static LogicalResult funcsToCoroutines(
66158ceae95SRiver Riddle     ModuleOp module,
66258ceae95SRiver Riddle     llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions) {
6636ea22d46Sbakhtiyar   // The following code supports the general case when 2 functions mutually
6646ea22d46Sbakhtiyar   // recurse into each other. Because of this and that we are relying on
6656ea22d46Sbakhtiyar   // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase
6666ea22d46Sbakhtiyar   // a FuncOp while inserting an equivalent coroutine, because that could lead
6676ea22d46Sbakhtiyar   // to dangling pointers.
6686ea22d46Sbakhtiyar 
66958ceae95SRiver Riddle   SmallVector<func::FuncOp> funcWorklist;
6706ea22d46Sbakhtiyar 
6716ea22d46Sbakhtiyar   // Careful, it's okay to add a func to the worklist multiple times if and only
6726ea22d46Sbakhtiyar   // if the loop processing the worklist will skip the functions that have
6736ea22d46Sbakhtiyar   // already been converted to coroutines.
67458ceae95SRiver Riddle   auto addToWorklist = [&](func::FuncOp func) {
6759a5bc836Sbakhtiyar     if (isAllowedToBlock(func))
6769a5bc836Sbakhtiyar       return;
6776ea22d46Sbakhtiyar     // N.B. To refactor this code into a separate pass the lookup in
6786ea22d46Sbakhtiyar     // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary
6796ea22d46Sbakhtiyar     // func and recognizing if it has a coroutine structure is messy. Passing
6806ea22d46Sbakhtiyar     // this dict between the passes is ugly.
6819a5bc836Sbakhtiyar     if (isAllowedToBlock(func) ||
6829a5bc836Sbakhtiyar         outlinedFunctions.find(func) == outlinedFunctions.end()) {
683f8d5c73cSRiver Riddle       for (Operation &op : func.getBody().getOps()) {
6846ea22d46Sbakhtiyar         if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
6856ea22d46Sbakhtiyar           funcWorklist.push_back(func);
6866ea22d46Sbakhtiyar           break;
6876ea22d46Sbakhtiyar         }
6886ea22d46Sbakhtiyar       }
6896ea22d46Sbakhtiyar     }
6906ea22d46Sbakhtiyar   };
6916ea22d46Sbakhtiyar 
6926ea22d46Sbakhtiyar   // Traverse in post-order collecting for each func op the await ops it has.
69358ceae95SRiver Riddle   for (func::FuncOp func : module.getOps<func::FuncOp>())
6946ea22d46Sbakhtiyar     addToWorklist(func);
6956ea22d46Sbakhtiyar 
6966ea22d46Sbakhtiyar   SymbolTableCollection symbolTable;
6976ea22d46Sbakhtiyar   SymbolUserMap symbolUserMap(symbolTable, module);
6986ea22d46Sbakhtiyar 
6996ea22d46Sbakhtiyar   // Rewrite funcs, while updating call sites and adding them to the worklist.
7006ea22d46Sbakhtiyar   while (!funcWorklist.empty()) {
7016ea22d46Sbakhtiyar     auto func = funcWorklist.pop_back_val();
7026ea22d46Sbakhtiyar     auto insertion = outlinedFunctions.insert({func, CoroMachinery{}});
7036ea22d46Sbakhtiyar     if (!insertion.second)
7046ea22d46Sbakhtiyar       // This function has already been processed because this is either
7056ea22d46Sbakhtiyar       // the corecursive case, or a caller with multiple calls to a newly
7066ea22d46Sbakhtiyar       // created corouting. Either way, skip updating the call sites.
7076ea22d46Sbakhtiyar       continue;
7086ea22d46Sbakhtiyar     insertion.first->second = rewriteFuncAsCoroutine(func);
7096ea22d46Sbakhtiyar     SmallVector<Operation *> users(symbolUserMap.getUsers(func).begin(),
7106ea22d46Sbakhtiyar                                    symbolUserMap.getUsers(func).end());
7116ea22d46Sbakhtiyar     // If there are multiple calls from the same block they need to be traversed
7126ea22d46Sbakhtiyar     // in reverse order so that symbolUserMap references are not invalidated
7136ea22d46Sbakhtiyar     // when updating the users of the call op which is earlier in the block.
7146ea22d46Sbakhtiyar     llvm::sort(users, [](Operation *a, Operation *b) {
7156ea22d46Sbakhtiyar       Block *blockA = a->getBlock();
7166ea22d46Sbakhtiyar       Block *blockB = b->getBlock();
7176ea22d46Sbakhtiyar       // Impose arbitrary order on blocks so that there is a well-defined order.
7186ea22d46Sbakhtiyar       return blockA > blockB || (blockA == blockB && !a->isBeforeInBlock(b));
7196ea22d46Sbakhtiyar     });
7206ea22d46Sbakhtiyar     // Rewrite the callsites to await on results of the newly created coroutine.
7216ea22d46Sbakhtiyar     for (Operation *op : users) {
72223aa5a74SRiver Riddle       if (func::CallOp call = dyn_cast<func::CallOp>(*op)) {
72358ceae95SRiver Riddle         func::FuncOp caller = call->getParentOfType<func::FuncOp>();
7246ea22d46Sbakhtiyar         rewriteCallsiteForCoroutine(call, func); // Careful, erases the call op.
7256ea22d46Sbakhtiyar         addToWorklist(caller);
7266ea22d46Sbakhtiyar       } else {
7276ea22d46Sbakhtiyar         op->emitError("Unexpected reference to func referenced by symbol");
7286ea22d46Sbakhtiyar         return failure();
7296ea22d46Sbakhtiyar       }
7306ea22d46Sbakhtiyar     }
7316ea22d46Sbakhtiyar   }
7326ea22d46Sbakhtiyar   return success();
7336ea22d46Sbakhtiyar }
7346ea22d46Sbakhtiyar 
7356ea22d46Sbakhtiyar //===----------------------------------------------------------------------===//
runOnOperation()73625f80e16SEugene Zhulenev void AsyncToAsyncRuntimePass::runOnOperation() {
73725f80e16SEugene Zhulenev   ModuleOp module = getOperation();
73825f80e16SEugene Zhulenev   SymbolTable symbolTable(module);
73925f80e16SEugene Zhulenev 
74025f80e16SEugene Zhulenev   // Outline all `async.execute` body regions into async functions (coroutines).
74158ceae95SRiver Riddle   llvm::DenseMap<func::FuncOp, CoroMachinery> outlinedFunctions;
74225f80e16SEugene Zhulenev 
74325f80e16SEugene Zhulenev   module.walk([&](ExecuteOp execute) {
74425f80e16SEugene Zhulenev     outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
74525f80e16SEugene Zhulenev   });
74625f80e16SEugene Zhulenev 
74725f80e16SEugene Zhulenev   LLVM_DEBUG({
74825f80e16SEugene Zhulenev     llvm::dbgs() << "Outlined " << outlinedFunctions.size()
74925f80e16SEugene Zhulenev                  << " functions built from async.execute operations\n";
75025f80e16SEugene Zhulenev   });
75125f80e16SEugene Zhulenev 
752de7a4e53SEugene Zhulenev   // Returns true if operation is inside the coroutine.
753de7a4e53SEugene Zhulenev   auto isInCoroutine = [&](Operation *op) -> bool {
75458ceae95SRiver Riddle     auto parentFunc = op->getParentOfType<func::FuncOp>();
755de7a4e53SEugene Zhulenev     return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
756de7a4e53SEugene Zhulenev   };
757de7a4e53SEugene Zhulenev 
7586ea22d46Sbakhtiyar   if (eliminateBlockingAwaitOps &&
7596ea22d46Sbakhtiyar       failed(funcsToCoroutines(module, outlinedFunctions))) {
7606ea22d46Sbakhtiyar     signalPassFailure();
7616ea22d46Sbakhtiyar     return;
7626ea22d46Sbakhtiyar   }
7636ea22d46Sbakhtiyar 
76425f80e16SEugene Zhulenev   // Lower async operations to async.runtime operations.
76525f80e16SEugene Zhulenev   MLIRContext *ctx = module->getContext();
766dc4e913bSChris Lattner   RewritePatternSet asyncPatterns(ctx);
76725f80e16SEugene Zhulenev 
768de7a4e53SEugene Zhulenev   // Conversion to async runtime augments original CFG with the coroutine CFG,
769de7a4e53SEugene Zhulenev   // and we have to make sure that structured control flow operations with async
770de7a4e53SEugene Zhulenev   // operations in nested regions will be converted to branch-based control flow
771de7a4e53SEugene Zhulenev   // before we add the coroutine basic blocks.
772ace01605SRiver Riddle   populateSCFToControlFlowConversionPatterns(asyncPatterns);
773de7a4e53SEugene Zhulenev 
77425f80e16SEugene Zhulenev   // Async lowering does not use type converter because it must preserve all
77525f80e16SEugene Zhulenev   // types for async.runtime operations.
776dc4e913bSChris Lattner   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
777dc4e913bSChris Lattner   asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
77825f80e16SEugene Zhulenev                     AwaitAllOpLowering, YieldOpLowering>(ctx,
77925f80e16SEugene Zhulenev                                                          outlinedFunctions);
78025f80e16SEugene Zhulenev 
78139957aa4SEugene Zhulenev   // Lower assertions to conditional branches into error blocks.
78239957aa4SEugene Zhulenev   asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
78339957aa4SEugene Zhulenev 
78425f80e16SEugene Zhulenev   // All high level async operations must be lowered to the runtime operations.
78525f80e16SEugene Zhulenev   ConversionTarget runtimeTarget(*ctx);
78625f80e16SEugene Zhulenev   runtimeTarget.addLegalDialect<AsyncDialect>();
78725f80e16SEugene Zhulenev   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
78825f80e16SEugene Zhulenev   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
78925f80e16SEugene Zhulenev 
790de7a4e53SEugene Zhulenev   // Decide if structured control flow has to be lowered to branch-based CFG.
791de7a4e53SEugene Zhulenev   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
792de7a4e53SEugene Zhulenev     auto walkResult = op->walk([&](Operation *nested) {
793de7a4e53SEugene Zhulenev       bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
794de7a4e53SEugene Zhulenev       return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
795de7a4e53SEugene Zhulenev                                               : WalkResult::advance();
796de7a4e53SEugene Zhulenev     });
797de7a4e53SEugene Zhulenev     return !walkResult.wasInterrupted();
798de7a4e53SEugene Zhulenev   });
799ace01605SRiver Riddle   runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
80023aa5a74SRiver Riddle                            func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
801de7a4e53SEugene Zhulenev 
8028f23fac4SEugene Zhulenev   // Assertions must be converted to runtime errors inside async functions.
803ace01605SRiver Riddle   runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
804ace01605SRiver Riddle       [&](cf::AssertOp op) -> bool {
80558ceae95SRiver Riddle         auto func = op->getParentOfType<func::FuncOp>();
8068f23fac4SEugene Zhulenev         return outlinedFunctions.find(func) == outlinedFunctions.end();
8078f23fac4SEugene Zhulenev       });
80839957aa4SEugene Zhulenev 
8096ea22d46Sbakhtiyar   if (eliminateBlockingAwaitOps)
8109a5bc836Sbakhtiyar     runtimeTarget.addDynamicallyLegalOp<RuntimeAwaitOp>(
8119a5bc836Sbakhtiyar         [&](RuntimeAwaitOp op) -> bool {
81258ceae95SRiver Riddle           return isAllowedToBlock(op->getParentOfType<func::FuncOp>());
8139a5bc836Sbakhtiyar         });
8146ea22d46Sbakhtiyar 
81525f80e16SEugene Zhulenev   if (failed(applyPartialConversion(module, runtimeTarget,
81625f80e16SEugene Zhulenev                                     std::move(asyncPatterns)))) {
81725f80e16SEugene Zhulenev     signalPassFailure();
81825f80e16SEugene Zhulenev     return;
81925f80e16SEugene Zhulenev   }
82025f80e16SEugene Zhulenev }
82125f80e16SEugene Zhulenev 
createAsyncToAsyncRuntimePass()82225f80e16SEugene Zhulenev std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
82325f80e16SEugene Zhulenev   return std::make_unique<AsyncToAsyncRuntimePass>();
82425f80e16SEugene Zhulenev }
825