1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15 
16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
17 
18 #include <atomic>
19 #include <cassert>
20 #include <condition_variable>
21 #include <functional>
22 #include <iostream>
23 #include <mutex>
24 #include <thread>
25 #include <vector>
26 
27 #include "llvm/Support/ThreadPool.h"
28 
29 //===----------------------------------------------------------------------===//
30 // Async runtime API.
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 
35 // Forward declare class defined below.
36 class RefCounted;
37 
38 // -------------------------------------------------------------------------- //
39 // AsyncRuntime orchestrates all async operations and Async runtime API is built
40 // on top of the default runtime instance.
41 // -------------------------------------------------------------------------- //
42 
43 class AsyncRuntime {
44 public:
45   AsyncRuntime() : numRefCountedObjects(0) {}
46 
47   ~AsyncRuntime() {
48     threadPool.wait(); // wait for the completion of all async tasks
49     assert(getNumRefCountedObjects() == 0 &&
50            "all ref counted objects must be destroyed");
51   }
52 
53   int32_t getNumRefCountedObjects() {
54     return numRefCountedObjects.load(std::memory_order_relaxed);
55   }
56 
57   llvm::ThreadPool &getThreadPool() { return threadPool; }
58 
59 private:
60   friend class RefCounted;
61 
62   // Count the total number of reference counted objects in this instance
63   // of an AsyncRuntime. For debugging purposes only.
64   void addNumRefCountedObjects() {
65     numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
66   }
67   void dropNumRefCountedObjects() {
68     numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
69   }
70 
71   std::atomic<int32_t> numRefCountedObjects;
72 
73   llvm::ThreadPool threadPool;
74 };
75 
76 // Returns the default per-process instance of an async runtime.
77 AsyncRuntime *getDefaultAsyncRuntimeInstance() {
78   static auto runtime = std::make_unique<AsyncRuntime>();
79   return runtime.get();
80 }
81 
82 // -------------------------------------------------------------------------- //
83 // A base class for all reference counted objects created by the async runtime.
84 // -------------------------------------------------------------------------- //
85 
86 class RefCounted {
87 public:
88   RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
89       : runtime(runtime), refCount(refCount) {
90     runtime->addNumRefCountedObjects();
91   }
92 
93   virtual ~RefCounted() {
94     assert(refCount.load() == 0 && "reference count must be zero");
95     runtime->dropNumRefCountedObjects();
96   }
97 
98   RefCounted(const RefCounted &) = delete;
99   RefCounted &operator=(const RefCounted &) = delete;
100 
101   void addRef(int32_t count = 1) { refCount.fetch_add(count); }
102 
103   void dropRef(int32_t count = 1) {
104     int32_t previous = refCount.fetch_sub(count);
105     assert(previous >= count && "reference count should not go below zero");
106     if (previous == count)
107       destroy();
108   }
109 
110 protected:
111   virtual void destroy() { delete this; }
112 
113 private:
114   AsyncRuntime *runtime;
115   std::atomic<int32_t> refCount;
116 };
117 
118 } // namespace
119 
120 struct AsyncToken : public RefCounted {
121   // AsyncToken created with a reference count of 2 because it will be returned
122   // to the `async.execute` caller and also will be later on emplaced by the
123   // asynchronously executed task. If the caller immediately will drop its
124   // reference we must ensure that the token will be alive until the
125   // asynchronous operation is completed.
126   AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
127 
128   // Internal state below guarded by a mutex.
129   std::mutex mu;
130   std::condition_variable cv;
131 
132   bool ready = false;
133   std::vector<std::function<void()>> awaiters;
134 };
135 
136 struct AsyncGroup : public RefCounted {
137   AsyncGroup(AsyncRuntime *runtime)
138       : RefCounted(runtime), pendingTokens(0), rank(0) {}
139 
140   std::atomic<int> pendingTokens;
141   std::atomic<int> rank;
142 
143   // Internal state below guarded by a mutex.
144   std::mutex mu;
145   std::condition_variable cv;
146 
147   std::vector<std::function<void()>> awaiters;
148 };
149 
150 // Adds references to reference counted runtime object.
151 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
152   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
153   refCounted->addRef(count);
154 }
155 
156 // Drops references from reference counted runtime object.
157 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
158   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
159   refCounted->dropRef(count);
160 }
161 
162 // Create a new `async.token` in not-ready state.
163 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
164   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
165   return token;
166 }
167 
168 // Create a new `async.group` in empty state.
169 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
170   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
171   return group;
172 }
173 
174 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
175                                                    AsyncGroup *group) {
176   std::unique_lock<std::mutex> lockToken(token->mu);
177   std::unique_lock<std::mutex> lockGroup(group->mu);
178 
179   // Get the rank of the token inside the group before we drop the reference.
180   int rank = group->rank.fetch_add(1);
181   group->pendingTokens.fetch_add(1);
182 
183   auto onTokenReady = [group]() {
184     // Run all group awaiters if it was the last token in the group.
185     if (group->pendingTokens.fetch_sub(1) == 1) {
186       group->cv.notify_all();
187       for (auto &awaiter : group->awaiters)
188         awaiter();
189     }
190   };
191 
192   if (token->ready) {
193     // Update group pending tokens immediately and maybe run awaiters.
194     onTokenReady();
195 
196   } else {
197     // Update group pending tokens when token will become ready. Because this
198     // will happen asynchronously we must ensure that `group` is alive until
199     // then, and re-ackquire the lock.
200     group->addRef();
201 
202     token->awaiters.push_back([group, onTokenReady]() {
203       // Make sure that `dropRef` does not destroy the mutex owned by the lock.
204       {
205         std::unique_lock<std::mutex> lockGroup(group->mu);
206         onTokenReady();
207       }
208       group->dropRef();
209     });
210   }
211 
212   return rank;
213 }
214 
215 // Switches `async.token` to ready state and runs all awaiters.
216 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
217   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
218   {
219     std::unique_lock<std::mutex> lock(token->mu);
220     token->ready = true;
221     token->cv.notify_all();
222     for (auto &awaiter : token->awaiters)
223       awaiter();
224   }
225 
226   // Async tokens created with a ref count `2` to keep token alive until the
227   // async task completes. Drop this reference explicitly when token emplaced.
228   token->dropRef();
229 }
230 
231 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
232   std::unique_lock<std::mutex> lock(token->mu);
233   if (!token->ready)
234     token->cv.wait(lock, [token] { return token->ready; });
235 }
236 
237 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
238   std::unique_lock<std::mutex> lock(group->mu);
239   if (group->pendingTokens != 0)
240     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
241 }
242 
243 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
244   auto *runtime = getDefaultAsyncRuntimeInstance();
245   runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
246 }
247 
248 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
249                                                      CoroHandle handle,
250                                                      CoroResume resume) {
251   std::unique_lock<std::mutex> lock(token->mu);
252   auto execute = [handle, resume]() { (*resume)(handle); };
253   if (token->ready)
254     execute();
255   else
256     token->awaiters.push_back([execute]() { execute(); });
257 }
258 
259 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
260                                                           CoroHandle handle,
261                                                           CoroResume resume) {
262   std::unique_lock<std::mutex> lock(group->mu);
263   auto execute = [handle, resume]() { (*resume)(handle); };
264   if (group->pendingTokens == 0)
265     execute();
266   else
267     group->awaiters.push_back([execute]() { execute(); });
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // Small async runtime support library for testing.
272 //===----------------------------------------------------------------------===//
273 
274 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
275   static thread_local std::thread::id thisId = std::this_thread::get_id();
276   std::cout << "Current thread id: " << thisId << std::endl;
277 }
278 
279 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
280