1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15 
16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
17 
18 #include <atomic>
19 #include <cassert>
20 #include <condition_variable>
21 #include <functional>
22 #include <iostream>
23 #include <mutex>
24 #include <thread>
25 #include <vector>
26 
27 //===----------------------------------------------------------------------===//
28 // Async runtime API.
29 //===----------------------------------------------------------------------===//
30 
31 namespace {
32 
33 // Forward declare class defined below.
34 class RefCounted;
35 
36 // -------------------------------------------------------------------------- //
37 // AsyncRuntime orchestrates all async operations and Async runtime API is built
38 // on top of the default runtime instance.
39 // -------------------------------------------------------------------------- //
40 
41 class AsyncRuntime {
42 public:
43   AsyncRuntime() : numRefCountedObjects(0) {}
44 
45   ~AsyncRuntime() {
46     assert(getNumRefCountedObjects() == 0 &&
47            "all ref counted objects must be destroyed");
48   }
49 
50   int32_t getNumRefCountedObjects() {
51     return numRefCountedObjects.load(std::memory_order_relaxed);
52   }
53 
54 private:
55   friend class RefCounted;
56 
57   // Count the total number of reference counted objects in this instance
58   // of an AsyncRuntime. For debugging purposes only.
59   void addNumRefCountedObjects() {
60     numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
61   }
62   void dropNumRefCountedObjects() {
63     numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
64   }
65 
66   std::atomic<int32_t> numRefCountedObjects;
67 };
68 
69 // Returns the default per-process instance of an async runtime.
70 AsyncRuntime *getDefaultAsyncRuntimeInstance() {
71   static auto runtime = std::make_unique<AsyncRuntime>();
72   return runtime.get();
73 }
74 
75 // -------------------------------------------------------------------------- //
76 // A base class for all reference counted objects created by the async runtime.
77 // -------------------------------------------------------------------------- //
78 
79 class RefCounted {
80 public:
81   RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
82       : runtime(runtime), refCount(refCount) {
83     runtime->addNumRefCountedObjects();
84   }
85 
86   virtual ~RefCounted() {
87     assert(refCount.load() == 0 && "reference count must be zero");
88     runtime->dropNumRefCountedObjects();
89   }
90 
91   RefCounted(const RefCounted &) = delete;
92   RefCounted &operator=(const RefCounted &) = delete;
93 
94   void addRef(int32_t count = 1) { refCount.fetch_add(count); }
95 
96   void dropRef(int32_t count = 1) {
97     int32_t previous = refCount.fetch_sub(count);
98     assert(previous >= count && "reference count should not go below zero");
99     if (previous == count)
100       destroy();
101   }
102 
103 protected:
104   virtual void destroy() { delete this; }
105 
106 private:
107   AsyncRuntime *runtime;
108   std::atomic<int32_t> refCount;
109 };
110 
111 } // namespace
112 
113 struct AsyncToken : public RefCounted {
114   // AsyncToken created with a reference count of 2 because it will be returned
115   // to the `async.execute` caller and also will be later on emplaced by the
116   // asynchronously executed task. If the caller immediately will drop its
117   // reference we must ensure that the token will be alive until the
118   // asynchronous operation is completed.
119   AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
120 
121   // Internal state below guarded by a mutex.
122   std::mutex mu;
123   std::condition_variable cv;
124 
125   bool ready = false;
126   std::vector<std::function<void()>> awaiters;
127 };
128 
129 struct AsyncGroup : public RefCounted {
130   AsyncGroup(AsyncRuntime *runtime)
131       : RefCounted(runtime), pendingTokens(0), rank(0) {}
132 
133   std::atomic<int> pendingTokens;
134   std::atomic<int> rank;
135 
136   // Internal state below guarded by a mutex.
137   std::mutex mu;
138   std::condition_variable cv;
139 
140   std::vector<std::function<void()>> awaiters;
141 };
142 
143 // Adds references to reference counted runtime object.
144 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
145 mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
146   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
147   refCounted->addRef(count);
148 }
149 
150 // Drops references from reference counted runtime object.
151 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
152 mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
153   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
154   refCounted->dropRef(count);
155 }
156 
157 // Create a new `async.token` in not-ready state.
158 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
159   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
160   return token;
161 }
162 
163 // Create a new `async.group` in empty state.
164 extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
165   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
166   return group;
167 }
168 
169 extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
170 mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
171   std::unique_lock<std::mutex> lockToken(token->mu);
172   std::unique_lock<std::mutex> lockGroup(group->mu);
173 
174   // Get the rank of the token inside the group before we drop the reference.
175   int rank = group->rank.fetch_add(1);
176   group->pendingTokens.fetch_add(1);
177 
178   auto onTokenReady = [group, token](bool dropRef) {
179     // Run all group awaiters if it was the last token in the group.
180     if (group->pendingTokens.fetch_sub(1) == 1) {
181       group->cv.notify_all();
182       for (auto &awaiter : group->awaiters)
183         awaiter();
184     }
185 
186     // We no longer need the token or the group, drop references on them.
187     if (dropRef) {
188       group->dropRef();
189       token->dropRef();
190     }
191   };
192 
193   if (token->ready) {
194     onTokenReady(false);
195   } else {
196     group->addRef();
197     token->addRef();
198     token->awaiters.push_back([onTokenReady]() { onTokenReady(true); });
199   }
200 
201   return rank;
202 }
203 
204 // Switches `async.token` to ready state and runs all awaiters.
205 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
206   std::unique_lock<std::mutex> lock(token->mu);
207   token->ready = true;
208   token->cv.notify_all();
209   for (auto &awaiter : token->awaiters)
210     awaiter();
211 
212   // Async tokens created with a ref count `2` to keep token alive until the
213   // async task completes. Drop this reference explicitly when token emplaced.
214   token->dropRef();
215 }
216 
217 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
218   std::unique_lock<std::mutex> lock(token->mu);
219   if (!token->ready)
220     token->cv.wait(lock, [token] { return token->ready; });
221 }
222 
223 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
224 mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
225   std::unique_lock<std::mutex> lock(group->mu);
226   if (group->pendingTokens != 0)
227     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
228 }
229 
230 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
231 #if LLVM_ENABLE_THREADS
232   std::thread thread([handle, resume]() { (*resume)(handle); });
233   thread.detach();
234 #else
235   (*resume)(handle);
236 #endif
237 }
238 
239 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
240                                                      CoroHandle handle,
241                                                      CoroResume resume) {
242   std::unique_lock<std::mutex> lock(token->mu);
243 
244   auto execute = [handle, resume, token](bool dropRef) {
245     if (dropRef)
246       token->dropRef();
247     mlirAsyncRuntimeExecute(handle, resume);
248   };
249 
250   if (token->ready) {
251     execute(false);
252   } else {
253     token->addRef();
254     token->awaiters.push_back([execute]() { execute(true); });
255   }
256 }
257 
258 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
259 mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
260                                           CoroResume resume) {
261   std::unique_lock<std::mutex> lock(group->mu);
262 
263   auto execute = [handle, resume, group](bool dropRef) {
264     if (dropRef)
265       group->dropRef();
266     mlirAsyncRuntimeExecute(handle, resume);
267   };
268 
269   if (group->pendingTokens == 0) {
270     execute(false);
271   } else {
272     group->addRef();
273     group->awaiters.push_back([execute]() { execute(true); });
274   }
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // Small async runtime support library for testing.
279 //===----------------------------------------------------------------------===//
280 
281 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
282   static thread_local std::thread::id thisId = std::this_thread::get_id();
283   std::cout << "Current thread id: " << thisId << "\n";
284 }
285 
286 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
287