1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15 
16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
17 
18 #include <atomic>
19 #include <cassert>
20 #include <condition_variable>
21 #include <functional>
22 #include <iostream>
23 #include <mutex>
24 #include <thread>
25 #include <vector>
26 
27 #include "llvm/Support/ThreadPool.h"
28 
29 //===----------------------------------------------------------------------===//
30 // Async runtime API.
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 
35 // Forward declare class defined below.
36 class RefCounted;
37 
38 // -------------------------------------------------------------------------- //
39 // AsyncRuntime orchestrates all async operations and Async runtime API is built
40 // on top of the default runtime instance.
41 // -------------------------------------------------------------------------- //
42 
43 class AsyncRuntime {
44 public:
45   AsyncRuntime() : numRefCountedObjects(0) {}
46 
47   ~AsyncRuntime() {
48     assert(getNumRefCountedObjects() == 0 &&
49            "all ref counted objects must be destroyed");
50   }
51 
52   int32_t getNumRefCountedObjects() {
53     return numRefCountedObjects.load(std::memory_order_relaxed);
54   }
55 
56 private:
57   friend class RefCounted;
58 
59   // Count the total number of reference counted objects in this instance
60   // of an AsyncRuntime. For debugging purposes only.
61   void addNumRefCountedObjects() {
62     numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
63   }
64   void dropNumRefCountedObjects() {
65     numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
66   }
67 
68   std::atomic<int32_t> numRefCountedObjects;
69 };
70 
71 // Returns the default per-process instance of an async runtime.
72 AsyncRuntime *getDefaultAsyncRuntimeInstance() {
73   static auto runtime = std::make_unique<AsyncRuntime>();
74   return runtime.get();
75 }
76 
77 // -------------------------------------------------------------------------- //
78 // A base class for all reference counted objects created by the async runtime.
79 // -------------------------------------------------------------------------- //
80 
81 class RefCounted {
82 public:
83   RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
84       : runtime(runtime), refCount(refCount) {
85     runtime->addNumRefCountedObjects();
86   }
87 
88   virtual ~RefCounted() {
89     assert(refCount.load() == 0 && "reference count must be zero");
90     runtime->dropNumRefCountedObjects();
91   }
92 
93   RefCounted(const RefCounted &) = delete;
94   RefCounted &operator=(const RefCounted &) = delete;
95 
96   void addRef(int32_t count = 1) { refCount.fetch_add(count); }
97 
98   void dropRef(int32_t count = 1) {
99     int32_t previous = refCount.fetch_sub(count);
100     assert(previous >= count && "reference count should not go below zero");
101     if (previous == count)
102       destroy();
103   }
104 
105 protected:
106   virtual void destroy() { delete this; }
107 
108 private:
109   AsyncRuntime *runtime;
110   std::atomic<int32_t> refCount;
111 };
112 
113 } // namespace
114 
115 struct AsyncToken : public RefCounted {
116   // AsyncToken created with a reference count of 2 because it will be returned
117   // to the `async.execute` caller and also will be later on emplaced by the
118   // asynchronously executed task. If the caller immediately will drop its
119   // reference we must ensure that the token will be alive until the
120   // asynchronous operation is completed.
121   AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
122 
123   // Internal state below guarded by a mutex.
124   std::mutex mu;
125   std::condition_variable cv;
126 
127   bool ready = false;
128   std::vector<std::function<void()>> awaiters;
129 };
130 
131 struct AsyncGroup : public RefCounted {
132   AsyncGroup(AsyncRuntime *runtime)
133       : RefCounted(runtime), pendingTokens(0), rank(0) {}
134 
135   std::atomic<int> pendingTokens;
136   std::atomic<int> rank;
137 
138   // Internal state below guarded by a mutex.
139   std::mutex mu;
140   std::condition_variable cv;
141 
142   std::vector<std::function<void()>> awaiters;
143 };
144 
145 // Adds references to reference counted runtime object.
146 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
147 mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
148   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
149   refCounted->addRef(count);
150 }
151 
152 // Drops references from reference counted runtime object.
153 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
154 mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
155   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
156   refCounted->dropRef(count);
157 }
158 
159 // Create a new `async.token` in not-ready state.
160 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
161   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
162   return token;
163 }
164 
165 // Create a new `async.group` in empty state.
166 extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
167   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
168   return group;
169 }
170 
171 extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
172 mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
173   std::unique_lock<std::mutex> lockToken(token->mu);
174   std::unique_lock<std::mutex> lockGroup(group->mu);
175 
176   // Get the rank of the token inside the group before we drop the reference.
177   int rank = group->rank.fetch_add(1);
178   group->pendingTokens.fetch_add(1);
179 
180   auto onTokenReady = [group, token](bool dropRef) {
181     // Run all group awaiters if it was the last token in the group.
182     if (group->pendingTokens.fetch_sub(1) == 1) {
183       group->cv.notify_all();
184       for (auto &awaiter : group->awaiters)
185         awaiter();
186     }
187 
188     // We no longer need the token or the group, drop references on them.
189     if (dropRef) {
190       group->dropRef();
191       token->dropRef();
192     }
193   };
194 
195   if (token->ready) {
196     onTokenReady(false);
197   } else {
198     group->addRef();
199     token->addRef();
200     token->awaiters.push_back([onTokenReady]() { onTokenReady(true); });
201   }
202 
203   return rank;
204 }
205 
206 // Switches `async.token` to ready state and runs all awaiters.
207 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
208   std::unique_lock<std::mutex> lock(token->mu);
209   token->ready = true;
210   token->cv.notify_all();
211   for (auto &awaiter : token->awaiters)
212     awaiter();
213 
214   // Async tokens created with a ref count `2` to keep token alive until the
215   // async task completes. Drop this reference explicitly when token emplaced.
216   token->dropRef();
217 }
218 
219 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
220   std::unique_lock<std::mutex> lock(token->mu);
221   if (!token->ready)
222     token->cv.wait(lock, [token] { return token->ready; });
223 }
224 
225 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
226 mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
227   std::unique_lock<std::mutex> lock(group->mu);
228   if (group->pendingTokens != 0)
229     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
230 }
231 
232 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
233 #if LLVM_ENABLE_THREADS
234   static llvm::ThreadPool *threadPool = new llvm::ThreadPool();
235   threadPool->async([handle, resume]() { (*resume)(handle); });
236 #else
237   (*resume)(handle);
238 #endif
239 }
240 
241 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
242                                                      CoroHandle handle,
243                                                      CoroResume resume) {
244   std::unique_lock<std::mutex> lock(token->mu);
245 
246   auto execute = [handle, resume, token](bool dropRef) {
247     if (dropRef)
248       token->dropRef();
249     mlirAsyncRuntimeExecute(handle, resume);
250   };
251 
252   if (token->ready) {
253     execute(false);
254   } else {
255     token->addRef();
256     token->awaiters.push_back([execute]() { execute(true); });
257   }
258 }
259 
260 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
261 mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
262                                           CoroResume resume) {
263   std::unique_lock<std::mutex> lock(group->mu);
264 
265   auto execute = [handle, resume, group](bool dropRef) {
266     if (dropRef)
267       group->dropRef();
268     mlirAsyncRuntimeExecute(handle, resume);
269   };
270 
271   if (group->pendingTokens == 0) {
272     execute(false);
273   } else {
274     group->addRef();
275     group->awaiters.push_back([execute]() { execute(true); });
276   }
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // Small async runtime support library for testing.
281 //===----------------------------------------------------------------------===//
282 
283 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
284   static thread_local std::thread::id thisId = std::this_thread::get_id();
285   std::cout << "Current thread id: " << thisId << "\n";
286 }
287 
288 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
289