1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <__threading_support>
10 #define NOMINMAX
11 #define WIN32_LEAN_AND_MEAN
12 #include <windows.h>
13 #include <process.h>
14 #include <fibersapi.h>
15 
16 _LIBCPP_BEGIN_NAMESPACE_STD
17 
18 static_assert(sizeof(__libcpp_mutex_t) == sizeof(SRWLOCK), "");
19 static_assert(alignof(__libcpp_mutex_t) == alignof(SRWLOCK), "");
20 
21 static_assert(sizeof(__libcpp_recursive_mutex_t) == sizeof(CRITICAL_SECTION),
22               "");
23 static_assert(alignof(__libcpp_recursive_mutex_t) == alignof(CRITICAL_SECTION),
24               "");
25 
26 static_assert(sizeof(__libcpp_condvar_t) == sizeof(CONDITION_VARIABLE), "");
27 static_assert(alignof(__libcpp_condvar_t) == alignof(CONDITION_VARIABLE), "");
28 
29 static_assert(sizeof(__libcpp_exec_once_flag) == sizeof(INIT_ONCE), "");
30 static_assert(alignof(__libcpp_exec_once_flag) == alignof(INIT_ONCE), "");
31 
32 static_assert(sizeof(__libcpp_thread_id) == sizeof(DWORD), "");
33 static_assert(alignof(__libcpp_thread_id) == alignof(DWORD), "");
34 
35 static_assert(sizeof(__libcpp_thread_t) == sizeof(HANDLE), "");
36 static_assert(alignof(__libcpp_thread_t) == alignof(HANDLE), "");
37 
38 static_assert(sizeof(__libcpp_tls_key) == sizeof(DWORD), "");
39 static_assert(alignof(__libcpp_tls_key) == alignof(DWORD), "");
40 
41 static_assert(sizeof(__libcpp_semaphore_t) == sizeof(HANDLE), "");
42 static_assert(alignof(__libcpp_semaphore_t) == alignof(HANDLE), "");
43 
44 // Mutex
45 int __libcpp_recursive_mutex_init(__libcpp_recursive_mutex_t *__m)
46 {
47   InitializeCriticalSection((LPCRITICAL_SECTION)__m);
48   return 0;
49 }
50 
51 int __libcpp_recursive_mutex_lock(__libcpp_recursive_mutex_t *__m)
52 {
53   EnterCriticalSection((LPCRITICAL_SECTION)__m);
54   return 0;
55 }
56 
57 bool __libcpp_recursive_mutex_trylock(__libcpp_recursive_mutex_t *__m)
58 {
59   return TryEnterCriticalSection((LPCRITICAL_SECTION)__m) != 0;
60 }
61 
62 int __libcpp_recursive_mutex_unlock(__libcpp_recursive_mutex_t *__m)
63 {
64   LeaveCriticalSection((LPCRITICAL_SECTION)__m);
65   return 0;
66 }
67 
68 int __libcpp_recursive_mutex_destroy(__libcpp_recursive_mutex_t *__m)
69 {
70   DeleteCriticalSection((LPCRITICAL_SECTION)__m);
71   return 0;
72 }
73 
74 int __libcpp_mutex_lock(__libcpp_mutex_t *__m)
75 {
76   AcquireSRWLockExclusive((PSRWLOCK)__m);
77   return 0;
78 }
79 
80 bool __libcpp_mutex_trylock(__libcpp_mutex_t *__m)
81 {
82   return TryAcquireSRWLockExclusive((PSRWLOCK)__m) != 0;
83 }
84 
85 int __libcpp_mutex_unlock(__libcpp_mutex_t *__m)
86 {
87   ReleaseSRWLockExclusive((PSRWLOCK)__m);
88   return 0;
89 }
90 
91 int __libcpp_mutex_destroy(__libcpp_mutex_t *__m)
92 {
93   static_cast<void>(__m);
94   return 0;
95 }
96 
97 // Condition Variable
98 int __libcpp_condvar_signal(__libcpp_condvar_t *__cv)
99 {
100   WakeConditionVariable((PCONDITION_VARIABLE)__cv);
101   return 0;
102 }
103 
104 int __libcpp_condvar_broadcast(__libcpp_condvar_t *__cv)
105 {
106   WakeAllConditionVariable((PCONDITION_VARIABLE)__cv);
107   return 0;
108 }
109 
110 int __libcpp_condvar_wait(__libcpp_condvar_t *__cv, __libcpp_mutex_t *__m)
111 {
112   SleepConditionVariableSRW((PCONDITION_VARIABLE)__cv, (PSRWLOCK)__m, INFINITE, 0);
113   return 0;
114 }
115 
116 int __libcpp_condvar_timedwait(__libcpp_condvar_t *__cv, __libcpp_mutex_t *__m,
117                                __libcpp_timespec_t *__ts)
118 {
119   using namespace _VSTD::chrono;
120 
121   auto duration = seconds(__ts->tv_sec) + nanoseconds(__ts->tv_nsec);
122   auto abstime =
123       system_clock::time_point(duration_cast<system_clock::duration>(duration));
124   auto timeout_ms = duration_cast<milliseconds>(abstime - system_clock::now());
125 
126   if (!SleepConditionVariableSRW((PCONDITION_VARIABLE)__cv, (PSRWLOCK)__m,
127                                  timeout_ms.count() > 0 ? timeout_ms.count()
128                                                         : 0,
129                                  0))
130     {
131       auto __ec = GetLastError();
132       return __ec == ERROR_TIMEOUT ? ETIMEDOUT : __ec;
133     }
134   return 0;
135 }
136 
137 int __libcpp_condvar_destroy(__libcpp_condvar_t *__cv)
138 {
139   static_cast<void>(__cv);
140   return 0;
141 }
142 
143 // Execute Once
144 static inline _LIBCPP_INLINE_VISIBILITY BOOL CALLBACK
145 __libcpp_init_once_execute_once_thunk(PINIT_ONCE __init_once, PVOID __parameter,
146                                       PVOID *__context)
147 {
148   static_cast<void>(__init_once);
149   static_cast<void>(__context);
150 
151   void (*init_routine)(void) = reinterpret_cast<void (*)(void)>(__parameter);
152   init_routine();
153   return TRUE;
154 }
155 
156 int __libcpp_execute_once(__libcpp_exec_once_flag *__flag,
157                           void (*__init_routine)(void))
158 {
159   if (!InitOnceExecuteOnce((PINIT_ONCE)__flag, __libcpp_init_once_execute_once_thunk,
160                            reinterpret_cast<void *>(__init_routine), NULL))
161     return GetLastError();
162   return 0;
163 }
164 
165 // Thread ID
166 bool __libcpp_thread_id_equal(__libcpp_thread_id __lhs,
167                               __libcpp_thread_id __rhs)
168 {
169   return __lhs == __rhs;
170 }
171 
172 bool __libcpp_thread_id_less(__libcpp_thread_id __lhs, __libcpp_thread_id __rhs)
173 {
174   return __lhs < __rhs;
175 }
176 
177 // Thread
178 struct __libcpp_beginthreadex_thunk_data
179 {
180   void *(*__func)(void *);
181   void *__arg;
182 };
183 
184 static inline _LIBCPP_INLINE_VISIBILITY unsigned WINAPI
185 __libcpp_beginthreadex_thunk(void *__raw_data)
186 {
187   auto *__data =
188       static_cast<__libcpp_beginthreadex_thunk_data *>(__raw_data);
189   auto *__func = __data->__func;
190   void *__arg = __data->__arg;
191   delete __data;
192   return static_cast<unsigned>(reinterpret_cast<uintptr_t>(__func(__arg)));
193 }
194 
195 bool __libcpp_thread_isnull(const __libcpp_thread_t *__t) {
196   return *__t == 0;
197 }
198 
199 int __libcpp_thread_create(__libcpp_thread_t *__t, void *(*__func)(void *),
200                            void *__arg)
201 {
202   auto *__data = new __libcpp_beginthreadex_thunk_data;
203   __data->__func = __func;
204   __data->__arg = __arg;
205 
206   *__t = reinterpret_cast<HANDLE>(_beginthreadex(nullptr, 0,
207                                                  __libcpp_beginthreadex_thunk,
208                                                  __data, 0, nullptr));
209 
210   if (*__t)
211     return 0;
212   return GetLastError();
213 }
214 
215 __libcpp_thread_id __libcpp_thread_get_current_id()
216 {
217   return GetCurrentThreadId();
218 }
219 
220 __libcpp_thread_id __libcpp_thread_get_id(const __libcpp_thread_t *__t)
221 {
222   return GetThreadId(*__t);
223 }
224 
225 int __libcpp_thread_join(__libcpp_thread_t *__t)
226 {
227   if (WaitForSingleObjectEx(*__t, INFINITE, FALSE) == WAIT_FAILED)
228     return GetLastError();
229   if (!CloseHandle(*__t))
230     return GetLastError();
231   return 0;
232 }
233 
234 int __libcpp_thread_detach(__libcpp_thread_t *__t)
235 {
236   if (!CloseHandle(*__t))
237     return GetLastError();
238   return 0;
239 }
240 
241 void __libcpp_thread_yield()
242 {
243   SwitchToThread();
244 }
245 
246 void __libcpp_thread_sleep_for(const chrono::nanoseconds& __ns)
247 {
248   // round-up to the nearest millisecond
249   chrono::milliseconds __ms = chrono::ceil<chrono::milliseconds>(__ns);
250   // FIXME(compnerd) this should be an alertable sleep (WFSO or SleepEx)
251   Sleep(__ms.count());
252 }
253 
254 // Thread Local Storage
255 int __libcpp_tls_create(__libcpp_tls_key* __key,
256                         void(_LIBCPP_TLS_DESTRUCTOR_CC* __at_exit)(void*))
257 {
258   DWORD index = FlsAlloc(__at_exit);
259   if (index == FLS_OUT_OF_INDEXES)
260     return GetLastError();
261   *__key = index;
262   return 0;
263 }
264 
265 void *__libcpp_tls_get(__libcpp_tls_key __key)
266 {
267   return FlsGetValue(__key);
268 }
269 
270 int __libcpp_tls_set(__libcpp_tls_key __key, void *__p)
271 {
272   if (!FlsSetValue(__key, __p))
273     return GetLastError();
274   return 0;
275 }
276 
277 // Semaphores
278 bool __libcpp_semaphore_init(__libcpp_semaphore_t* __sem, int __init)
279 {
280   *(PHANDLE)__sem = CreateSemaphoreEx(nullptr, __init, _LIBCPP_SEMAPHORE_MAX,
281                                       nullptr, 0, SEMAPHORE_ALL_ACCESS);
282   return *__sem != nullptr;
283 }
284 
285 bool __libcpp_semaphore_destroy(__libcpp_semaphore_t* __sem)
286 {
287   CloseHandle(*(PHANDLE)__sem);
288   return true;
289 }
290 
291 bool __libcpp_semaphore_post(__libcpp_semaphore_t* __sem)
292 {
293   return ReleaseSemaphore(*(PHANDLE)__sem, 1, nullptr);
294 }
295 
296 bool __libcpp_semaphore_wait(__libcpp_semaphore_t* __sem)
297 {
298   return WaitForSingleObjectEx(*(PHANDLE)__sem, INFINITE, false) ==
299          WAIT_OBJECT_0;
300 }
301 
302 bool __libcpp_semaphore_wait_timed(__libcpp_semaphore_t* __sem,
303                                    chrono::nanoseconds const& __ns)
304 {
305   chrono::milliseconds __ms = chrono::ceil<chrono::milliseconds>(__ns);
306   return WaitForSingleObjectEx(*(PHANDLE)__sem, __ms.count(), false) ==
307          WAIT_OBJECT_0;
308 }
309 
310 _LIBCPP_END_NAMESPACE_STD
311