1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15 
16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
17 
18 #include <atomic>
19 #include <cassert>
20 #include <condition_variable>
21 #include <functional>
22 #include <iostream>
23 #include <mutex>
24 #include <thread>
25 #include <vector>
26 
27 #include "llvm/ADT/StringMap.h"
28 
29 using namespace mlir::runtime;
30 
31 //===----------------------------------------------------------------------===//
32 // Async runtime API.
33 //===----------------------------------------------------------------------===//
34 
35 namespace mlir {
36 namespace runtime {
37 namespace {
38 
39 // Forward declare class defined below.
40 class RefCounted;
41 
42 // -------------------------------------------------------------------------- //
43 // AsyncRuntime orchestrates all async operations and Async runtime API is built
44 // on top of the default runtime instance.
45 // -------------------------------------------------------------------------- //
46 
47 class AsyncRuntime {
48 public:
49   AsyncRuntime() : numRefCountedObjects(0) {}
50 
51   ~AsyncRuntime() {
52     assert(getNumRefCountedObjects() == 0 &&
53            "all ref counted objects must be destroyed");
54   }
55 
56   int32_t getNumRefCountedObjects() {
57     return numRefCountedObjects.load(std::memory_order_relaxed);
58   }
59 
60 private:
61   friend class RefCounted;
62 
63   // Count the total number of reference counted objects in this instance
64   // of an AsyncRuntime. For debugging purposes only.
65   void addNumRefCountedObjects() {
66     numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
67   }
68   void dropNumRefCountedObjects() {
69     numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
70   }
71 
72   std::atomic<int32_t> numRefCountedObjects;
73 };
74 
75 // -------------------------------------------------------------------------- //
76 // A base class for all reference counted objects created by the async runtime.
77 // -------------------------------------------------------------------------- //
78 
79 class RefCounted {
80 public:
81   RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
82       : runtime(runtime), refCount(refCount) {
83     runtime->addNumRefCountedObjects();
84   }
85 
86   virtual ~RefCounted() {
87     assert(refCount.load() == 0 && "reference count must be zero");
88     runtime->dropNumRefCountedObjects();
89   }
90 
91   RefCounted(const RefCounted &) = delete;
92   RefCounted &operator=(const RefCounted &) = delete;
93 
94   void addRef(int32_t count = 1) { refCount.fetch_add(count); }
95 
96   void dropRef(int32_t count = 1) {
97     int32_t previous = refCount.fetch_sub(count);
98     assert(previous >= count && "reference count should not go below zero");
99     if (previous == count)
100       destroy();
101   }
102 
103 protected:
104   virtual void destroy() { delete this; }
105 
106 private:
107   AsyncRuntime *runtime;
108   std::atomic<int32_t> refCount;
109 };
110 
111 } // namespace
112 
113 // Returns the default per-process instance of an async runtime.
114 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
115   static auto runtime = std::make_unique<AsyncRuntime>();
116   return runtime;
117 }
118 
119 static void resetDefaultAsyncRuntime() {
120   return getDefaultAsyncRuntimeInstance().reset();
121 }
122 
123 static AsyncRuntime *getDefaultAsyncRuntime() {
124   return getDefaultAsyncRuntimeInstance().get();
125 }
126 
127 // Async token provides a mechanism to signal asynchronous operation completion.
128 struct AsyncToken : public RefCounted {
129   // AsyncToken created with a reference count of 2 because it will be returned
130   // to the `async.execute` caller and also will be later on emplaced by the
131   // asynchronously executed task. If the caller immediately will drop its
132   // reference we must ensure that the token will be alive until the
133   // asynchronous operation is completed.
134   AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
135 
136   // Internal state below guarded by a mutex.
137   std::mutex mu;
138   std::condition_variable cv;
139 
140   bool ready = false;
141   std::vector<std::function<void()>> awaiters;
142 };
143 
144 // Async value provides a mechanism to access the result of asynchronous
145 // operations. It owns the storage that is used to store/load the value of the
146 // underlying type, and a flag to signal if the value is ready or not.
147 struct AsyncValue : public RefCounted {
148   // AsyncValue similar to an AsyncToken created with a reference count of 2.
149   AsyncValue(AsyncRuntime *runtime, int32_t size)
150       : RefCounted(runtime, /*count=*/2), storage(size) {}
151 
152   // Internal state below guarded by a mutex.
153   std::mutex mu;
154   std::condition_variable cv;
155 
156   bool ready = false;
157   std::vector<std::function<void()>> awaiters;
158 
159   // Use vector of bytes to store async value payload.
160   std::vector<int8_t> storage;
161 };
162 
163 // Async group provides a mechanism to group together multiple async tokens or
164 // values to await on all of them together (wait for the completion of all
165 // tokens or values added to the group).
166 struct AsyncGroup : public RefCounted {
167   AsyncGroup(AsyncRuntime *runtime)
168       : RefCounted(runtime), pendingTokens(0), rank(0) {}
169 
170   std::atomic<int> pendingTokens;
171   std::atomic<int> rank;
172 
173   // Internal state below guarded by a mutex.
174   std::mutex mu;
175   std::condition_variable cv;
176 
177   std::vector<std::function<void()>> awaiters;
178 };
179 
180 } // namespace runtime
181 } // namespace mlir
182 
183 // Adds references to reference counted runtime object.
184 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
185   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
186   refCounted->addRef(count);
187 }
188 
189 // Drops references from reference counted runtime object.
190 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
191   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
192   refCounted->dropRef(count);
193 }
194 
195 // Creates a new `async.token` in not-ready state.
196 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
197   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
198   return token;
199 }
200 
201 // Creates a new `async.value` in not-ready state.
202 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
203   AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
204   return value;
205 }
206 
207 // Create a new `async.group` in empty state.
208 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
209   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
210   return group;
211 }
212 
213 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
214                                                    AsyncGroup *group) {
215   std::unique_lock<std::mutex> lockToken(token->mu);
216   std::unique_lock<std::mutex> lockGroup(group->mu);
217 
218   // Get the rank of the token inside the group before we drop the reference.
219   int rank = group->rank.fetch_add(1);
220   group->pendingTokens.fetch_add(1);
221 
222   auto onTokenReady = [group]() {
223     // Run all group awaiters if it was the last token in the group.
224     if (group->pendingTokens.fetch_sub(1) == 1) {
225       group->cv.notify_all();
226       for (auto &awaiter : group->awaiters)
227         awaiter();
228     }
229   };
230 
231   if (token->ready) {
232     // Update group pending tokens immediately and maybe run awaiters.
233     onTokenReady();
234 
235   } else {
236     // Update group pending tokens when token will become ready. Because this
237     // will happen asynchronously we must ensure that `group` is alive until
238     // then, and re-ackquire the lock.
239     group->addRef();
240 
241     token->awaiters.push_back([group, onTokenReady]() {
242       // Make sure that `dropRef` does not destroy the mutex owned by the lock.
243       {
244         std::unique_lock<std::mutex> lockGroup(group->mu);
245         onTokenReady();
246       }
247       group->dropRef();
248     });
249   }
250 
251   return rank;
252 }
253 
254 // Switches `async.token` to ready state and runs all awaiters.
255 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
256   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
257   {
258     std::unique_lock<std::mutex> lock(token->mu);
259     token->ready = true;
260     token->cv.notify_all();
261     for (auto &awaiter : token->awaiters)
262       awaiter();
263   }
264 
265   // Async tokens created with a ref count `2` to keep token alive until the
266   // async task completes. Drop this reference explicitly when token emplaced.
267   token->dropRef();
268 }
269 
270 // Switches `async.value` to ready state and runs all awaiters.
271 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
272   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
273   {
274     std::unique_lock<std::mutex> lock(value->mu);
275     value->ready = true;
276     value->cv.notify_all();
277     for (auto &awaiter : value->awaiters)
278       awaiter();
279   }
280 
281   // Async values created with a ref count `2` to keep value alive until the
282   // async task completes. Drop this reference explicitly when value emplaced.
283   value->dropRef();
284 }
285 
286 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
287   std::unique_lock<std::mutex> lock(token->mu);
288   if (!token->ready)
289     token->cv.wait(lock, [token] { return token->ready; });
290 }
291 
292 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
293   std::unique_lock<std::mutex> lock(value->mu);
294   if (!value->ready)
295     value->cv.wait(lock, [value] { return value->ready; });
296 }
297 
298 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
299   std::unique_lock<std::mutex> lock(group->mu);
300   if (group->pendingTokens != 0)
301     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
302 }
303 
304 // Returns a pointer to the storage owned by the async value.
305 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
306   return value->storage.data();
307 }
308 
309 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
310   (*resume)(handle);
311 }
312 
313 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
314                                                      CoroHandle handle,
315                                                      CoroResume resume) {
316   std::unique_lock<std::mutex> lock(token->mu);
317   auto execute = [handle, resume]() { (*resume)(handle); };
318   if (token->ready)
319     execute();
320   else
321     token->awaiters.push_back([execute]() { execute(); });
322 }
323 
324 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
325                                                      CoroHandle handle,
326                                                      CoroResume resume) {
327   std::unique_lock<std::mutex> lock(value->mu);
328   auto execute = [handle, resume]() { (*resume)(handle); };
329   if (value->ready)
330     execute();
331   else
332     value->awaiters.push_back([execute]() { execute(); });
333 }
334 
335 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
336                                                           CoroHandle handle,
337                                                           CoroResume resume) {
338   std::unique_lock<std::mutex> lock(group->mu);
339   auto execute = [handle, resume]() { (*resume)(handle); };
340   if (group->pendingTokens == 0)
341     execute();
342   else
343     group->awaiters.push_back([execute]() { execute(); });
344 }
345 
346 //===----------------------------------------------------------------------===//
347 // Small async runtime support library for testing.
348 //===----------------------------------------------------------------------===//
349 
350 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
351   static thread_local std::thread::id thisId = std::this_thread::get_id();
352   std::cout << "Current thread id: " << thisId << std::endl;
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // MLIR Runner (JitRunner) dynamic library integration.
357 //===----------------------------------------------------------------------===//
358 
359 // Export symbols for the MLIR runner integration. All other symbols are hidden.
360 #ifndef _WIN32
361 #define API __attribute__((visibility("default")))
362 
363 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
364   auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
365     assert(exportSymbols.count(name) == 0 && "symbol already exists");
366     exportSymbols[name] = reinterpret_cast<void *>(ptr);
367   };
368 
369   exportSymbol("mlirAsyncRuntimeAddRef",
370                &mlir::runtime::mlirAsyncRuntimeAddRef);
371   exportSymbol("mlirAsyncRuntimeDropRef",
372                &mlir::runtime::mlirAsyncRuntimeDropRef);
373   exportSymbol("mlirAsyncRuntimeExecute",
374                &mlir::runtime::mlirAsyncRuntimeExecute);
375   exportSymbol("mlirAsyncRuntimeGetValueStorage",
376                &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
377   exportSymbol("mlirAsyncRuntimeCreateToken",
378                &mlir::runtime::mlirAsyncRuntimeCreateToken);
379   exportSymbol("mlirAsyncRuntimeCreateValue",
380                &mlir::runtime::mlirAsyncRuntimeCreateValue);
381   exportSymbol("mlirAsyncRuntimeEmplaceToken",
382                &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
383   exportSymbol("mlirAsyncRuntimeEmplaceValue",
384                &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
385   exportSymbol("mlirAsyncRuntimeAwaitToken",
386                &mlir::runtime::mlirAsyncRuntimeAwaitToken);
387   exportSymbol("mlirAsyncRuntimeAwaitValue",
388                &mlir::runtime::mlirAsyncRuntimeAwaitValue);
389   exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
390                &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
391   exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
392                &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
393   exportSymbol("mlirAsyncRuntimeCreateGroup",
394                &mlir::runtime::mlirAsyncRuntimeCreateGroup);
395   exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
396                &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
397   exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
398                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
399   exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
400                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
401   exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
402                &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
403 }
404 
405 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
406 
407 #endif // _WIN32
408 
409 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
410