136ce915aSLei Zhang //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
236ce915aSLei Zhang //
336ce915aSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
436ce915aSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
536ce915aSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
636ce915aSLei Zhang //
736ce915aSLei Zhang //===----------------------------------------------------------------------===//
836ce915aSLei Zhang //
936ce915aSLei Zhang // This file implements basic Async runtime API for supporting Async dialect
1036ce915aSLei Zhang // to LLVM dialect lowering.
1136ce915aSLei Zhang //
1236ce915aSLei Zhang //===----------------------------------------------------------------------===//
1336ce915aSLei Zhang
1436ce915aSLei Zhang #include "mlir/ExecutionEngine/AsyncRuntime.h"
1536ce915aSLei Zhang
1636ce915aSLei Zhang #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
1736ce915aSLei Zhang
18c30ab6c2SEugene Zhulenev #include <atomic>
19a86a9b5eSEugene Zhulenev #include <cassert>
2036ce915aSLei Zhang #include <condition_variable>
2136ce915aSLei Zhang #include <functional>
2236ce915aSLei Zhang #include <iostream>
2336ce915aSLei Zhang #include <mutex>
2436ce915aSLei Zhang #include <thread>
2536ce915aSLei Zhang #include <vector>
2636ce915aSLei Zhang
271fc98642SEugene Zhulenev #include "llvm/ADT/StringMap.h"
28bb0e6213SEugene Zhulenev #include "llvm/Support/ThreadPool.h"
291fc98642SEugene Zhulenev
3011f1027bSEugene Zhulenev using namespace mlir::runtime;
3111f1027bSEugene Zhulenev
3236ce915aSLei Zhang //===----------------------------------------------------------------------===//
3336ce915aSLei Zhang // Async runtime API.
3436ce915aSLei Zhang //===----------------------------------------------------------------------===//
3536ce915aSLei Zhang
3611f1027bSEugene Zhulenev namespace mlir {
3711f1027bSEugene Zhulenev namespace runtime {
38a86a9b5eSEugene Zhulenev namespace {
39a86a9b5eSEugene Zhulenev
40a86a9b5eSEugene Zhulenev // Forward declare class defined below.
41a86a9b5eSEugene Zhulenev class RefCounted;
42a86a9b5eSEugene Zhulenev
43a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
44a86a9b5eSEugene Zhulenev // AsyncRuntime orchestrates all async operations and Async runtime API is built
45a86a9b5eSEugene Zhulenev // on top of the default runtime instance.
46a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
47a86a9b5eSEugene Zhulenev
48a86a9b5eSEugene Zhulenev class AsyncRuntime {
49a86a9b5eSEugene Zhulenev public:
AsyncRuntime()50a86a9b5eSEugene Zhulenev AsyncRuntime() : numRefCountedObjects(0) {}
51a86a9b5eSEugene Zhulenev
~AsyncRuntime()52a86a9b5eSEugene Zhulenev ~AsyncRuntime() {
53bb0e6213SEugene Zhulenev threadPool.wait(); // wait for the completion of all async tasks
54a86a9b5eSEugene Zhulenev assert(getNumRefCountedObjects() == 0 &&
55a86a9b5eSEugene Zhulenev "all ref counted objects must be destroyed");
56a86a9b5eSEugene Zhulenev }
57a86a9b5eSEugene Zhulenev
getNumRefCountedObjects()5892db09cdSEugene Zhulenev int64_t getNumRefCountedObjects() {
59a86a9b5eSEugene Zhulenev return numRefCountedObjects.load(std::memory_order_relaxed);
60a86a9b5eSEugene Zhulenev }
61a86a9b5eSEugene Zhulenev
getThreadPool()62bb0e6213SEugene Zhulenev llvm::ThreadPool &getThreadPool() { return threadPool; }
63bb0e6213SEugene Zhulenev
64a86a9b5eSEugene Zhulenev private:
65a86a9b5eSEugene Zhulenev friend class RefCounted;
66a86a9b5eSEugene Zhulenev
67a86a9b5eSEugene Zhulenev // Count the total number of reference counted objects in this instance
68a86a9b5eSEugene Zhulenev // of an AsyncRuntime. For debugging purposes only.
addNumRefCountedObjects()69a86a9b5eSEugene Zhulenev void addNumRefCountedObjects() {
70a86a9b5eSEugene Zhulenev numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
71a86a9b5eSEugene Zhulenev }
dropNumRefCountedObjects()72a86a9b5eSEugene Zhulenev void dropNumRefCountedObjects() {
73a86a9b5eSEugene Zhulenev numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
74a86a9b5eSEugene Zhulenev }
75a86a9b5eSEugene Zhulenev
7692db09cdSEugene Zhulenev std::atomic<int64_t> numRefCountedObjects;
77bb0e6213SEugene Zhulenev llvm::ThreadPool threadPool;
78a86a9b5eSEugene Zhulenev };
79a86a9b5eSEugene Zhulenev
80a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
8139957aa4SEugene Zhulenev // A state of the async runtime value (token, value or group).
8239957aa4SEugene Zhulenev // -------------------------------------------------------------------------- //
8339957aa4SEugene Zhulenev
8439957aa4SEugene Zhulenev class State {
8539957aa4SEugene Zhulenev public:
8639957aa4SEugene Zhulenev enum StateEnum : int8_t {
8739957aa4SEugene Zhulenev // The underlying value is not yet available for consumption.
8839957aa4SEugene Zhulenev kUnavailable = 0,
8939957aa4SEugene Zhulenev // The underlying value is available for consumption. This state can not
9039957aa4SEugene Zhulenev // transition to any other state.
9139957aa4SEugene Zhulenev kAvailable = 1,
9239957aa4SEugene Zhulenev // This underlying value is available and contains an error. This state can
9339957aa4SEugene Zhulenev // not transition to any other state.
9439957aa4SEugene Zhulenev kError = 2,
9539957aa4SEugene Zhulenev };
9639957aa4SEugene Zhulenev
State(StateEnum s)9739957aa4SEugene Zhulenev /* implicit */ State(StateEnum s) : state(s) {}
operator StateEnum()9839957aa4SEugene Zhulenev /* implicit */ operator StateEnum() { return state; }
9939957aa4SEugene Zhulenev
isUnavailable() const10039957aa4SEugene Zhulenev bool isUnavailable() const { return state == kUnavailable; }
isAvailable() const10139957aa4SEugene Zhulenev bool isAvailable() const { return state == kAvailable; }
isError() const10239957aa4SEugene Zhulenev bool isError() const { return state == kError; }
isAvailableOrError() const10339957aa4SEugene Zhulenev bool isAvailableOrError() const { return isAvailable() || isError(); }
10439957aa4SEugene Zhulenev
debug() const10539957aa4SEugene Zhulenev const char *debug() const {
10639957aa4SEugene Zhulenev switch (state) {
10739957aa4SEugene Zhulenev case kUnavailable:
10839957aa4SEugene Zhulenev return "unavailable";
10939957aa4SEugene Zhulenev case kAvailable:
11039957aa4SEugene Zhulenev return "available";
11139957aa4SEugene Zhulenev case kError:
11239957aa4SEugene Zhulenev return "error";
11339957aa4SEugene Zhulenev }
11439957aa4SEugene Zhulenev }
11539957aa4SEugene Zhulenev
11639957aa4SEugene Zhulenev private:
11739957aa4SEugene Zhulenev StateEnum state;
11839957aa4SEugene Zhulenev };
11939957aa4SEugene Zhulenev
12039957aa4SEugene Zhulenev // -------------------------------------------------------------------------- //
121a86a9b5eSEugene Zhulenev // A base class for all reference counted objects created by the async runtime.
122a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
123a86a9b5eSEugene Zhulenev
124a86a9b5eSEugene Zhulenev class RefCounted {
125a86a9b5eSEugene Zhulenev public:
RefCounted(AsyncRuntime * runtime,int64_t refCount=1)12692db09cdSEugene Zhulenev RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
127a86a9b5eSEugene Zhulenev : runtime(runtime), refCount(refCount) {
128a86a9b5eSEugene Zhulenev runtime->addNumRefCountedObjects();
129a86a9b5eSEugene Zhulenev }
130a86a9b5eSEugene Zhulenev
~RefCounted()131a86a9b5eSEugene Zhulenev virtual ~RefCounted() {
132a86a9b5eSEugene Zhulenev assert(refCount.load() == 0 && "reference count must be zero");
133a86a9b5eSEugene Zhulenev runtime->dropNumRefCountedObjects();
134a86a9b5eSEugene Zhulenev }
135a86a9b5eSEugene Zhulenev
136a86a9b5eSEugene Zhulenev RefCounted(const RefCounted &) = delete;
137a86a9b5eSEugene Zhulenev RefCounted &operator=(const RefCounted &) = delete;
138a86a9b5eSEugene Zhulenev
addRef(int64_t count=1)13992db09cdSEugene Zhulenev void addRef(int64_t count = 1) { refCount.fetch_add(count); }
140a86a9b5eSEugene Zhulenev
dropRef(int64_t count=1)14192db09cdSEugene Zhulenev void dropRef(int64_t count = 1) {
14292db09cdSEugene Zhulenev int64_t previous = refCount.fetch_sub(count);
143a86a9b5eSEugene Zhulenev assert(previous >= count && "reference count should not go below zero");
144a86a9b5eSEugene Zhulenev if (previous == count)
145a86a9b5eSEugene Zhulenev destroy();
146a86a9b5eSEugene Zhulenev }
147a86a9b5eSEugene Zhulenev
148a86a9b5eSEugene Zhulenev protected:
destroy()149a86a9b5eSEugene Zhulenev virtual void destroy() { delete this; }
150a86a9b5eSEugene Zhulenev
151a86a9b5eSEugene Zhulenev private:
152a86a9b5eSEugene Zhulenev AsyncRuntime *runtime;
15392db09cdSEugene Zhulenev std::atomic<int64_t> refCount;
154a86a9b5eSEugene Zhulenev };
155a86a9b5eSEugene Zhulenev
156a86a9b5eSEugene Zhulenev } // namespace
157a86a9b5eSEugene Zhulenev
15811f1027bSEugene Zhulenev // Returns the default per-process instance of an async runtime.
getDefaultAsyncRuntimeInstance()1591fc98642SEugene Zhulenev static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
16011f1027bSEugene Zhulenev static auto runtime = std::make_unique<AsyncRuntime>();
1611fc98642SEugene Zhulenev return runtime;
1621fc98642SEugene Zhulenev }
1631fc98642SEugene Zhulenev
resetDefaultAsyncRuntime()1641fc98642SEugene Zhulenev static void resetDefaultAsyncRuntime() {
1651fc98642SEugene Zhulenev return getDefaultAsyncRuntimeInstance().reset();
1661fc98642SEugene Zhulenev }
1671fc98642SEugene Zhulenev
getDefaultAsyncRuntime()1681fc98642SEugene Zhulenev static AsyncRuntime *getDefaultAsyncRuntime() {
1691fc98642SEugene Zhulenev return getDefaultAsyncRuntimeInstance().get();
17011f1027bSEugene Zhulenev }
17111f1027bSEugene Zhulenev
172621ad468SEugene Zhulenev // Async token provides a mechanism to signal asynchronous operation completion.
173a86a9b5eSEugene Zhulenev struct AsyncToken : public RefCounted {
174a86a9b5eSEugene Zhulenev // AsyncToken created with a reference count of 2 because it will be returned
175a86a9b5eSEugene Zhulenev // to the `async.execute` caller and also will be later on emplaced by the
176a86a9b5eSEugene Zhulenev // asynchronously executed task. If the caller immediately will drop its
177a86a9b5eSEugene Zhulenev // reference we must ensure that the token will be alive until the
178a86a9b5eSEugene Zhulenev // asynchronous operation is completed.
AsyncTokenmlir::runtime::AsyncToken179a2223b09SEugene Zhulenev AsyncToken(AsyncRuntime *runtime)
18039957aa4SEugene Zhulenev : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
181a86a9b5eSEugene Zhulenev
18239957aa4SEugene Zhulenev std::atomic<State::StateEnum> state;
183a2223b09SEugene Zhulenev
184a2223b09SEugene Zhulenev // Pending awaiters are guarded by a mutex.
18536ce915aSLei Zhang std::mutex mu;
18636ce915aSLei Zhang std::condition_variable cv;
18736ce915aSLei Zhang std::vector<std::function<void()>> awaiters;
18836ce915aSLei Zhang };
18936ce915aSLei Zhang
190621ad468SEugene Zhulenev // Async value provides a mechanism to access the result of asynchronous
191621ad468SEugene Zhulenev // operations. It owns the storage that is used to store/load the value of the
192621ad468SEugene Zhulenev // underlying type, and a flag to signal if the value is ready or not.
193621ad468SEugene Zhulenev struct AsyncValue : public RefCounted {
194621ad468SEugene Zhulenev // AsyncValue similar to an AsyncToken created with a reference count of 2.
AsyncValuemlir::runtime::AsyncValue19592db09cdSEugene Zhulenev AsyncValue(AsyncRuntime *runtime, int64_t size)
19639957aa4SEugene Zhulenev : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
19739957aa4SEugene Zhulenev storage(size) {}
198621ad468SEugene Zhulenev
19939957aa4SEugene Zhulenev std::atomic<State::StateEnum> state;
200621ad468SEugene Zhulenev
201621ad468SEugene Zhulenev // Use vector of bytes to store async value payload.
202621ad468SEugene Zhulenev std::vector<int8_t> storage;
203a2223b09SEugene Zhulenev
204a2223b09SEugene Zhulenev // Pending awaiters are guarded by a mutex.
205a2223b09SEugene Zhulenev std::mutex mu;
206a2223b09SEugene Zhulenev std::condition_variable cv;
207a2223b09SEugene Zhulenev std::vector<std::function<void()>> awaiters;
208621ad468SEugene Zhulenev };
209621ad468SEugene Zhulenev
210621ad468SEugene Zhulenev // Async group provides a mechanism to group together multiple async tokens or
211621ad468SEugene Zhulenev // values to await on all of them together (wait for the completion of all
212621ad468SEugene Zhulenev // tokens or values added to the group).
213a86a9b5eSEugene Zhulenev struct AsyncGroup : public RefCounted {
AsyncGroupmlir::runtime::AsyncGroup214d43b2360SEugene Zhulenev AsyncGroup(AsyncRuntime *runtime, int64_t size)
215d43b2360SEugene Zhulenev : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
216a86a9b5eSEugene Zhulenev
217a86a9b5eSEugene Zhulenev std::atomic<int> pendingTokens;
218d8c84d2aSEugene Zhulenev std::atomic<int> numErrors;
219a86a9b5eSEugene Zhulenev std::atomic<int> rank;
220a86a9b5eSEugene Zhulenev
221a2223b09SEugene Zhulenev // Pending awaiters are guarded by a mutex.
222c30ab6c2SEugene Zhulenev std::mutex mu;
223c30ab6c2SEugene Zhulenev std::condition_variable cv;
224c30ab6c2SEugene Zhulenev std::vector<std::function<void()>> awaiters;
225c30ab6c2SEugene Zhulenev };
226c30ab6c2SEugene Zhulenev
227a86a9b5eSEugene Zhulenev // Adds references to reference counted runtime object.
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr,int64_t count)22892db09cdSEugene Zhulenev extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
229a86a9b5eSEugene Zhulenev RefCounted *refCounted = static_cast<RefCounted *>(ptr);
230a86a9b5eSEugene Zhulenev refCounted->addRef(count);
231a86a9b5eSEugene Zhulenev }
232a86a9b5eSEugene Zhulenev
233a86a9b5eSEugene Zhulenev // Drops references from reference counted runtime object.
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr,int64_t count)23492db09cdSEugene Zhulenev extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
235a86a9b5eSEugene Zhulenev RefCounted *refCounted = static_cast<RefCounted *>(ptr);
236a86a9b5eSEugene Zhulenev refCounted->dropRef(count);
237a86a9b5eSEugene Zhulenev }
238a86a9b5eSEugene Zhulenev
239621ad468SEugene Zhulenev // Creates a new `async.token` in not-ready state.
mlirAsyncRuntimeCreateToken()2406fd9e59eSPaul Lietar extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
2411fc98642SEugene Zhulenev AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
24236ce915aSLei Zhang return token;
24336ce915aSLei Zhang }
24436ce915aSLei Zhang
245621ad468SEugene Zhulenev // Creates a new `async.value` in not-ready state.
mlirAsyncRuntimeCreateValue(int64_t size)24692db09cdSEugene Zhulenev extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
2471fc98642SEugene Zhulenev AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
248621ad468SEugene Zhulenev return value;
249621ad468SEugene Zhulenev }
250621ad468SEugene Zhulenev
251c30ab6c2SEugene Zhulenev // Create a new `async.group` in empty state.
mlirAsyncRuntimeCreateGroup(int64_t size)252d43b2360SEugene Zhulenev extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
253d43b2360SEugene Zhulenev AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
254c30ab6c2SEugene Zhulenev return group;
255c30ab6c2SEugene Zhulenev }
256c30ab6c2SEugene Zhulenev
mlirAsyncRuntimeAddTokenToGroup(AsyncToken * token,AsyncGroup * group)2573d95d1b4SEugene Zhulenev extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
2583d95d1b4SEugene Zhulenev AsyncGroup *group) {
259c30ab6c2SEugene Zhulenev std::unique_lock<std::mutex> lockToken(token->mu);
260c30ab6c2SEugene Zhulenev std::unique_lock<std::mutex> lockGroup(group->mu);
261c30ab6c2SEugene Zhulenev
262a86a9b5eSEugene Zhulenev // Get the rank of the token inside the group before we drop the reference.
263a86a9b5eSEugene Zhulenev int rank = group->rank.fetch_add(1);
264c30ab6c2SEugene Zhulenev
265d8c84d2aSEugene Zhulenev auto onTokenReady = [group, token]() {
266d8c84d2aSEugene Zhulenev // Increment the number of errors in the group.
267d8c84d2aSEugene Zhulenev if (State(token->state).isError())
268d8c84d2aSEugene Zhulenev group->numErrors.fetch_add(1);
269d8c84d2aSEugene Zhulenev
270d43b2360SEugene Zhulenev // If pending tokens go below zero it means that more tokens than the group
271d43b2360SEugene Zhulenev // size were added to this group.
272d43b2360SEugene Zhulenev assert(group->pendingTokens > 0 && "wrong group size");
273d43b2360SEugene Zhulenev
274c30ab6c2SEugene Zhulenev // Run all group awaiters if it was the last token in the group.
275c30ab6c2SEugene Zhulenev if (group->pendingTokens.fetch_sub(1) == 1) {
276c30ab6c2SEugene Zhulenev group->cv.notify_all();
277c30ab6c2SEugene Zhulenev for (auto &awaiter : group->awaiters)
278c30ab6c2SEugene Zhulenev awaiter();
279c30ab6c2SEugene Zhulenev }
280c30ab6c2SEugene Zhulenev };
281c30ab6c2SEugene Zhulenev
28239957aa4SEugene Zhulenev if (State(token->state).isAvailableOrError()) {
2833d95d1b4SEugene Zhulenev // Update group pending tokens immediately and maybe run awaiters.
2843d95d1b4SEugene Zhulenev onTokenReady();
2853d95d1b4SEugene Zhulenev
286a86a9b5eSEugene Zhulenev } else {
2873d95d1b4SEugene Zhulenev // Update group pending tokens when token will become ready. Because this
2883d95d1b4SEugene Zhulenev // will happen asynchronously we must ensure that `group` is alive until
2893d95d1b4SEugene Zhulenev // then, and re-ackquire the lock.
290a86a9b5eSEugene Zhulenev group->addRef();
2913d95d1b4SEugene Zhulenev
292e5639b3fSMehdi Amini token->awaiters.emplace_back([group, onTokenReady]() {
2933d95d1b4SEugene Zhulenev // Make sure that `dropRef` does not destroy the mutex owned by the lock.
2943d95d1b4SEugene Zhulenev {
2953d95d1b4SEugene Zhulenev std::unique_lock<std::mutex> lockGroup(group->mu);
2963d95d1b4SEugene Zhulenev onTokenReady();
2973d95d1b4SEugene Zhulenev }
2983d95d1b4SEugene Zhulenev group->dropRef();
2993d95d1b4SEugene Zhulenev });
300a86a9b5eSEugene Zhulenev }
301c30ab6c2SEugene Zhulenev
302a86a9b5eSEugene Zhulenev return rank;
303c30ab6c2SEugene Zhulenev }
304c30ab6c2SEugene Zhulenev
30539957aa4SEugene Zhulenev // Switches `async.token` to available or error state (terminatl state) and runs
30639957aa4SEugene Zhulenev // all awaiters.
setTokenState(AsyncToken * token,State state)30739957aa4SEugene Zhulenev static void setTokenState(AsyncToken *token, State state) {
30839957aa4SEugene Zhulenev assert(state.isAvailableOrError() && "must be terminal state");
30939957aa4SEugene Zhulenev assert(State(token->state).isUnavailable() && "token must be unavailable");
31039957aa4SEugene Zhulenev
3113d95d1b4SEugene Zhulenev // Make sure that `dropRef` does not destroy the mutex owned by the lock.
3123d95d1b4SEugene Zhulenev {
31336ce915aSLei Zhang std::unique_lock<std::mutex> lock(token->mu);
31439957aa4SEugene Zhulenev token->state = state;
31536ce915aSLei Zhang token->cv.notify_all();
31636ce915aSLei Zhang for (auto &awaiter : token->awaiters)
31736ce915aSLei Zhang awaiter();
3183d95d1b4SEugene Zhulenev }
319a86a9b5eSEugene Zhulenev
320a86a9b5eSEugene Zhulenev // Async tokens created with a ref count `2` to keep token alive until the
321a86a9b5eSEugene Zhulenev // async task completes. Drop this reference explicitly when token emplaced.
322a86a9b5eSEugene Zhulenev token->dropRef();
32336ce915aSLei Zhang }
32436ce915aSLei Zhang
setValueState(AsyncValue * value,State state)32539957aa4SEugene Zhulenev static void setValueState(AsyncValue *value, State state) {
32639957aa4SEugene Zhulenev assert(state.isAvailableOrError() && "must be terminal state");
32739957aa4SEugene Zhulenev assert(State(value->state).isUnavailable() && "value must be unavailable");
32839957aa4SEugene Zhulenev
329621ad468SEugene Zhulenev // Make sure that `dropRef` does not destroy the mutex owned by the lock.
330621ad468SEugene Zhulenev {
331621ad468SEugene Zhulenev std::unique_lock<std::mutex> lock(value->mu);
33239957aa4SEugene Zhulenev value->state = state;
333621ad468SEugene Zhulenev value->cv.notify_all();
334621ad468SEugene Zhulenev for (auto &awaiter : value->awaiters)
335621ad468SEugene Zhulenev awaiter();
336621ad468SEugene Zhulenev }
337621ad468SEugene Zhulenev
338621ad468SEugene Zhulenev // Async values created with a ref count `2` to keep value alive until the
339621ad468SEugene Zhulenev // async task completes. Drop this reference explicitly when value emplaced.
340621ad468SEugene Zhulenev value->dropRef();
341621ad468SEugene Zhulenev }
342621ad468SEugene Zhulenev
mlirAsyncRuntimeEmplaceToken(AsyncToken * token)34339957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
34439957aa4SEugene Zhulenev setTokenState(token, State::kAvailable);
34539957aa4SEugene Zhulenev }
34639957aa4SEugene Zhulenev
mlirAsyncRuntimeEmplaceValue(AsyncValue * value)34739957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
34839957aa4SEugene Zhulenev setValueState(value, State::kAvailable);
34939957aa4SEugene Zhulenev }
35039957aa4SEugene Zhulenev
mlirAsyncRuntimeSetTokenError(AsyncToken * token)35139957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
35239957aa4SEugene Zhulenev setTokenState(token, State::kError);
35339957aa4SEugene Zhulenev }
35439957aa4SEugene Zhulenev
mlirAsyncRuntimeSetValueError(AsyncValue * value)35539957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
35639957aa4SEugene Zhulenev setValueState(value, State::kError);
35739957aa4SEugene Zhulenev }
35839957aa4SEugene Zhulenev
mlirAsyncRuntimeIsTokenError(AsyncToken * token)35939957aa4SEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
36039957aa4SEugene Zhulenev return State(token->state).isError();
36139957aa4SEugene Zhulenev }
36239957aa4SEugene Zhulenev
mlirAsyncRuntimeIsValueError(AsyncValue * value)36339957aa4SEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
36439957aa4SEugene Zhulenev return State(value->state).isError();
36539957aa4SEugene Zhulenev }
36639957aa4SEugene Zhulenev
mlirAsyncRuntimeIsGroupError(AsyncGroup * group)367d8c84d2aSEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
368d8c84d2aSEugene Zhulenev return group->numErrors.load() > 0;
369d8c84d2aSEugene Zhulenev }
370d8c84d2aSEugene Zhulenev
mlirAsyncRuntimeAwaitToken(AsyncToken * token)3716fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
37236ce915aSLei Zhang std::unique_lock<std::mutex> lock(token->mu);
37339957aa4SEugene Zhulenev if (!State(token->state).isAvailableOrError())
37439957aa4SEugene Zhulenev token->cv.wait(
37539957aa4SEugene Zhulenev lock, [token] { return State(token->state).isAvailableOrError(); });
376c30ab6c2SEugene Zhulenev }
377c30ab6c2SEugene Zhulenev
mlirAsyncRuntimeAwaitValue(AsyncValue * value)378621ad468SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
379621ad468SEugene Zhulenev std::unique_lock<std::mutex> lock(value->mu);
38039957aa4SEugene Zhulenev if (!State(value->state).isAvailableOrError())
38139957aa4SEugene Zhulenev value->cv.wait(
38239957aa4SEugene Zhulenev lock, [value] { return State(value->state).isAvailableOrError(); });
383621ad468SEugene Zhulenev }
384621ad468SEugene Zhulenev
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup * group)3853d95d1b4SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
386c30ab6c2SEugene Zhulenev std::unique_lock<std::mutex> lock(group->mu);
387c30ab6c2SEugene Zhulenev if (group->pendingTokens != 0)
388c30ab6c2SEugene Zhulenev group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
38936ce915aSLei Zhang }
39036ce915aSLei Zhang
391621ad468SEugene Zhulenev // Returns a pointer to the storage owned by the async value.
mlirAsyncRuntimeGetValueStorage(AsyncValue * value)392621ad468SEugene Zhulenev extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
39339957aa4SEugene Zhulenev assert(!State(value->state).isError() && "unexpected error state");
394621ad468SEugene Zhulenev return value->storage.data();
395621ad468SEugene Zhulenev }
396621ad468SEugene Zhulenev
mlirAsyncRuntimeExecute(CoroHandle handle,CoroResume resume)3976fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
398bb0e6213SEugene Zhulenev auto *runtime = getDefaultAsyncRuntime();
399bb0e6213SEugene Zhulenev runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
40036ce915aSLei Zhang }
40136ce915aSLei Zhang
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken * token,CoroHandle handle,CoroResume resume)4026fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
4036fd9e59eSPaul Lietar CoroHandle handle,
40436ce915aSLei Zhang CoroResume resume) {
4053d95d1b4SEugene Zhulenev auto execute = [handle, resume]() { (*resume)(handle); };
406f63f28edSEugene Zhulenev std::unique_lock<std::mutex> lock(token->mu);
40739957aa4SEugene Zhulenev if (State(token->state).isAvailableOrError()) {
408f63f28edSEugene Zhulenev lock.unlock();
4093d95d1b4SEugene Zhulenev execute();
410a2223b09SEugene Zhulenev } else {
411e5639b3fSMehdi Amini token->awaiters.emplace_back([execute]() { execute(); });
41236ce915aSLei Zhang }
413a2223b09SEugene Zhulenev }
41436ce915aSLei Zhang
mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue * value,CoroHandle handle,CoroResume resume)415621ad468SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
416621ad468SEugene Zhulenev CoroHandle handle,
417621ad468SEugene Zhulenev CoroResume resume) {
418621ad468SEugene Zhulenev auto execute = [handle, resume]() { (*resume)(handle); };
419f63f28edSEugene Zhulenev std::unique_lock<std::mutex> lock(value->mu);
42039957aa4SEugene Zhulenev if (State(value->state).isAvailableOrError()) {
421f63f28edSEugene Zhulenev lock.unlock();
422621ad468SEugene Zhulenev execute();
423a2223b09SEugene Zhulenev } else {
424e5639b3fSMehdi Amini value->awaiters.emplace_back([execute]() { execute(); });
425621ad468SEugene Zhulenev }
426a2223b09SEugene Zhulenev }
427621ad468SEugene Zhulenev
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup * group,CoroHandle handle,CoroResume resume)4283d95d1b4SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
4293d95d1b4SEugene Zhulenev CoroHandle handle,
430c30ab6c2SEugene Zhulenev CoroResume resume) {
4313d95d1b4SEugene Zhulenev auto execute = [handle, resume]() { (*resume)(handle); };
432f63f28edSEugene Zhulenev std::unique_lock<std::mutex> lock(group->mu);
433a2223b09SEugene Zhulenev if (group->pendingTokens == 0) {
434f63f28edSEugene Zhulenev lock.unlock();
4353d95d1b4SEugene Zhulenev execute();
436a2223b09SEugene Zhulenev } else {
437e5639b3fSMehdi Amini group->awaiters.emplace_back([execute]() { execute(); });
438c30ab6c2SEugene Zhulenev }
439a2223b09SEugene Zhulenev }
440c30ab6c2SEugene Zhulenev
mlirAsyncRuntimGetNumWorkerThreads()441*149311b4Sbakhtiyar extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
442*149311b4Sbakhtiyar return getDefaultAsyncRuntime()->getThreadPool().getThreadCount();
443*149311b4Sbakhtiyar }
444*149311b4Sbakhtiyar
44536ce915aSLei Zhang //===----------------------------------------------------------------------===//
44636ce915aSLei Zhang // Small async runtime support library for testing.
44736ce915aSLei Zhang //===----------------------------------------------------------------------===//
44836ce915aSLei Zhang
mlirAsyncRuntimePrintCurrentThreadId()4496fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
45036ce915aSLei Zhang static thread_local std::thread::id thisId = std::this_thread::get_id();
4513d95d1b4SEugene Zhulenev std::cout << "Current thread id: " << thisId << std::endl;
45236ce915aSLei Zhang }
45336ce915aSLei Zhang
4541fc98642SEugene Zhulenev //===----------------------------------------------------------------------===//
4551fc98642SEugene Zhulenev // MLIR Runner (JitRunner) dynamic library integration.
4561fc98642SEugene Zhulenev //===----------------------------------------------------------------------===//
4571fc98642SEugene Zhulenev
4581fc98642SEugene Zhulenev // Export symbols for the MLIR runner integration. All other symbols are hidden.
459dd2dac2fSMatthew Parkinson #ifdef _WIN32
460dd2dac2fSMatthew Parkinson #define API __declspec(dllexport)
461dd2dac2fSMatthew Parkinson #else
4621fc98642SEugene Zhulenev #define API __attribute__((visibility("default")))
463dd2dac2fSMatthew Parkinson #endif
4641fc98642SEugene Zhulenev
465dd2dac2fSMatthew Parkinson // Visual Studio had a bug that fails to compile nested generic lambdas
466dd2dac2fSMatthew Parkinson // inside an `extern "C"` function.
467dd2dac2fSMatthew Parkinson // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
468dd2dac2fSMatthew Parkinson // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
469dd2dac2fSMatthew Parkinson // a work around for older versions of Visual Studio.
47002b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
471dd2dac2fSMatthew Parkinson extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols);
472dd2dac2fSMatthew Parkinson
47302b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_runner_init(llvm::StringMap<void * > & exportSymbols)474dd2dac2fSMatthew Parkinson void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
4751fc98642SEugene Zhulenev auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
4761fc98642SEugene Zhulenev assert(exportSymbols.count(name) == 0 && "symbol already exists");
4771fc98642SEugene Zhulenev exportSymbols[name] = reinterpret_cast<void *>(ptr);
4781fc98642SEugene Zhulenev };
4791fc98642SEugene Zhulenev
4801fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAddRef",
4811fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAddRef);
4821fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeDropRef",
4831fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeDropRef);
4841fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeExecute",
4851fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeExecute);
4861fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeGetValueStorage",
4871fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
4881fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeCreateToken",
4891fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeCreateToken);
4901fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeCreateValue",
4911fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeCreateValue);
4921fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeEmplaceToken",
4931fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
4941fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeEmplaceValue",
4951fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
49639957aa4SEugene Zhulenev exportSymbol("mlirAsyncRuntimeSetTokenError",
49739957aa4SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeSetTokenError);
49839957aa4SEugene Zhulenev exportSymbol("mlirAsyncRuntimeSetValueError",
49939957aa4SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeSetValueError);
50039957aa4SEugene Zhulenev exportSymbol("mlirAsyncRuntimeIsTokenError",
50139957aa4SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeIsTokenError);
50239957aa4SEugene Zhulenev exportSymbol("mlirAsyncRuntimeIsValueError",
50339957aa4SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeIsValueError);
504d8c84d2aSEugene Zhulenev exportSymbol("mlirAsyncRuntimeIsGroupError",
505d8c84d2aSEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeIsGroupError);
5061fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitToken",
5071fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitToken);
5081fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitValue",
5091fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitValue);
5101fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
5111fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
5121fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
5131fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
5141fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeCreateGroup",
5151fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeCreateGroup);
5161fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
5171fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
5181fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
5191fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
5201fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
5211fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
522*149311b4Sbakhtiyar exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
523*149311b4Sbakhtiyar &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
5241fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
5251fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
5261fc98642SEugene Zhulenev }
5271fc98642SEugene Zhulenev
52802b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_runner_destroy()5291fc98642SEugene Zhulenev extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
5301fc98642SEugene Zhulenev
531dd2dac2fSMatthew Parkinson } // namespace runtime
532dd2dac2fSMatthew Parkinson } // namespace mlir
53378b3bce2SEugene Zhulenev
53436ce915aSLei Zhang #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
535