1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15 
16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
17 
18 #include <atomic>
19 #include <condition_variable>
20 #include <functional>
21 #include <iostream>
22 #include <mutex>
23 #include <thread>
24 #include <vector>
25 
26 //===----------------------------------------------------------------------===//
27 // Async runtime API.
28 //===----------------------------------------------------------------------===//
29 
30 struct AsyncToken {
31   bool ready = false;
32   std::mutex mu;
33   std::condition_variable cv;
34   std::vector<std::function<void()>> awaiters;
35 };
36 
37 struct AsyncGroup {
38   std::atomic<int> pendingTokens{0};
39   std::atomic<int> rank{0};
40   std::mutex mu;
41   std::condition_variable cv;
42   std::vector<std::function<void()>> awaiters;
43 };
44 
45 // Create a new `async.token` in not-ready state.
46 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
47   AsyncToken *token = new AsyncToken;
48   return token;
49 }
50 
51 // Create a new `async.group` in empty state.
52 extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
53   AsyncGroup *group = new AsyncGroup;
54   return group;
55 }
56 
57 extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
58 mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
59   std::unique_lock<std::mutex> lockToken(token->mu);
60   std::unique_lock<std::mutex> lockGroup(group->mu);
61 
62   group->pendingTokens.fetch_add(1);
63 
64   auto onTokenReady = [group]() {
65     // Run all group awaiters if it was the last token in the group.
66     if (group->pendingTokens.fetch_sub(1) == 1) {
67       group->cv.notify_all();
68       for (auto &awaiter : group->awaiters)
69         awaiter();
70     }
71   };
72 
73   if (token->ready)
74     onTokenReady();
75   else
76     token->awaiters.push_back([onTokenReady]() { onTokenReady(); });
77 
78   return group->rank.fetch_add(1);
79 }
80 
81 // Switches `async.token` to ready state and runs all awaiters.
82 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
83   std::unique_lock<std::mutex> lock(token->mu);
84   token->ready = true;
85   token->cv.notify_all();
86   for (auto &awaiter : token->awaiters)
87     awaiter();
88 }
89 
90 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
91   std::unique_lock<std::mutex> lock(token->mu);
92   if (!token->ready)
93     token->cv.wait(lock, [token] { return token->ready; });
94 }
95 
96 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
97 mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
98   std::unique_lock<std::mutex> lock(group->mu);
99   if (group->pendingTokens != 0)
100     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
101 }
102 
103 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
104 #if LLVM_ENABLE_THREADS
105   std::thread thread([handle, resume]() { (*resume)(handle); });
106   thread.detach();
107 #else
108   (*resume)(handle);
109 #endif
110 }
111 
112 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
113                                                      CoroHandle handle,
114                                                      CoroResume resume) {
115   std::unique_lock<std::mutex> lock(token->mu);
116 
117   auto execute = [handle, resume]() {
118     mlirAsyncRuntimeExecute(handle, resume);
119   };
120 
121   if (token->ready)
122     execute();
123   else
124     token->awaiters.push_back([execute]() { execute(); });
125 }
126 
127 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
128 mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
129                                           CoroResume resume) {
130   std::unique_lock<std::mutex> lock(group->mu);
131 
132   auto execute = [handle, resume]() {
133     mlirAsyncRuntimeExecute(handle, resume);
134   };
135 
136   if (group->pendingTokens == 0)
137     execute();
138   else
139     group->awaiters.push_back([execute]() { execute(); });
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // Small async runtime support library for testing.
144 //===----------------------------------------------------------------------===//
145 
146 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
147   static thread_local std::thread::id thisId = std::this_thread::get_id();
148   std::cout << "Current thread id: " << thisId << "\n";
149 }
150 
151 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
152