1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15 
16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
17 
18 #include <atomic>
19 #include <cassert>
20 #include <condition_variable>
21 #include <functional>
22 #include <iostream>
23 #include <mutex>
24 #include <thread>
25 #include <vector>
26 
27 using namespace mlir::runtime;
28 
29 //===----------------------------------------------------------------------===//
30 // Async runtime API.
31 //===----------------------------------------------------------------------===//
32 
33 namespace mlir {
34 namespace runtime {
35 namespace {
36 
37 // Forward declare class defined below.
38 class RefCounted;
39 
40 // -------------------------------------------------------------------------- //
41 // AsyncRuntime orchestrates all async operations and Async runtime API is built
42 // on top of the default runtime instance.
43 // -------------------------------------------------------------------------- //
44 
45 class AsyncRuntime {
46 public:
47   AsyncRuntime() : numRefCountedObjects(0) {}
48 
49   ~AsyncRuntime() {
50     assert(getNumRefCountedObjects() == 0 &&
51            "all ref counted objects must be destroyed");
52   }
53 
54   int32_t getNumRefCountedObjects() {
55     return numRefCountedObjects.load(std::memory_order_relaxed);
56   }
57 
58 private:
59   friend class RefCounted;
60 
61   // Count the total number of reference counted objects in this instance
62   // of an AsyncRuntime. For debugging purposes only.
63   void addNumRefCountedObjects() {
64     numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
65   }
66   void dropNumRefCountedObjects() {
67     numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
68   }
69 
70   std::atomic<int32_t> numRefCountedObjects;
71 };
72 
73 // -------------------------------------------------------------------------- //
74 // A base class for all reference counted objects created by the async runtime.
75 // -------------------------------------------------------------------------- //
76 
77 class RefCounted {
78 public:
79   RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
80       : runtime(runtime), refCount(refCount) {
81     runtime->addNumRefCountedObjects();
82   }
83 
84   virtual ~RefCounted() {
85     assert(refCount.load() == 0 && "reference count must be zero");
86     runtime->dropNumRefCountedObjects();
87   }
88 
89   RefCounted(const RefCounted &) = delete;
90   RefCounted &operator=(const RefCounted &) = delete;
91 
92   void addRef(int32_t count = 1) { refCount.fetch_add(count); }
93 
94   void dropRef(int32_t count = 1) {
95     int32_t previous = refCount.fetch_sub(count);
96     assert(previous >= count && "reference count should not go below zero");
97     if (previous == count)
98       destroy();
99   }
100 
101 protected:
102   virtual void destroy() { delete this; }
103 
104 private:
105   AsyncRuntime *runtime;
106   std::atomic<int32_t> refCount;
107 };
108 
109 } // namespace
110 
111 // Returns the default per-process instance of an async runtime.
112 static AsyncRuntime *getDefaultAsyncRuntimeInstance() {
113   static auto runtime = std::make_unique<AsyncRuntime>();
114   return runtime.get();
115 }
116 
117 // Async token provides a mechanism to signal asynchronous operation completion.
118 struct AsyncToken : public RefCounted {
119   // AsyncToken created with a reference count of 2 because it will be returned
120   // to the `async.execute` caller and also will be later on emplaced by the
121   // asynchronously executed task. If the caller immediately will drop its
122   // reference we must ensure that the token will be alive until the
123   // asynchronous operation is completed.
124   AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
125 
126   // Internal state below guarded by a mutex.
127   std::mutex mu;
128   std::condition_variable cv;
129 
130   bool ready = false;
131   std::vector<std::function<void()>> awaiters;
132 };
133 
134 // Async value provides a mechanism to access the result of asynchronous
135 // operations. It owns the storage that is used to store/load the value of the
136 // underlying type, and a flag to signal if the value is ready or not.
137 struct AsyncValue : public RefCounted {
138   // AsyncValue similar to an AsyncToken created with a reference count of 2.
139   AsyncValue(AsyncRuntime *runtime, int32_t size)
140       : RefCounted(runtime, /*count=*/2), storage(size) {}
141 
142   // Internal state below guarded by a mutex.
143   std::mutex mu;
144   std::condition_variable cv;
145 
146   bool ready = false;
147   std::vector<std::function<void()>> awaiters;
148 
149   // Use vector of bytes to store async value payload.
150   std::vector<int8_t> storage;
151 };
152 
153 // Async group provides a mechanism to group together multiple async tokens or
154 // values to await on all of them together (wait for the completion of all
155 // tokens or values added to the group).
156 struct AsyncGroup : public RefCounted {
157   AsyncGroup(AsyncRuntime *runtime)
158       : RefCounted(runtime), pendingTokens(0), rank(0) {}
159 
160   std::atomic<int> pendingTokens;
161   std::atomic<int> rank;
162 
163   // Internal state below guarded by a mutex.
164   std::mutex mu;
165   std::condition_variable cv;
166 
167   std::vector<std::function<void()>> awaiters;
168 };
169 
170 } // namespace runtime
171 } // namespace mlir
172 
173 // Adds references to reference counted runtime object.
174 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
175   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
176   refCounted->addRef(count);
177 }
178 
179 // Drops references from reference counted runtime object.
180 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
181   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
182   refCounted->dropRef(count);
183 }
184 
185 // Creates a new `async.token` in not-ready state.
186 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
187   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
188   return token;
189 }
190 
191 // Creates a new `async.value` in not-ready state.
192 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
193   AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size);
194   return value;
195 }
196 
197 // Create a new `async.group` in empty state.
198 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
199   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
200   return group;
201 }
202 
203 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
204                                                    AsyncGroup *group) {
205   std::unique_lock<std::mutex> lockToken(token->mu);
206   std::unique_lock<std::mutex> lockGroup(group->mu);
207 
208   // Get the rank of the token inside the group before we drop the reference.
209   int rank = group->rank.fetch_add(1);
210   group->pendingTokens.fetch_add(1);
211 
212   auto onTokenReady = [group]() {
213     // Run all group awaiters if it was the last token in the group.
214     if (group->pendingTokens.fetch_sub(1) == 1) {
215       group->cv.notify_all();
216       for (auto &awaiter : group->awaiters)
217         awaiter();
218     }
219   };
220 
221   if (token->ready) {
222     // Update group pending tokens immediately and maybe run awaiters.
223     onTokenReady();
224 
225   } else {
226     // Update group pending tokens when token will become ready. Because this
227     // will happen asynchronously we must ensure that `group` is alive until
228     // then, and re-ackquire the lock.
229     group->addRef();
230 
231     token->awaiters.push_back([group, onTokenReady]() {
232       // Make sure that `dropRef` does not destroy the mutex owned by the lock.
233       {
234         std::unique_lock<std::mutex> lockGroup(group->mu);
235         onTokenReady();
236       }
237       group->dropRef();
238     });
239   }
240 
241   return rank;
242 }
243 
244 // Switches `async.token` to ready state and runs all awaiters.
245 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
246   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
247   {
248     std::unique_lock<std::mutex> lock(token->mu);
249     token->ready = true;
250     token->cv.notify_all();
251     for (auto &awaiter : token->awaiters)
252       awaiter();
253   }
254 
255   // Async tokens created with a ref count `2` to keep token alive until the
256   // async task completes. Drop this reference explicitly when token emplaced.
257   token->dropRef();
258 }
259 
260 // Switches `async.value` to ready state and runs all awaiters.
261 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
262   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
263   {
264     std::unique_lock<std::mutex> lock(value->mu);
265     value->ready = true;
266     value->cv.notify_all();
267     for (auto &awaiter : value->awaiters)
268       awaiter();
269   }
270 
271   // Async values created with a ref count `2` to keep value alive until the
272   // async task completes. Drop this reference explicitly when value emplaced.
273   value->dropRef();
274 }
275 
276 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
277   std::unique_lock<std::mutex> lock(token->mu);
278   if (!token->ready)
279     token->cv.wait(lock, [token] { return token->ready; });
280 }
281 
282 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
283   std::unique_lock<std::mutex> lock(value->mu);
284   if (!value->ready)
285     value->cv.wait(lock, [value] { return value->ready; });
286 }
287 
288 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
289   std::unique_lock<std::mutex> lock(group->mu);
290   if (group->pendingTokens != 0)
291     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
292 }
293 
294 // Returns a pointer to the storage owned by the async value.
295 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
296   return value->storage.data();
297 }
298 
299 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
300   (*resume)(handle);
301 }
302 
303 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
304                                                      CoroHandle handle,
305                                                      CoroResume resume) {
306   std::unique_lock<std::mutex> lock(token->mu);
307   auto execute = [handle, resume]() { (*resume)(handle); };
308   if (token->ready)
309     execute();
310   else
311     token->awaiters.push_back([execute]() { execute(); });
312 }
313 
314 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
315                                                      CoroHandle handle,
316                                                      CoroResume resume) {
317   std::unique_lock<std::mutex> lock(value->mu);
318   auto execute = [handle, resume]() { (*resume)(handle); };
319   if (value->ready)
320     execute();
321   else
322     value->awaiters.push_back([execute]() { execute(); });
323 }
324 
325 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
326                                                           CoroHandle handle,
327                                                           CoroResume resume) {
328   std::unique_lock<std::mutex> lock(group->mu);
329   auto execute = [handle, resume]() { (*resume)(handle); };
330   if (group->pendingTokens == 0)
331     execute();
332   else
333     group->awaiters.push_back([execute]() { execute(); });
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // Small async runtime support library for testing.
338 //===----------------------------------------------------------------------===//
339 
340 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
341   static thread_local std::thread::id thisId = std::this_thread::get_id();
342   std::cout << "Current thread id: " << thisId << std::endl;
343 }
344 
345 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
346