1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15 
16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
17 
18 #include <atomic>
19 #include <cassert>
20 #include <condition_variable>
21 #include <functional>
22 #include <iostream>
23 #include <mutex>
24 #include <thread>
25 #include <vector>
26 
27 #include "llvm/ADT/StringMap.h"
28 #include "llvm/Support/ThreadPool.h"
29 
30 using namespace mlir::runtime;
31 
32 //===----------------------------------------------------------------------===//
33 // Async runtime API.
34 //===----------------------------------------------------------------------===//
35 
36 namespace mlir {
37 namespace runtime {
38 namespace {
39 
40 // Forward declare class defined below.
41 class RefCounted;
42 
43 // -------------------------------------------------------------------------- //
44 // AsyncRuntime orchestrates all async operations and Async runtime API is built
45 // on top of the default runtime instance.
46 // -------------------------------------------------------------------------- //
47 
48 class AsyncRuntime {
49 public:
50   AsyncRuntime() : numRefCountedObjects(0) {}
51 
52   ~AsyncRuntime() {
53     threadPool.wait(); // wait for the completion of all async tasks
54     assert(getNumRefCountedObjects() == 0 &&
55            "all ref counted objects must be destroyed");
56   }
57 
58   int32_t getNumRefCountedObjects() {
59     return numRefCountedObjects.load(std::memory_order_relaxed);
60   }
61 
62   llvm::ThreadPool &getThreadPool() { return threadPool; }
63 
64 private:
65   friend class RefCounted;
66 
67   // Count the total number of reference counted objects in this instance
68   // of an AsyncRuntime. For debugging purposes only.
69   void addNumRefCountedObjects() {
70     numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
71   }
72   void dropNumRefCountedObjects() {
73     numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
74   }
75 
76   std::atomic<int32_t> numRefCountedObjects;
77   llvm::ThreadPool threadPool;
78 };
79 
80 // -------------------------------------------------------------------------- //
81 // A base class for all reference counted objects created by the async runtime.
82 // -------------------------------------------------------------------------- //
83 
84 class RefCounted {
85 public:
86   RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
87       : runtime(runtime), refCount(refCount) {
88     runtime->addNumRefCountedObjects();
89   }
90 
91   virtual ~RefCounted() {
92     assert(refCount.load() == 0 && "reference count must be zero");
93     runtime->dropNumRefCountedObjects();
94   }
95 
96   RefCounted(const RefCounted &) = delete;
97   RefCounted &operator=(const RefCounted &) = delete;
98 
99   void addRef(int32_t count = 1) { refCount.fetch_add(count); }
100 
101   void dropRef(int32_t count = 1) {
102     int32_t previous = refCount.fetch_sub(count);
103     assert(previous >= count && "reference count should not go below zero");
104     if (previous == count)
105       destroy();
106   }
107 
108 protected:
109   virtual void destroy() { delete this; }
110 
111 private:
112   AsyncRuntime *runtime;
113   std::atomic<int32_t> refCount;
114 };
115 
116 } // namespace
117 
118 // Returns the default per-process instance of an async runtime.
119 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
120   static auto runtime = std::make_unique<AsyncRuntime>();
121   return runtime;
122 }
123 
124 static void resetDefaultAsyncRuntime() {
125   return getDefaultAsyncRuntimeInstance().reset();
126 }
127 
128 static AsyncRuntime *getDefaultAsyncRuntime() {
129   return getDefaultAsyncRuntimeInstance().get();
130 }
131 
132 // Async token provides a mechanism to signal asynchronous operation completion.
133 struct AsyncToken : public RefCounted {
134   // AsyncToken created with a reference count of 2 because it will be returned
135   // to the `async.execute` caller and also will be later on emplaced by the
136   // asynchronously executed task. If the caller immediately will drop its
137   // reference we must ensure that the token will be alive until the
138   // asynchronous operation is completed.
139   AsyncToken(AsyncRuntime *runtime)
140       : RefCounted(runtime, /*count=*/2), ready(false) {}
141 
142   std::atomic<bool> ready;
143 
144   // Pending awaiters are guarded by a mutex.
145   std::mutex mu;
146   std::condition_variable cv;
147   std::vector<std::function<void()>> awaiters;
148 };
149 
150 // Async value provides a mechanism to access the result of asynchronous
151 // operations. It owns the storage that is used to store/load the value of the
152 // underlying type, and a flag to signal if the value is ready or not.
153 struct AsyncValue : public RefCounted {
154   // AsyncValue similar to an AsyncToken created with a reference count of 2.
155   AsyncValue(AsyncRuntime *runtime, int32_t size)
156       : RefCounted(runtime, /*count=*/2), ready(false), storage(size) {}
157 
158   std::atomic<bool> ready;
159 
160   // Use vector of bytes to store async value payload.
161   std::vector<int8_t> storage;
162 
163   // Pending awaiters are guarded by a mutex.
164   std::mutex mu;
165   std::condition_variable cv;
166   std::vector<std::function<void()>> awaiters;
167 };
168 
169 // Async group provides a mechanism to group together multiple async tokens or
170 // values to await on all of them together (wait for the completion of all
171 // tokens or values added to the group).
172 struct AsyncGroup : public RefCounted {
173   AsyncGroup(AsyncRuntime *runtime)
174       : RefCounted(runtime), pendingTokens(0), rank(0) {}
175 
176   std::atomic<int> pendingTokens;
177   std::atomic<int> rank;
178 
179   // Pending awaiters are guarded by a mutex.
180   std::mutex mu;
181   std::condition_variable cv;
182   std::vector<std::function<void()>> awaiters;
183 };
184 
185 
186 // Adds references to reference counted runtime object.
187 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
188   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
189   refCounted->addRef(count);
190 }
191 
192 // Drops references from reference counted runtime object.
193 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
194   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
195   refCounted->dropRef(count);
196 }
197 
198 // Creates a new `async.token` in not-ready state.
199 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
200   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
201   return token;
202 }
203 
204 // Creates a new `async.value` in not-ready state.
205 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
206   AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
207   return value;
208 }
209 
210 // Create a new `async.group` in empty state.
211 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
212   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
213   return group;
214 }
215 
216 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
217                                                    AsyncGroup *group) {
218   std::unique_lock<std::mutex> lockToken(token->mu);
219   std::unique_lock<std::mutex> lockGroup(group->mu);
220 
221   // Get the rank of the token inside the group before we drop the reference.
222   int rank = group->rank.fetch_add(1);
223   group->pendingTokens.fetch_add(1);
224 
225   auto onTokenReady = [group]() {
226     // Run all group awaiters if it was the last token in the group.
227     if (group->pendingTokens.fetch_sub(1) == 1) {
228       group->cv.notify_all();
229       for (auto &awaiter : group->awaiters)
230         awaiter();
231     }
232   };
233 
234   if (token->ready) {
235     // Update group pending tokens immediately and maybe run awaiters.
236     onTokenReady();
237 
238   } else {
239     // Update group pending tokens when token will become ready. Because this
240     // will happen asynchronously we must ensure that `group` is alive until
241     // then, and re-ackquire the lock.
242     group->addRef();
243 
244     token->awaiters.push_back([group, onTokenReady]() {
245       // Make sure that `dropRef` does not destroy the mutex owned by the lock.
246       {
247         std::unique_lock<std::mutex> lockGroup(group->mu);
248         onTokenReady();
249       }
250       group->dropRef();
251     });
252   }
253 
254   return rank;
255 }
256 
257 // Switches `async.token` to ready state and runs all awaiters.
258 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
259   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
260   {
261     std::unique_lock<std::mutex> lock(token->mu);
262     token->ready = true;
263     token->cv.notify_all();
264     for (auto &awaiter : token->awaiters)
265       awaiter();
266   }
267 
268   // Async tokens created with a ref count `2` to keep token alive until the
269   // async task completes. Drop this reference explicitly when token emplaced.
270   token->dropRef();
271 }
272 
273 // Switches `async.value` to ready state and runs all awaiters.
274 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
275   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
276   {
277     std::unique_lock<std::mutex> lock(value->mu);
278     value->ready = true;
279     value->cv.notify_all();
280     for (auto &awaiter : value->awaiters)
281       awaiter();
282   }
283 
284   // Async values created with a ref count `2` to keep value alive until the
285   // async task completes. Drop this reference explicitly when value emplaced.
286   value->dropRef();
287 }
288 
289 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
290   std::unique_lock<std::mutex> lock(token->mu);
291   if (!token->ready)
292     token->cv.wait(lock, [token] { return token->ready.load(); });
293 }
294 
295 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
296   std::unique_lock<std::mutex> lock(value->mu);
297   if (!value->ready)
298     value->cv.wait(lock, [value] { return value->ready.load(); });
299 }
300 
301 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
302   std::unique_lock<std::mutex> lock(group->mu);
303   if (group->pendingTokens != 0)
304     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
305 }
306 
307 // Returns a pointer to the storage owned by the async value.
308 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
309   return value->storage.data();
310 }
311 
312 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
313   auto *runtime = getDefaultAsyncRuntime();
314   runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
315 }
316 
317 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
318                                                      CoroHandle handle,
319                                                      CoroResume resume) {
320   auto execute = [handle, resume]() { (*resume)(handle); };
321   std::unique_lock<std::mutex> lock(token->mu);
322   if (token->ready) {
323     lock.unlock();
324     execute();
325   } else {
326     token->awaiters.push_back([execute]() { execute(); });
327   }
328 }
329 
330 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
331                                                      CoroHandle handle,
332                                                      CoroResume resume) {
333   auto execute = [handle, resume]() { (*resume)(handle); };
334   std::unique_lock<std::mutex> lock(value->mu);
335   if (value->ready) {
336     lock.unlock();
337     execute();
338   } else {
339     value->awaiters.push_back([execute]() { execute(); });
340   }
341 }
342 
343 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
344                                                           CoroHandle handle,
345                                                           CoroResume resume) {
346   auto execute = [handle, resume]() { (*resume)(handle); };
347   std::unique_lock<std::mutex> lock(group->mu);
348   if (group->pendingTokens == 0) {
349     lock.unlock();
350     execute();
351   } else {
352     group->awaiters.push_back([execute]() { execute(); });
353   }
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // Small async runtime support library for testing.
358 //===----------------------------------------------------------------------===//
359 
360 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
361   static thread_local std::thread::id thisId = std::this_thread::get_id();
362   std::cout << "Current thread id: " << thisId << std::endl;
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // MLIR Runner (JitRunner) dynamic library integration.
367 //===----------------------------------------------------------------------===//
368 
369 // Export symbols for the MLIR runner integration. All other symbols are hidden.
370 #ifdef _WIN32
371 #define API __declspec(dllexport)
372 #else
373 #define API __attribute__((visibility("default")))
374 #endif
375 
376 // Visual Studio had a bug that fails to compile nested generic lambdas
377 // inside an `extern "C"` function.
378 //   https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
379 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
380 // a work around for older versions of Visual Studio.
381 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols);
382 
383 void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
384   auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
385     assert(exportSymbols.count(name) == 0 && "symbol already exists");
386     exportSymbols[name] = reinterpret_cast<void *>(ptr);
387   };
388 
389   exportSymbol("mlirAsyncRuntimeAddRef",
390                &mlir::runtime::mlirAsyncRuntimeAddRef);
391   exportSymbol("mlirAsyncRuntimeDropRef",
392                &mlir::runtime::mlirAsyncRuntimeDropRef);
393   exportSymbol("mlirAsyncRuntimeExecute",
394                &mlir::runtime::mlirAsyncRuntimeExecute);
395   exportSymbol("mlirAsyncRuntimeGetValueStorage",
396                &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
397   exportSymbol("mlirAsyncRuntimeCreateToken",
398                &mlir::runtime::mlirAsyncRuntimeCreateToken);
399   exportSymbol("mlirAsyncRuntimeCreateValue",
400                &mlir::runtime::mlirAsyncRuntimeCreateValue);
401   exportSymbol("mlirAsyncRuntimeEmplaceToken",
402                &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
403   exportSymbol("mlirAsyncRuntimeEmplaceValue",
404                &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
405   exportSymbol("mlirAsyncRuntimeAwaitToken",
406                &mlir::runtime::mlirAsyncRuntimeAwaitToken);
407   exportSymbol("mlirAsyncRuntimeAwaitValue",
408                &mlir::runtime::mlirAsyncRuntimeAwaitValue);
409   exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
410                &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
411   exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
412                &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
413   exportSymbol("mlirAsyncRuntimeCreateGroup",
414                &mlir::runtime::mlirAsyncRuntimeCreateGroup);
415   exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
416                &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
417   exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
418                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
419   exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
420                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
421   exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
422                &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
423 }
424 
425 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
426 
427 } // namespace runtime
428 } // namespace mlir
429 
430 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
431