136ce915aSLei Zhang //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
236ce915aSLei Zhang //
336ce915aSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
436ce915aSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
536ce915aSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
636ce915aSLei Zhang //
736ce915aSLei Zhang //===----------------------------------------------------------------------===//
836ce915aSLei Zhang //
936ce915aSLei Zhang // This file implements basic Async runtime API for supporting Async dialect
1036ce915aSLei Zhang // to LLVM dialect lowering.
1136ce915aSLei Zhang //
1236ce915aSLei Zhang //===----------------------------------------------------------------------===//
1336ce915aSLei Zhang 
1436ce915aSLei Zhang #include "mlir/ExecutionEngine/AsyncRuntime.h"
1536ce915aSLei Zhang 
1636ce915aSLei Zhang #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
1736ce915aSLei Zhang 
18c30ab6c2SEugene Zhulenev #include <atomic>
19a86a9b5eSEugene Zhulenev #include <cassert>
2036ce915aSLei Zhang #include <condition_variable>
2136ce915aSLei Zhang #include <functional>
2236ce915aSLei Zhang #include <iostream>
2336ce915aSLei Zhang #include <mutex>
2436ce915aSLei Zhang #include <thread>
2536ce915aSLei Zhang #include <vector>
2636ce915aSLei Zhang 
271fc98642SEugene Zhulenev #include "llvm/ADT/StringMap.h"
28bb0e6213SEugene Zhulenev #include "llvm/Support/ThreadPool.h"
291fc98642SEugene Zhulenev 
3011f1027bSEugene Zhulenev using namespace mlir::runtime;
3111f1027bSEugene Zhulenev 
3236ce915aSLei Zhang //===----------------------------------------------------------------------===//
3336ce915aSLei Zhang // Async runtime API.
3436ce915aSLei Zhang //===----------------------------------------------------------------------===//
3536ce915aSLei Zhang 
3611f1027bSEugene Zhulenev namespace mlir {
3711f1027bSEugene Zhulenev namespace runtime {
38a86a9b5eSEugene Zhulenev namespace {
39a86a9b5eSEugene Zhulenev 
40a86a9b5eSEugene Zhulenev // Forward declare class defined below.
41a86a9b5eSEugene Zhulenev class RefCounted;
42a86a9b5eSEugene Zhulenev 
43a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
44a86a9b5eSEugene Zhulenev // AsyncRuntime orchestrates all async operations and Async runtime API is built
45a86a9b5eSEugene Zhulenev // on top of the default runtime instance.
46a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
47a86a9b5eSEugene Zhulenev 
48a86a9b5eSEugene Zhulenev class AsyncRuntime {
49a86a9b5eSEugene Zhulenev public:
AsyncRuntime()50a86a9b5eSEugene Zhulenev   AsyncRuntime() : numRefCountedObjects(0) {}
51a86a9b5eSEugene Zhulenev 
~AsyncRuntime()52a86a9b5eSEugene Zhulenev   ~AsyncRuntime() {
53bb0e6213SEugene Zhulenev     threadPool.wait(); // wait for the completion of all async tasks
54a86a9b5eSEugene Zhulenev     assert(getNumRefCountedObjects() == 0 &&
55a86a9b5eSEugene Zhulenev            "all ref counted objects must be destroyed");
56a86a9b5eSEugene Zhulenev   }
57a86a9b5eSEugene Zhulenev 
getNumRefCountedObjects()5892db09cdSEugene Zhulenev   int64_t getNumRefCountedObjects() {
59a86a9b5eSEugene Zhulenev     return numRefCountedObjects.load(std::memory_order_relaxed);
60a86a9b5eSEugene Zhulenev   }
61a86a9b5eSEugene Zhulenev 
getThreadPool()62bb0e6213SEugene Zhulenev   llvm::ThreadPool &getThreadPool() { return threadPool; }
63bb0e6213SEugene Zhulenev 
64a86a9b5eSEugene Zhulenev private:
65a86a9b5eSEugene Zhulenev   friend class RefCounted;
66a86a9b5eSEugene Zhulenev 
67a86a9b5eSEugene Zhulenev   // Count the total number of reference counted objects in this instance
68a86a9b5eSEugene Zhulenev   // of an AsyncRuntime. For debugging purposes only.
addNumRefCountedObjects()69a86a9b5eSEugene Zhulenev   void addNumRefCountedObjects() {
70a86a9b5eSEugene Zhulenev     numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
71a86a9b5eSEugene Zhulenev   }
dropNumRefCountedObjects()72a86a9b5eSEugene Zhulenev   void dropNumRefCountedObjects() {
73a86a9b5eSEugene Zhulenev     numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
74a86a9b5eSEugene Zhulenev   }
75a86a9b5eSEugene Zhulenev 
7692db09cdSEugene Zhulenev   std::atomic<int64_t> numRefCountedObjects;
77bb0e6213SEugene Zhulenev   llvm::ThreadPool threadPool;
78a86a9b5eSEugene Zhulenev };
79a86a9b5eSEugene Zhulenev 
80a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
8139957aa4SEugene Zhulenev // A state of the async runtime value (token, value or group).
8239957aa4SEugene Zhulenev // -------------------------------------------------------------------------- //
8339957aa4SEugene Zhulenev 
8439957aa4SEugene Zhulenev class State {
8539957aa4SEugene Zhulenev public:
8639957aa4SEugene Zhulenev   enum StateEnum : int8_t {
8739957aa4SEugene Zhulenev     // The underlying value is not yet available for consumption.
8839957aa4SEugene Zhulenev     kUnavailable = 0,
8939957aa4SEugene Zhulenev     // The underlying value is available for consumption. This state can not
9039957aa4SEugene Zhulenev     // transition to any other state.
9139957aa4SEugene Zhulenev     kAvailable = 1,
9239957aa4SEugene Zhulenev     // This underlying value is available and contains an error. This state can
9339957aa4SEugene Zhulenev     // not transition to any other state.
9439957aa4SEugene Zhulenev     kError = 2,
9539957aa4SEugene Zhulenev   };
9639957aa4SEugene Zhulenev 
State(StateEnum s)9739957aa4SEugene Zhulenev   /* implicit */ State(StateEnum s) : state(s) {}
operator StateEnum()9839957aa4SEugene Zhulenev   /* implicit */ operator StateEnum() { return state; }
9939957aa4SEugene Zhulenev 
isUnavailable() const10039957aa4SEugene Zhulenev   bool isUnavailable() const { return state == kUnavailable; }
isAvailable() const10139957aa4SEugene Zhulenev   bool isAvailable() const { return state == kAvailable; }
isError() const10239957aa4SEugene Zhulenev   bool isError() const { return state == kError; }
isAvailableOrError() const10339957aa4SEugene Zhulenev   bool isAvailableOrError() const { return isAvailable() || isError(); }
10439957aa4SEugene Zhulenev 
debug() const10539957aa4SEugene Zhulenev   const char *debug() const {
10639957aa4SEugene Zhulenev     switch (state) {
10739957aa4SEugene Zhulenev     case kUnavailable:
10839957aa4SEugene Zhulenev       return "unavailable";
10939957aa4SEugene Zhulenev     case kAvailable:
11039957aa4SEugene Zhulenev       return "available";
11139957aa4SEugene Zhulenev     case kError:
11239957aa4SEugene Zhulenev       return "error";
11339957aa4SEugene Zhulenev     }
11439957aa4SEugene Zhulenev   }
11539957aa4SEugene Zhulenev 
11639957aa4SEugene Zhulenev private:
11739957aa4SEugene Zhulenev   StateEnum state;
11839957aa4SEugene Zhulenev };
11939957aa4SEugene Zhulenev 
12039957aa4SEugene Zhulenev // -------------------------------------------------------------------------- //
121a86a9b5eSEugene Zhulenev // A base class for all reference counted objects created by the async runtime.
122a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
123a86a9b5eSEugene Zhulenev 
124a86a9b5eSEugene Zhulenev class RefCounted {
125a86a9b5eSEugene Zhulenev public:
RefCounted(AsyncRuntime * runtime,int64_t refCount=1)12692db09cdSEugene Zhulenev   RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
127a86a9b5eSEugene Zhulenev       : runtime(runtime), refCount(refCount) {
128a86a9b5eSEugene Zhulenev     runtime->addNumRefCountedObjects();
129a86a9b5eSEugene Zhulenev   }
130a86a9b5eSEugene Zhulenev 
~RefCounted()131a86a9b5eSEugene Zhulenev   virtual ~RefCounted() {
132a86a9b5eSEugene Zhulenev     assert(refCount.load() == 0 && "reference count must be zero");
133a86a9b5eSEugene Zhulenev     runtime->dropNumRefCountedObjects();
134a86a9b5eSEugene Zhulenev   }
135a86a9b5eSEugene Zhulenev 
136a86a9b5eSEugene Zhulenev   RefCounted(const RefCounted &) = delete;
137a86a9b5eSEugene Zhulenev   RefCounted &operator=(const RefCounted &) = delete;
138a86a9b5eSEugene Zhulenev 
addRef(int64_t count=1)13992db09cdSEugene Zhulenev   void addRef(int64_t count = 1) { refCount.fetch_add(count); }
140a86a9b5eSEugene Zhulenev 
dropRef(int64_t count=1)14192db09cdSEugene Zhulenev   void dropRef(int64_t count = 1) {
14292db09cdSEugene Zhulenev     int64_t previous = refCount.fetch_sub(count);
143a86a9b5eSEugene Zhulenev     assert(previous >= count && "reference count should not go below zero");
144a86a9b5eSEugene Zhulenev     if (previous == count)
145a86a9b5eSEugene Zhulenev       destroy();
146a86a9b5eSEugene Zhulenev   }
147a86a9b5eSEugene Zhulenev 
148a86a9b5eSEugene Zhulenev protected:
destroy()149a86a9b5eSEugene Zhulenev   virtual void destroy() { delete this; }
150a86a9b5eSEugene Zhulenev 
151a86a9b5eSEugene Zhulenev private:
152a86a9b5eSEugene Zhulenev   AsyncRuntime *runtime;
15392db09cdSEugene Zhulenev   std::atomic<int64_t> refCount;
154a86a9b5eSEugene Zhulenev };
155a86a9b5eSEugene Zhulenev 
156a86a9b5eSEugene Zhulenev } // namespace
157a86a9b5eSEugene Zhulenev 
15811f1027bSEugene Zhulenev // Returns the default per-process instance of an async runtime.
getDefaultAsyncRuntimeInstance()1591fc98642SEugene Zhulenev static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
16011f1027bSEugene Zhulenev   static auto runtime = std::make_unique<AsyncRuntime>();
1611fc98642SEugene Zhulenev   return runtime;
1621fc98642SEugene Zhulenev }
1631fc98642SEugene Zhulenev 
resetDefaultAsyncRuntime()1641fc98642SEugene Zhulenev static void resetDefaultAsyncRuntime() {
1651fc98642SEugene Zhulenev   return getDefaultAsyncRuntimeInstance().reset();
1661fc98642SEugene Zhulenev }
1671fc98642SEugene Zhulenev 
getDefaultAsyncRuntime()1681fc98642SEugene Zhulenev static AsyncRuntime *getDefaultAsyncRuntime() {
1691fc98642SEugene Zhulenev   return getDefaultAsyncRuntimeInstance().get();
17011f1027bSEugene Zhulenev }
17111f1027bSEugene Zhulenev 
172621ad468SEugene Zhulenev // Async token provides a mechanism to signal asynchronous operation completion.
173a86a9b5eSEugene Zhulenev struct AsyncToken : public RefCounted {
174a86a9b5eSEugene Zhulenev   // AsyncToken created with a reference count of 2 because it will be returned
175a86a9b5eSEugene Zhulenev   // to the `async.execute` caller and also will be later on emplaced by the
176a86a9b5eSEugene Zhulenev   // asynchronously executed task. If the caller immediately will drop its
177a86a9b5eSEugene Zhulenev   // reference we must ensure that the token will be alive until the
178a86a9b5eSEugene Zhulenev   // asynchronous operation is completed.
AsyncTokenmlir::runtime::AsyncToken179a2223b09SEugene Zhulenev   AsyncToken(AsyncRuntime *runtime)
18039957aa4SEugene Zhulenev       : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
181a86a9b5eSEugene Zhulenev 
18239957aa4SEugene Zhulenev   std::atomic<State::StateEnum> state;
183a2223b09SEugene Zhulenev 
184a2223b09SEugene Zhulenev   // Pending awaiters are guarded by a mutex.
18536ce915aSLei Zhang   std::mutex mu;
18636ce915aSLei Zhang   std::condition_variable cv;
18736ce915aSLei Zhang   std::vector<std::function<void()>> awaiters;
18836ce915aSLei Zhang };
18936ce915aSLei Zhang 
190621ad468SEugene Zhulenev // Async value provides a mechanism to access the result of asynchronous
191621ad468SEugene Zhulenev // operations. It owns the storage that is used to store/load the value of the
192621ad468SEugene Zhulenev // underlying type, and a flag to signal if the value is ready or not.
193621ad468SEugene Zhulenev struct AsyncValue : public RefCounted {
194621ad468SEugene Zhulenev   // AsyncValue similar to an AsyncToken created with a reference count of 2.
AsyncValuemlir::runtime::AsyncValue19592db09cdSEugene Zhulenev   AsyncValue(AsyncRuntime *runtime, int64_t size)
19639957aa4SEugene Zhulenev       : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
19739957aa4SEugene Zhulenev         storage(size) {}
198621ad468SEugene Zhulenev 
19939957aa4SEugene Zhulenev   std::atomic<State::StateEnum> state;
200621ad468SEugene Zhulenev 
201621ad468SEugene Zhulenev   // Use vector of bytes to store async value payload.
202621ad468SEugene Zhulenev   std::vector<int8_t> storage;
203a2223b09SEugene Zhulenev 
204a2223b09SEugene Zhulenev   // Pending awaiters are guarded by a mutex.
205a2223b09SEugene Zhulenev   std::mutex mu;
206a2223b09SEugene Zhulenev   std::condition_variable cv;
207a2223b09SEugene Zhulenev   std::vector<std::function<void()>> awaiters;
208621ad468SEugene Zhulenev };
209621ad468SEugene Zhulenev 
210621ad468SEugene Zhulenev // Async group provides a mechanism to group together multiple async tokens or
211621ad468SEugene Zhulenev // values to await on all of them together (wait for the completion of all
212621ad468SEugene Zhulenev // tokens or values added to the group).
213a86a9b5eSEugene Zhulenev struct AsyncGroup : public RefCounted {
AsyncGroupmlir::runtime::AsyncGroup214d43b2360SEugene Zhulenev   AsyncGroup(AsyncRuntime *runtime, int64_t size)
215d43b2360SEugene Zhulenev       : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
216a86a9b5eSEugene Zhulenev 
217a86a9b5eSEugene Zhulenev   std::atomic<int> pendingTokens;
218d8c84d2aSEugene Zhulenev   std::atomic<int> numErrors;
219a86a9b5eSEugene Zhulenev   std::atomic<int> rank;
220a86a9b5eSEugene Zhulenev 
221a2223b09SEugene Zhulenev   // Pending awaiters are guarded by a mutex.
222c30ab6c2SEugene Zhulenev   std::mutex mu;
223c30ab6c2SEugene Zhulenev   std::condition_variable cv;
224c30ab6c2SEugene Zhulenev   std::vector<std::function<void()>> awaiters;
225c30ab6c2SEugene Zhulenev };
226c30ab6c2SEugene Zhulenev 
227a86a9b5eSEugene Zhulenev // Adds references to reference counted runtime object.
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr,int64_t count)22892db09cdSEugene Zhulenev extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
229a86a9b5eSEugene Zhulenev   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
230a86a9b5eSEugene Zhulenev   refCounted->addRef(count);
231a86a9b5eSEugene Zhulenev }
232a86a9b5eSEugene Zhulenev 
233a86a9b5eSEugene Zhulenev // Drops references from reference counted runtime object.
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr,int64_t count)23492db09cdSEugene Zhulenev extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
235a86a9b5eSEugene Zhulenev   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
236a86a9b5eSEugene Zhulenev   refCounted->dropRef(count);
237a86a9b5eSEugene Zhulenev }
238a86a9b5eSEugene Zhulenev 
239621ad468SEugene Zhulenev // Creates a new `async.token` in not-ready state.
mlirAsyncRuntimeCreateToken()2406fd9e59eSPaul Lietar extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
2411fc98642SEugene Zhulenev   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
24236ce915aSLei Zhang   return token;
24336ce915aSLei Zhang }
24436ce915aSLei Zhang 
245621ad468SEugene Zhulenev // Creates a new `async.value` in not-ready state.
mlirAsyncRuntimeCreateValue(int64_t size)24692db09cdSEugene Zhulenev extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
2471fc98642SEugene Zhulenev   AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
248621ad468SEugene Zhulenev   return value;
249621ad468SEugene Zhulenev }
250621ad468SEugene Zhulenev 
251c30ab6c2SEugene Zhulenev // Create a new `async.group` in empty state.
mlirAsyncRuntimeCreateGroup(int64_t size)252d43b2360SEugene Zhulenev extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
253d43b2360SEugene Zhulenev   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
254c30ab6c2SEugene Zhulenev   return group;
255c30ab6c2SEugene Zhulenev }
256c30ab6c2SEugene Zhulenev 
mlirAsyncRuntimeAddTokenToGroup(AsyncToken * token,AsyncGroup * group)2573d95d1b4SEugene Zhulenev extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
2583d95d1b4SEugene Zhulenev                                                    AsyncGroup *group) {
259c30ab6c2SEugene Zhulenev   std::unique_lock<std::mutex> lockToken(token->mu);
260c30ab6c2SEugene Zhulenev   std::unique_lock<std::mutex> lockGroup(group->mu);
261c30ab6c2SEugene Zhulenev 
262a86a9b5eSEugene Zhulenev   // Get the rank of the token inside the group before we drop the reference.
263a86a9b5eSEugene Zhulenev   int rank = group->rank.fetch_add(1);
264c30ab6c2SEugene Zhulenev 
265d8c84d2aSEugene Zhulenev   auto onTokenReady = [group, token]() {
266d8c84d2aSEugene Zhulenev     // Increment the number of errors in the group.
267d8c84d2aSEugene Zhulenev     if (State(token->state).isError())
268d8c84d2aSEugene Zhulenev       group->numErrors.fetch_add(1);
269d8c84d2aSEugene Zhulenev 
270d43b2360SEugene Zhulenev     // If pending tokens go below zero it means that more tokens than the group
271d43b2360SEugene Zhulenev     // size were added to this group.
272d43b2360SEugene Zhulenev     assert(group->pendingTokens > 0 && "wrong group size");
273d43b2360SEugene Zhulenev 
274c30ab6c2SEugene Zhulenev     // Run all group awaiters if it was the last token in the group.
275c30ab6c2SEugene Zhulenev     if (group->pendingTokens.fetch_sub(1) == 1) {
276c30ab6c2SEugene Zhulenev       group->cv.notify_all();
277c30ab6c2SEugene Zhulenev       for (auto &awaiter : group->awaiters)
278c30ab6c2SEugene Zhulenev         awaiter();
279c30ab6c2SEugene Zhulenev     }
280c30ab6c2SEugene Zhulenev   };
281c30ab6c2SEugene Zhulenev 
28239957aa4SEugene Zhulenev   if (State(token->state).isAvailableOrError()) {
2833d95d1b4SEugene Zhulenev     // Update group pending tokens immediately and maybe run awaiters.
2843d95d1b4SEugene Zhulenev     onTokenReady();
2853d95d1b4SEugene Zhulenev 
286a86a9b5eSEugene Zhulenev   } else {
2873d95d1b4SEugene Zhulenev     // Update group pending tokens when token will become ready. Because this
2883d95d1b4SEugene Zhulenev     // will happen asynchronously we must ensure that `group` is alive until
2893d95d1b4SEugene Zhulenev     // then, and re-ackquire the lock.
290a86a9b5eSEugene Zhulenev     group->addRef();
2913d95d1b4SEugene Zhulenev 
292e5639b3fSMehdi Amini     token->awaiters.emplace_back([group, onTokenReady]() {
2933d95d1b4SEugene Zhulenev       // Make sure that `dropRef` does not destroy the mutex owned by the lock.
2943d95d1b4SEugene Zhulenev       {
2953d95d1b4SEugene Zhulenev         std::unique_lock<std::mutex> lockGroup(group->mu);
2963d95d1b4SEugene Zhulenev         onTokenReady();
2973d95d1b4SEugene Zhulenev       }
2983d95d1b4SEugene Zhulenev       group->dropRef();
2993d95d1b4SEugene Zhulenev     });
300a86a9b5eSEugene Zhulenev   }
301c30ab6c2SEugene Zhulenev 
302a86a9b5eSEugene Zhulenev   return rank;
303c30ab6c2SEugene Zhulenev }
304c30ab6c2SEugene Zhulenev 
30539957aa4SEugene Zhulenev // Switches `async.token` to available or error state (terminatl state) and runs
30639957aa4SEugene Zhulenev // all awaiters.
setTokenState(AsyncToken * token,State state)30739957aa4SEugene Zhulenev static void setTokenState(AsyncToken *token, State state) {
30839957aa4SEugene Zhulenev   assert(state.isAvailableOrError() && "must be terminal state");
30939957aa4SEugene Zhulenev   assert(State(token->state).isUnavailable() && "token must be unavailable");
31039957aa4SEugene Zhulenev 
3113d95d1b4SEugene Zhulenev   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
3123d95d1b4SEugene Zhulenev   {
31336ce915aSLei Zhang     std::unique_lock<std::mutex> lock(token->mu);
31439957aa4SEugene Zhulenev     token->state = state;
31536ce915aSLei Zhang     token->cv.notify_all();
31636ce915aSLei Zhang     for (auto &awaiter : token->awaiters)
31736ce915aSLei Zhang       awaiter();
3183d95d1b4SEugene Zhulenev   }
319a86a9b5eSEugene Zhulenev 
320a86a9b5eSEugene Zhulenev   // Async tokens created with a ref count `2` to keep token alive until the
321a86a9b5eSEugene Zhulenev   // async task completes. Drop this reference explicitly when token emplaced.
322a86a9b5eSEugene Zhulenev   token->dropRef();
32336ce915aSLei Zhang }
32436ce915aSLei Zhang 
setValueState(AsyncValue * value,State state)32539957aa4SEugene Zhulenev static void setValueState(AsyncValue *value, State state) {
32639957aa4SEugene Zhulenev   assert(state.isAvailableOrError() && "must be terminal state");
32739957aa4SEugene Zhulenev   assert(State(value->state).isUnavailable() && "value must be unavailable");
32839957aa4SEugene Zhulenev 
329621ad468SEugene Zhulenev   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
330621ad468SEugene Zhulenev   {
331621ad468SEugene Zhulenev     std::unique_lock<std::mutex> lock(value->mu);
33239957aa4SEugene Zhulenev     value->state = state;
333621ad468SEugene Zhulenev     value->cv.notify_all();
334621ad468SEugene Zhulenev     for (auto &awaiter : value->awaiters)
335621ad468SEugene Zhulenev       awaiter();
336621ad468SEugene Zhulenev   }
337621ad468SEugene Zhulenev 
338621ad468SEugene Zhulenev   // Async values created with a ref count `2` to keep value alive until the
339621ad468SEugene Zhulenev   // async task completes. Drop this reference explicitly when value emplaced.
340621ad468SEugene Zhulenev   value->dropRef();
341621ad468SEugene Zhulenev }
342621ad468SEugene Zhulenev 
mlirAsyncRuntimeEmplaceToken(AsyncToken * token)34339957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
34439957aa4SEugene Zhulenev   setTokenState(token, State::kAvailable);
34539957aa4SEugene Zhulenev }
34639957aa4SEugene Zhulenev 
mlirAsyncRuntimeEmplaceValue(AsyncValue * value)34739957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
34839957aa4SEugene Zhulenev   setValueState(value, State::kAvailable);
34939957aa4SEugene Zhulenev }
35039957aa4SEugene Zhulenev 
mlirAsyncRuntimeSetTokenError(AsyncToken * token)35139957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
35239957aa4SEugene Zhulenev   setTokenState(token, State::kError);
35339957aa4SEugene Zhulenev }
35439957aa4SEugene Zhulenev 
mlirAsyncRuntimeSetValueError(AsyncValue * value)35539957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
35639957aa4SEugene Zhulenev   setValueState(value, State::kError);
35739957aa4SEugene Zhulenev }
35839957aa4SEugene Zhulenev 
mlirAsyncRuntimeIsTokenError(AsyncToken * token)35939957aa4SEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
36039957aa4SEugene Zhulenev   return State(token->state).isError();
36139957aa4SEugene Zhulenev }
36239957aa4SEugene Zhulenev 
mlirAsyncRuntimeIsValueError(AsyncValue * value)36339957aa4SEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
36439957aa4SEugene Zhulenev   return State(value->state).isError();
36539957aa4SEugene Zhulenev }
36639957aa4SEugene Zhulenev 
mlirAsyncRuntimeIsGroupError(AsyncGroup * group)367d8c84d2aSEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
368d8c84d2aSEugene Zhulenev   return group->numErrors.load() > 0;
369d8c84d2aSEugene Zhulenev }
370d8c84d2aSEugene Zhulenev 
mlirAsyncRuntimeAwaitToken(AsyncToken * token)3716fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
37236ce915aSLei Zhang   std::unique_lock<std::mutex> lock(token->mu);
37339957aa4SEugene Zhulenev   if (!State(token->state).isAvailableOrError())
37439957aa4SEugene Zhulenev     token->cv.wait(
37539957aa4SEugene Zhulenev         lock, [token] { return State(token->state).isAvailableOrError(); });
376c30ab6c2SEugene Zhulenev }
377c30ab6c2SEugene Zhulenev 
mlirAsyncRuntimeAwaitValue(AsyncValue * value)378621ad468SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
379621ad468SEugene Zhulenev   std::unique_lock<std::mutex> lock(value->mu);
38039957aa4SEugene Zhulenev   if (!State(value->state).isAvailableOrError())
38139957aa4SEugene Zhulenev     value->cv.wait(
38239957aa4SEugene Zhulenev         lock, [value] { return State(value->state).isAvailableOrError(); });
383621ad468SEugene Zhulenev }
384621ad468SEugene Zhulenev 
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup * group)3853d95d1b4SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
386c30ab6c2SEugene Zhulenev   std::unique_lock<std::mutex> lock(group->mu);
387c30ab6c2SEugene Zhulenev   if (group->pendingTokens != 0)
388c30ab6c2SEugene Zhulenev     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
38936ce915aSLei Zhang }
39036ce915aSLei Zhang 
391621ad468SEugene Zhulenev // Returns a pointer to the storage owned by the async value.
mlirAsyncRuntimeGetValueStorage(AsyncValue * value)392621ad468SEugene Zhulenev extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
39339957aa4SEugene Zhulenev   assert(!State(value->state).isError() && "unexpected error state");
394621ad468SEugene Zhulenev   return value->storage.data();
395621ad468SEugene Zhulenev }
396621ad468SEugene Zhulenev 
mlirAsyncRuntimeExecute(CoroHandle handle,CoroResume resume)3976fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
398bb0e6213SEugene Zhulenev   auto *runtime = getDefaultAsyncRuntime();
399bb0e6213SEugene Zhulenev   runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
40036ce915aSLei Zhang }
40136ce915aSLei Zhang 
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken * token,CoroHandle handle,CoroResume resume)4026fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
4036fd9e59eSPaul Lietar                                                      CoroHandle handle,
40436ce915aSLei Zhang                                                      CoroResume resume) {
4053d95d1b4SEugene Zhulenev   auto execute = [handle, resume]() { (*resume)(handle); };
406f63f28edSEugene Zhulenev   std::unique_lock<std::mutex> lock(token->mu);
40739957aa4SEugene Zhulenev   if (State(token->state).isAvailableOrError()) {
408f63f28edSEugene Zhulenev     lock.unlock();
4093d95d1b4SEugene Zhulenev     execute();
410a2223b09SEugene Zhulenev   } else {
411e5639b3fSMehdi Amini     token->awaiters.emplace_back([execute]() { execute(); });
41236ce915aSLei Zhang   }
413a2223b09SEugene Zhulenev }
41436ce915aSLei Zhang 
mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue * value,CoroHandle handle,CoroResume resume)415621ad468SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
416621ad468SEugene Zhulenev                                                      CoroHandle handle,
417621ad468SEugene Zhulenev                                                      CoroResume resume) {
418621ad468SEugene Zhulenev   auto execute = [handle, resume]() { (*resume)(handle); };
419f63f28edSEugene Zhulenev   std::unique_lock<std::mutex> lock(value->mu);
42039957aa4SEugene Zhulenev   if (State(value->state).isAvailableOrError()) {
421f63f28edSEugene Zhulenev     lock.unlock();
422621ad468SEugene Zhulenev     execute();
423a2223b09SEugene Zhulenev   } else {
424e5639b3fSMehdi Amini     value->awaiters.emplace_back([execute]() { execute(); });
425621ad468SEugene Zhulenev   }
426a2223b09SEugene Zhulenev }
427621ad468SEugene Zhulenev 
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup * group,CoroHandle handle,CoroResume resume)4283d95d1b4SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
4293d95d1b4SEugene Zhulenev                                                           CoroHandle handle,
430c30ab6c2SEugene Zhulenev                                                           CoroResume resume) {
4313d95d1b4SEugene Zhulenev   auto execute = [handle, resume]() { (*resume)(handle); };
432f63f28edSEugene Zhulenev   std::unique_lock<std::mutex> lock(group->mu);
433a2223b09SEugene Zhulenev   if (group->pendingTokens == 0) {
434f63f28edSEugene Zhulenev     lock.unlock();
4353d95d1b4SEugene Zhulenev     execute();
436a2223b09SEugene Zhulenev   } else {
437e5639b3fSMehdi Amini     group->awaiters.emplace_back([execute]() { execute(); });
438c30ab6c2SEugene Zhulenev   }
439a2223b09SEugene Zhulenev }
440c30ab6c2SEugene Zhulenev 
mlirAsyncRuntimGetNumWorkerThreads()441*149311b4Sbakhtiyar extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
442*149311b4Sbakhtiyar   return getDefaultAsyncRuntime()->getThreadPool().getThreadCount();
443*149311b4Sbakhtiyar }
444*149311b4Sbakhtiyar 
44536ce915aSLei Zhang //===----------------------------------------------------------------------===//
44636ce915aSLei Zhang // Small async runtime support library for testing.
44736ce915aSLei Zhang //===----------------------------------------------------------------------===//
44836ce915aSLei Zhang 
mlirAsyncRuntimePrintCurrentThreadId()4496fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
45036ce915aSLei Zhang   static thread_local std::thread::id thisId = std::this_thread::get_id();
4513d95d1b4SEugene Zhulenev   std::cout << "Current thread id: " << thisId << std::endl;
45236ce915aSLei Zhang }
45336ce915aSLei Zhang 
4541fc98642SEugene Zhulenev //===----------------------------------------------------------------------===//
4551fc98642SEugene Zhulenev // MLIR Runner (JitRunner) dynamic library integration.
4561fc98642SEugene Zhulenev //===----------------------------------------------------------------------===//
4571fc98642SEugene Zhulenev 
4581fc98642SEugene Zhulenev // Export symbols for the MLIR runner integration. All other symbols are hidden.
459dd2dac2fSMatthew Parkinson #ifdef _WIN32
460dd2dac2fSMatthew Parkinson #define API __declspec(dllexport)
461dd2dac2fSMatthew Parkinson #else
4621fc98642SEugene Zhulenev #define API __attribute__((visibility("default")))
463dd2dac2fSMatthew Parkinson #endif
4641fc98642SEugene Zhulenev 
465dd2dac2fSMatthew Parkinson // Visual Studio had a bug that fails to compile nested generic lambdas
466dd2dac2fSMatthew Parkinson // inside an `extern "C"` function.
467dd2dac2fSMatthew Parkinson //   https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
468dd2dac2fSMatthew Parkinson // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
469dd2dac2fSMatthew Parkinson // a work around for older versions of Visual Studio.
47002b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
471dd2dac2fSMatthew Parkinson extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols);
472dd2dac2fSMatthew Parkinson 
47302b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_runner_init(llvm::StringMap<void * > & exportSymbols)474dd2dac2fSMatthew Parkinson void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
4751fc98642SEugene Zhulenev   auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
4761fc98642SEugene Zhulenev     assert(exportSymbols.count(name) == 0 && "symbol already exists");
4771fc98642SEugene Zhulenev     exportSymbols[name] = reinterpret_cast<void *>(ptr);
4781fc98642SEugene Zhulenev   };
4791fc98642SEugene Zhulenev 
4801fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAddRef",
4811fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAddRef);
4821fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeDropRef",
4831fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeDropRef);
4841fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeExecute",
4851fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeExecute);
4861fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeGetValueStorage",
4871fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
4881fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeCreateToken",
4891fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeCreateToken);
4901fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeCreateValue",
4911fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeCreateValue);
4921fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeEmplaceToken",
4931fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
4941fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeEmplaceValue",
4951fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
49639957aa4SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeSetTokenError",
49739957aa4SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeSetTokenError);
49839957aa4SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeSetValueError",
49939957aa4SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeSetValueError);
50039957aa4SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeIsTokenError",
50139957aa4SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeIsTokenError);
50239957aa4SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeIsValueError",
50339957aa4SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeIsValueError);
504d8c84d2aSEugene Zhulenev   exportSymbol("mlirAsyncRuntimeIsGroupError",
505d8c84d2aSEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeIsGroupError);
5061fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitToken",
5071fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitToken);
5081fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitValue",
5091fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitValue);
5101fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
5111fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
5121fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
5131fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
5141fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeCreateGroup",
5151fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeCreateGroup);
5161fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
5171fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
5181fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
5191fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
5201fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
5211fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
522*149311b4Sbakhtiyar   exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
523*149311b4Sbakhtiyar                &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
5241fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
5251fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
5261fc98642SEugene Zhulenev }
5271fc98642SEugene Zhulenev 
52802b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_runner_destroy()5291fc98642SEugene Zhulenev extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
5301fc98642SEugene Zhulenev 
531dd2dac2fSMatthew Parkinson } // namespace runtime
532dd2dac2fSMatthew Parkinson } // namespace mlir
53378b3bce2SEugene Zhulenev 
53436ce915aSLei Zhang #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
535