1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15 
16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
17 
18 #include <atomic>
19 #include <cassert>
20 #include <condition_variable>
21 #include <functional>
22 #include <iostream>
23 #include <mutex>
24 #include <thread>
25 #include <vector>
26 
27 #include "llvm/ADT/StringMap.h"
28 #include "llvm/Support/ThreadPool.h"
29 
30 using namespace mlir::runtime;
31 
32 //===----------------------------------------------------------------------===//
33 // Async runtime API.
34 //===----------------------------------------------------------------------===//
35 
36 namespace mlir {
37 namespace runtime {
38 namespace {
39 
40 // Forward declare class defined below.
41 class RefCounted;
42 
43 // -------------------------------------------------------------------------- //
44 // AsyncRuntime orchestrates all async operations and Async runtime API is built
45 // on top of the default runtime instance.
46 // -------------------------------------------------------------------------- //
47 
48 class AsyncRuntime {
49 public:
50   AsyncRuntime() : numRefCountedObjects(0) {}
51 
52   ~AsyncRuntime() {
53     threadPool.wait(); // wait for the completion of all async tasks
54     assert(getNumRefCountedObjects() == 0 &&
55            "all ref counted objects must be destroyed");
56   }
57 
58   int64_t getNumRefCountedObjects() {
59     return numRefCountedObjects.load(std::memory_order_relaxed);
60   }
61 
62   llvm::ThreadPool &getThreadPool() { return threadPool; }
63 
64 private:
65   friend class RefCounted;
66 
67   // Count the total number of reference counted objects in this instance
68   // of an AsyncRuntime. For debugging purposes only.
69   void addNumRefCountedObjects() {
70     numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
71   }
72   void dropNumRefCountedObjects() {
73     numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
74   }
75 
76   std::atomic<int64_t> numRefCountedObjects;
77   llvm::ThreadPool threadPool;
78 };
79 
80 // -------------------------------------------------------------------------- //
81 // A state of the async runtime value (token, value or group).
82 // -------------------------------------------------------------------------- //
83 
84 class State {
85 public:
86   enum StateEnum : int8_t {
87     // The underlying value is not yet available for consumption.
88     kUnavailable = 0,
89     // The underlying value is available for consumption. This state can not
90     // transition to any other state.
91     kAvailable = 1,
92     // This underlying value is available and contains an error. This state can
93     // not transition to any other state.
94     kError = 2,
95   };
96 
97   /* implicit */ State(StateEnum s) : state(s) {}
98   /* implicit */ operator StateEnum() { return state; }
99 
100   bool isUnavailable() const { return state == kUnavailable; }
101   bool isAvailable() const { return state == kAvailable; }
102   bool isError() const { return state == kError; }
103   bool isAvailableOrError() const { return isAvailable() || isError(); }
104 
105   const char *debug() const {
106     switch (state) {
107     case kUnavailable:
108       return "unavailable";
109     case kAvailable:
110       return "available";
111     case kError:
112       return "error";
113     }
114   }
115 
116 private:
117   StateEnum state;
118 };
119 
120 // -------------------------------------------------------------------------- //
121 // A base class for all reference counted objects created by the async runtime.
122 // -------------------------------------------------------------------------- //
123 
124 class RefCounted {
125 public:
126   RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
127       : runtime(runtime), refCount(refCount) {
128     runtime->addNumRefCountedObjects();
129   }
130 
131   virtual ~RefCounted() {
132     assert(refCount.load() == 0 && "reference count must be zero");
133     runtime->dropNumRefCountedObjects();
134   }
135 
136   RefCounted(const RefCounted &) = delete;
137   RefCounted &operator=(const RefCounted &) = delete;
138 
139   void addRef(int64_t count = 1) { refCount.fetch_add(count); }
140 
141   void dropRef(int64_t count = 1) {
142     int64_t previous = refCount.fetch_sub(count);
143     assert(previous >= count && "reference count should not go below zero");
144     if (previous == count)
145       destroy();
146   }
147 
148 protected:
149   virtual void destroy() { delete this; }
150 
151 private:
152   AsyncRuntime *runtime;
153   std::atomic<int64_t> refCount;
154 };
155 
156 } // namespace
157 
158 // Returns the default per-process instance of an async runtime.
159 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
160   static auto runtime = std::make_unique<AsyncRuntime>();
161   return runtime;
162 }
163 
164 static void resetDefaultAsyncRuntime() {
165   return getDefaultAsyncRuntimeInstance().reset();
166 }
167 
168 static AsyncRuntime *getDefaultAsyncRuntime() {
169   return getDefaultAsyncRuntimeInstance().get();
170 }
171 
172 // Async token provides a mechanism to signal asynchronous operation completion.
173 struct AsyncToken : public RefCounted {
174   // AsyncToken created with a reference count of 2 because it will be returned
175   // to the `async.execute` caller and also will be later on emplaced by the
176   // asynchronously executed task. If the caller immediately will drop its
177   // reference we must ensure that the token will be alive until the
178   // asynchronous operation is completed.
179   AsyncToken(AsyncRuntime *runtime)
180       : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
181 
182   std::atomic<State::StateEnum> state;
183 
184   // Pending awaiters are guarded by a mutex.
185   std::mutex mu;
186   std::condition_variable cv;
187   std::vector<std::function<void()>> awaiters;
188 };
189 
190 // Async value provides a mechanism to access the result of asynchronous
191 // operations. It owns the storage that is used to store/load the value of the
192 // underlying type, and a flag to signal if the value is ready or not.
193 struct AsyncValue : public RefCounted {
194   // AsyncValue similar to an AsyncToken created with a reference count of 2.
195   AsyncValue(AsyncRuntime *runtime, int64_t size)
196       : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
197         storage(size) {}
198 
199   std::atomic<State::StateEnum> state;
200 
201   // Use vector of bytes to store async value payload.
202   std::vector<int8_t> storage;
203 
204   // Pending awaiters are guarded by a mutex.
205   std::mutex mu;
206   std::condition_variable cv;
207   std::vector<std::function<void()>> awaiters;
208 };
209 
210 // Async group provides a mechanism to group together multiple async tokens or
211 // values to await on all of them together (wait for the completion of all
212 // tokens or values added to the group).
213 struct AsyncGroup : public RefCounted {
214   AsyncGroup(AsyncRuntime *runtime, int64_t size)
215       : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
216 
217   std::atomic<int> pendingTokens;
218   std::atomic<int> numErrors;
219   std::atomic<int> rank;
220 
221   // Pending awaiters are guarded by a mutex.
222   std::mutex mu;
223   std::condition_variable cv;
224   std::vector<std::function<void()>> awaiters;
225 };
226 
227 // Adds references to reference counted runtime object.
228 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
229   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
230   refCounted->addRef(count);
231 }
232 
233 // Drops references from reference counted runtime object.
234 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
235   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
236   refCounted->dropRef(count);
237 }
238 
239 // Creates a new `async.token` in not-ready state.
240 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
241   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
242   return token;
243 }
244 
245 // Creates a new `async.value` in not-ready state.
246 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
247   AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
248   return value;
249 }
250 
251 // Create a new `async.group` in empty state.
252 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
253   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
254   return group;
255 }
256 
257 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
258                                                    AsyncGroup *group) {
259   std::unique_lock<std::mutex> lockToken(token->mu);
260   std::unique_lock<std::mutex> lockGroup(group->mu);
261 
262   // Get the rank of the token inside the group before we drop the reference.
263   int rank = group->rank.fetch_add(1);
264 
265   auto onTokenReady = [group, token]() {
266     // Increment the number of errors in the group.
267     if (State(token->state).isError())
268       group->numErrors.fetch_add(1);
269 
270     // If pending tokens go below zero it means that more tokens than the group
271     // size were added to this group.
272     assert(group->pendingTokens > 0 && "wrong group size");
273 
274     // Run all group awaiters if it was the last token in the group.
275     if (group->pendingTokens.fetch_sub(1) == 1) {
276       group->cv.notify_all();
277       for (auto &awaiter : group->awaiters)
278         awaiter();
279     }
280   };
281 
282   if (State(token->state).isAvailableOrError()) {
283     // Update group pending tokens immediately and maybe run awaiters.
284     onTokenReady();
285 
286   } else {
287     // Update group pending tokens when token will become ready. Because this
288     // will happen asynchronously we must ensure that `group` is alive until
289     // then, and re-ackquire the lock.
290     group->addRef();
291 
292     token->awaiters.emplace_back([group, onTokenReady]() {
293       // Make sure that `dropRef` does not destroy the mutex owned by the lock.
294       {
295         std::unique_lock<std::mutex> lockGroup(group->mu);
296         onTokenReady();
297       }
298       group->dropRef();
299     });
300   }
301 
302   return rank;
303 }
304 
305 // Switches `async.token` to available or error state (terminatl state) and runs
306 // all awaiters.
307 static void setTokenState(AsyncToken *token, State state) {
308   assert(state.isAvailableOrError() && "must be terminal state");
309   assert(State(token->state).isUnavailable() && "token must be unavailable");
310 
311   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
312   {
313     std::unique_lock<std::mutex> lock(token->mu);
314     token->state = state;
315     token->cv.notify_all();
316     for (auto &awaiter : token->awaiters)
317       awaiter();
318   }
319 
320   // Async tokens created with a ref count `2` to keep token alive until the
321   // async task completes. Drop this reference explicitly when token emplaced.
322   token->dropRef();
323 }
324 
325 static void setValueState(AsyncValue *value, State state) {
326   assert(state.isAvailableOrError() && "must be terminal state");
327   assert(State(value->state).isUnavailable() && "value must be unavailable");
328 
329   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
330   {
331     std::unique_lock<std::mutex> lock(value->mu);
332     value->state = state;
333     value->cv.notify_all();
334     for (auto &awaiter : value->awaiters)
335       awaiter();
336   }
337 
338   // Async values created with a ref count `2` to keep value alive until the
339   // async task completes. Drop this reference explicitly when value emplaced.
340   value->dropRef();
341 }
342 
343 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
344   setTokenState(token, State::kAvailable);
345 }
346 
347 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
348   setValueState(value, State::kAvailable);
349 }
350 
351 extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
352   setTokenState(token, State::kError);
353 }
354 
355 extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
356   setValueState(value, State::kError);
357 }
358 
359 extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
360   return State(token->state).isError();
361 }
362 
363 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
364   return State(value->state).isError();
365 }
366 
367 extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
368   return group->numErrors.load() > 0;
369 }
370 
371 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
372   std::unique_lock<std::mutex> lock(token->mu);
373   if (!State(token->state).isAvailableOrError())
374     token->cv.wait(
375         lock, [token] { return State(token->state).isAvailableOrError(); });
376 }
377 
378 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
379   std::unique_lock<std::mutex> lock(value->mu);
380   if (!State(value->state).isAvailableOrError())
381     value->cv.wait(
382         lock, [value] { return State(value->state).isAvailableOrError(); });
383 }
384 
385 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
386   std::unique_lock<std::mutex> lock(group->mu);
387   if (group->pendingTokens != 0)
388     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
389 }
390 
391 // Returns a pointer to the storage owned by the async value.
392 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
393   assert(!State(value->state).isError() && "unexpected error state");
394   return value->storage.data();
395 }
396 
397 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
398   auto *runtime = getDefaultAsyncRuntime();
399   runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
400 }
401 
402 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
403                                                      CoroHandle handle,
404                                                      CoroResume resume) {
405   auto execute = [handle, resume]() { (*resume)(handle); };
406   std::unique_lock<std::mutex> lock(token->mu);
407   if (State(token->state).isAvailableOrError()) {
408     lock.unlock();
409     execute();
410   } else {
411     token->awaiters.emplace_back([execute]() { execute(); });
412   }
413 }
414 
415 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
416                                                      CoroHandle handle,
417                                                      CoroResume resume) {
418   auto execute = [handle, resume]() { (*resume)(handle); };
419   std::unique_lock<std::mutex> lock(value->mu);
420   if (State(value->state).isAvailableOrError()) {
421     lock.unlock();
422     execute();
423   } else {
424     value->awaiters.emplace_back([execute]() { execute(); });
425   }
426 }
427 
428 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
429                                                           CoroHandle handle,
430                                                           CoroResume resume) {
431   auto execute = [handle, resume]() { (*resume)(handle); };
432   std::unique_lock<std::mutex> lock(group->mu);
433   if (group->pendingTokens == 0) {
434     lock.unlock();
435     execute();
436   } else {
437     group->awaiters.emplace_back([execute]() { execute(); });
438   }
439 }
440 
441 extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
442   return getDefaultAsyncRuntime()->getThreadPool().getThreadCount();
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // Small async runtime support library for testing.
447 //===----------------------------------------------------------------------===//
448 
449 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
450   static thread_local std::thread::id thisId = std::this_thread::get_id();
451   std::cout << "Current thread id: " << thisId << std::endl;
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // MLIR Runner (JitRunner) dynamic library integration.
456 //===----------------------------------------------------------------------===//
457 
458 // Export symbols for the MLIR runner integration. All other symbols are hidden.
459 #ifdef _WIN32
460 #define API __declspec(dllexport)
461 #else
462 #define API __attribute__((visibility("default")))
463 #endif
464 
465 // Visual Studio had a bug that fails to compile nested generic lambdas
466 // inside an `extern "C"` function.
467 //   https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
468 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
469 // a work around for older versions of Visual Studio.
470 // NOLINTNEXTLINE(*-identifier-naming): externally called.
471 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols);
472 
473 // NOLINTNEXTLINE(*-identifier-naming): externally called.
474 void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
475   auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
476     assert(exportSymbols.count(name) == 0 && "symbol already exists");
477     exportSymbols[name] = reinterpret_cast<void *>(ptr);
478   };
479 
480   exportSymbol("mlirAsyncRuntimeAddRef",
481                &mlir::runtime::mlirAsyncRuntimeAddRef);
482   exportSymbol("mlirAsyncRuntimeDropRef",
483                &mlir::runtime::mlirAsyncRuntimeDropRef);
484   exportSymbol("mlirAsyncRuntimeExecute",
485                &mlir::runtime::mlirAsyncRuntimeExecute);
486   exportSymbol("mlirAsyncRuntimeGetValueStorage",
487                &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
488   exportSymbol("mlirAsyncRuntimeCreateToken",
489                &mlir::runtime::mlirAsyncRuntimeCreateToken);
490   exportSymbol("mlirAsyncRuntimeCreateValue",
491                &mlir::runtime::mlirAsyncRuntimeCreateValue);
492   exportSymbol("mlirAsyncRuntimeEmplaceToken",
493                &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
494   exportSymbol("mlirAsyncRuntimeEmplaceValue",
495                &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
496   exportSymbol("mlirAsyncRuntimeSetTokenError",
497                &mlir::runtime::mlirAsyncRuntimeSetTokenError);
498   exportSymbol("mlirAsyncRuntimeSetValueError",
499                &mlir::runtime::mlirAsyncRuntimeSetValueError);
500   exportSymbol("mlirAsyncRuntimeIsTokenError",
501                &mlir::runtime::mlirAsyncRuntimeIsTokenError);
502   exportSymbol("mlirAsyncRuntimeIsValueError",
503                &mlir::runtime::mlirAsyncRuntimeIsValueError);
504   exportSymbol("mlirAsyncRuntimeIsGroupError",
505                &mlir::runtime::mlirAsyncRuntimeIsGroupError);
506   exportSymbol("mlirAsyncRuntimeAwaitToken",
507                &mlir::runtime::mlirAsyncRuntimeAwaitToken);
508   exportSymbol("mlirAsyncRuntimeAwaitValue",
509                &mlir::runtime::mlirAsyncRuntimeAwaitValue);
510   exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
511                &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
512   exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
513                &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
514   exportSymbol("mlirAsyncRuntimeCreateGroup",
515                &mlir::runtime::mlirAsyncRuntimeCreateGroup);
516   exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
517                &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
518   exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
519                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
520   exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
521                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
522   exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
523                &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
524   exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
525                &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
526 }
527 
528 // NOLINTNEXTLINE(*-identifier-naming): externally called.
529 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
530 
531 } // namespace runtime
532 } // namespace mlir
533 
534 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
535