1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15
16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
17
18 #include <atomic>
19 #include <cassert>
20 #include <condition_variable>
21 #include <functional>
22 #include <iostream>
23 #include <mutex>
24 #include <thread>
25 #include <vector>
26
27 #include "llvm/ADT/StringMap.h"
28 #include "llvm/Support/ThreadPool.h"
29
30 using namespace mlir::runtime;
31
32 //===----------------------------------------------------------------------===//
33 // Async runtime API.
34 //===----------------------------------------------------------------------===//
35
36 namespace mlir {
37 namespace runtime {
38 namespace {
39
40 // Forward declare class defined below.
41 class RefCounted;
42
43 // -------------------------------------------------------------------------- //
44 // AsyncRuntime orchestrates all async operations and Async runtime API is built
45 // on top of the default runtime instance.
46 // -------------------------------------------------------------------------- //
47
48 class AsyncRuntime {
49 public:
AsyncRuntime()50 AsyncRuntime() : numRefCountedObjects(0) {}
51
~AsyncRuntime()52 ~AsyncRuntime() {
53 threadPool.wait(); // wait for the completion of all async tasks
54 assert(getNumRefCountedObjects() == 0 &&
55 "all ref counted objects must be destroyed");
56 }
57
getNumRefCountedObjects()58 int64_t getNumRefCountedObjects() {
59 return numRefCountedObjects.load(std::memory_order_relaxed);
60 }
61
getThreadPool()62 llvm::ThreadPool &getThreadPool() { return threadPool; }
63
64 private:
65 friend class RefCounted;
66
67 // Count the total number of reference counted objects in this instance
68 // of an AsyncRuntime. For debugging purposes only.
addNumRefCountedObjects()69 void addNumRefCountedObjects() {
70 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
71 }
dropNumRefCountedObjects()72 void dropNumRefCountedObjects() {
73 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
74 }
75
76 std::atomic<int64_t> numRefCountedObjects;
77 llvm::ThreadPool threadPool;
78 };
79
80 // -------------------------------------------------------------------------- //
81 // A state of the async runtime value (token, value or group).
82 // -------------------------------------------------------------------------- //
83
84 class State {
85 public:
86 enum StateEnum : int8_t {
87 // The underlying value is not yet available for consumption.
88 kUnavailable = 0,
89 // The underlying value is available for consumption. This state can not
90 // transition to any other state.
91 kAvailable = 1,
92 // This underlying value is available and contains an error. This state can
93 // not transition to any other state.
94 kError = 2,
95 };
96
State(StateEnum s)97 /* implicit */ State(StateEnum s) : state(s) {}
operator StateEnum()98 /* implicit */ operator StateEnum() { return state; }
99
isUnavailable() const100 bool isUnavailable() const { return state == kUnavailable; }
isAvailable() const101 bool isAvailable() const { return state == kAvailable; }
isError() const102 bool isError() const { return state == kError; }
isAvailableOrError() const103 bool isAvailableOrError() const { return isAvailable() || isError(); }
104
debug() const105 const char *debug() const {
106 switch (state) {
107 case kUnavailable:
108 return "unavailable";
109 case kAvailable:
110 return "available";
111 case kError:
112 return "error";
113 }
114 }
115
116 private:
117 StateEnum state;
118 };
119
120 // -------------------------------------------------------------------------- //
121 // A base class for all reference counted objects created by the async runtime.
122 // -------------------------------------------------------------------------- //
123
124 class RefCounted {
125 public:
RefCounted(AsyncRuntime * runtime,int64_t refCount=1)126 RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
127 : runtime(runtime), refCount(refCount) {
128 runtime->addNumRefCountedObjects();
129 }
130
~RefCounted()131 virtual ~RefCounted() {
132 assert(refCount.load() == 0 && "reference count must be zero");
133 runtime->dropNumRefCountedObjects();
134 }
135
136 RefCounted(const RefCounted &) = delete;
137 RefCounted &operator=(const RefCounted &) = delete;
138
addRef(int64_t count=1)139 void addRef(int64_t count = 1) { refCount.fetch_add(count); }
140
dropRef(int64_t count=1)141 void dropRef(int64_t count = 1) {
142 int64_t previous = refCount.fetch_sub(count);
143 assert(previous >= count && "reference count should not go below zero");
144 if (previous == count)
145 destroy();
146 }
147
148 protected:
destroy()149 virtual void destroy() { delete this; }
150
151 private:
152 AsyncRuntime *runtime;
153 std::atomic<int64_t> refCount;
154 };
155
156 } // namespace
157
158 // Returns the default per-process instance of an async runtime.
getDefaultAsyncRuntimeInstance()159 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
160 static auto runtime = std::make_unique<AsyncRuntime>();
161 return runtime;
162 }
163
resetDefaultAsyncRuntime()164 static void resetDefaultAsyncRuntime() {
165 return getDefaultAsyncRuntimeInstance().reset();
166 }
167
getDefaultAsyncRuntime()168 static AsyncRuntime *getDefaultAsyncRuntime() {
169 return getDefaultAsyncRuntimeInstance().get();
170 }
171
172 // Async token provides a mechanism to signal asynchronous operation completion.
173 struct AsyncToken : public RefCounted {
174 // AsyncToken created with a reference count of 2 because it will be returned
175 // to the `async.execute` caller and also will be later on emplaced by the
176 // asynchronously executed task. If the caller immediately will drop its
177 // reference we must ensure that the token will be alive until the
178 // asynchronous operation is completed.
AsyncTokenmlir::runtime::AsyncToken179 AsyncToken(AsyncRuntime *runtime)
180 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
181
182 std::atomic<State::StateEnum> state;
183
184 // Pending awaiters are guarded by a mutex.
185 std::mutex mu;
186 std::condition_variable cv;
187 std::vector<std::function<void()>> awaiters;
188 };
189
190 // Async value provides a mechanism to access the result of asynchronous
191 // operations. It owns the storage that is used to store/load the value of the
192 // underlying type, and a flag to signal if the value is ready or not.
193 struct AsyncValue : public RefCounted {
194 // AsyncValue similar to an AsyncToken created with a reference count of 2.
AsyncValuemlir::runtime::AsyncValue195 AsyncValue(AsyncRuntime *runtime, int64_t size)
196 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
197 storage(size) {}
198
199 std::atomic<State::StateEnum> state;
200
201 // Use vector of bytes to store async value payload.
202 std::vector<int8_t> storage;
203
204 // Pending awaiters are guarded by a mutex.
205 std::mutex mu;
206 std::condition_variable cv;
207 std::vector<std::function<void()>> awaiters;
208 };
209
210 // Async group provides a mechanism to group together multiple async tokens or
211 // values to await on all of them together (wait for the completion of all
212 // tokens or values added to the group).
213 struct AsyncGroup : public RefCounted {
AsyncGroupmlir::runtime::AsyncGroup214 AsyncGroup(AsyncRuntime *runtime, int64_t size)
215 : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
216
217 std::atomic<int> pendingTokens;
218 std::atomic<int> numErrors;
219 std::atomic<int> rank;
220
221 // Pending awaiters are guarded by a mutex.
222 std::mutex mu;
223 std::condition_variable cv;
224 std::vector<std::function<void()>> awaiters;
225 };
226
227 // Adds references to reference counted runtime object.
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr,int64_t count)228 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
229 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
230 refCounted->addRef(count);
231 }
232
233 // Drops references from reference counted runtime object.
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr,int64_t count)234 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
235 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
236 refCounted->dropRef(count);
237 }
238
239 // Creates a new `async.token` in not-ready state.
mlirAsyncRuntimeCreateToken()240 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
241 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
242 return token;
243 }
244
245 // Creates a new `async.value` in not-ready state.
mlirAsyncRuntimeCreateValue(int64_t size)246 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
247 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
248 return value;
249 }
250
251 // Create a new `async.group` in empty state.
mlirAsyncRuntimeCreateGroup(int64_t size)252 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
253 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
254 return group;
255 }
256
mlirAsyncRuntimeAddTokenToGroup(AsyncToken * token,AsyncGroup * group)257 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
258 AsyncGroup *group) {
259 std::unique_lock<std::mutex> lockToken(token->mu);
260 std::unique_lock<std::mutex> lockGroup(group->mu);
261
262 // Get the rank of the token inside the group before we drop the reference.
263 int rank = group->rank.fetch_add(1);
264
265 auto onTokenReady = [group, token]() {
266 // Increment the number of errors in the group.
267 if (State(token->state).isError())
268 group->numErrors.fetch_add(1);
269
270 // If pending tokens go below zero it means that more tokens than the group
271 // size were added to this group.
272 assert(group->pendingTokens > 0 && "wrong group size");
273
274 // Run all group awaiters if it was the last token in the group.
275 if (group->pendingTokens.fetch_sub(1) == 1) {
276 group->cv.notify_all();
277 for (auto &awaiter : group->awaiters)
278 awaiter();
279 }
280 };
281
282 if (State(token->state).isAvailableOrError()) {
283 // Update group pending tokens immediately and maybe run awaiters.
284 onTokenReady();
285
286 } else {
287 // Update group pending tokens when token will become ready. Because this
288 // will happen asynchronously we must ensure that `group` is alive until
289 // then, and re-ackquire the lock.
290 group->addRef();
291
292 token->awaiters.emplace_back([group, onTokenReady]() {
293 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
294 {
295 std::unique_lock<std::mutex> lockGroup(group->mu);
296 onTokenReady();
297 }
298 group->dropRef();
299 });
300 }
301
302 return rank;
303 }
304
305 // Switches `async.token` to available or error state (terminatl state) and runs
306 // all awaiters.
setTokenState(AsyncToken * token,State state)307 static void setTokenState(AsyncToken *token, State state) {
308 assert(state.isAvailableOrError() && "must be terminal state");
309 assert(State(token->state).isUnavailable() && "token must be unavailable");
310
311 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
312 {
313 std::unique_lock<std::mutex> lock(token->mu);
314 token->state = state;
315 token->cv.notify_all();
316 for (auto &awaiter : token->awaiters)
317 awaiter();
318 }
319
320 // Async tokens created with a ref count `2` to keep token alive until the
321 // async task completes. Drop this reference explicitly when token emplaced.
322 token->dropRef();
323 }
324
setValueState(AsyncValue * value,State state)325 static void setValueState(AsyncValue *value, State state) {
326 assert(state.isAvailableOrError() && "must be terminal state");
327 assert(State(value->state).isUnavailable() && "value must be unavailable");
328
329 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
330 {
331 std::unique_lock<std::mutex> lock(value->mu);
332 value->state = state;
333 value->cv.notify_all();
334 for (auto &awaiter : value->awaiters)
335 awaiter();
336 }
337
338 // Async values created with a ref count `2` to keep value alive until the
339 // async task completes. Drop this reference explicitly when value emplaced.
340 value->dropRef();
341 }
342
mlirAsyncRuntimeEmplaceToken(AsyncToken * token)343 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
344 setTokenState(token, State::kAvailable);
345 }
346
mlirAsyncRuntimeEmplaceValue(AsyncValue * value)347 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
348 setValueState(value, State::kAvailable);
349 }
350
mlirAsyncRuntimeSetTokenError(AsyncToken * token)351 extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
352 setTokenState(token, State::kError);
353 }
354
mlirAsyncRuntimeSetValueError(AsyncValue * value)355 extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
356 setValueState(value, State::kError);
357 }
358
mlirAsyncRuntimeIsTokenError(AsyncToken * token)359 extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
360 return State(token->state).isError();
361 }
362
mlirAsyncRuntimeIsValueError(AsyncValue * value)363 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
364 return State(value->state).isError();
365 }
366
mlirAsyncRuntimeIsGroupError(AsyncGroup * group)367 extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
368 return group->numErrors.load() > 0;
369 }
370
mlirAsyncRuntimeAwaitToken(AsyncToken * token)371 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
372 std::unique_lock<std::mutex> lock(token->mu);
373 if (!State(token->state).isAvailableOrError())
374 token->cv.wait(
375 lock, [token] { return State(token->state).isAvailableOrError(); });
376 }
377
mlirAsyncRuntimeAwaitValue(AsyncValue * value)378 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
379 std::unique_lock<std::mutex> lock(value->mu);
380 if (!State(value->state).isAvailableOrError())
381 value->cv.wait(
382 lock, [value] { return State(value->state).isAvailableOrError(); });
383 }
384
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup * group)385 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
386 std::unique_lock<std::mutex> lock(group->mu);
387 if (group->pendingTokens != 0)
388 group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
389 }
390
391 // Returns a pointer to the storage owned by the async value.
mlirAsyncRuntimeGetValueStorage(AsyncValue * value)392 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
393 assert(!State(value->state).isError() && "unexpected error state");
394 return value->storage.data();
395 }
396
mlirAsyncRuntimeExecute(CoroHandle handle,CoroResume resume)397 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
398 auto *runtime = getDefaultAsyncRuntime();
399 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
400 }
401
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken * token,CoroHandle handle,CoroResume resume)402 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
403 CoroHandle handle,
404 CoroResume resume) {
405 auto execute = [handle, resume]() { (*resume)(handle); };
406 std::unique_lock<std::mutex> lock(token->mu);
407 if (State(token->state).isAvailableOrError()) {
408 lock.unlock();
409 execute();
410 } else {
411 token->awaiters.emplace_back([execute]() { execute(); });
412 }
413 }
414
mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue * value,CoroHandle handle,CoroResume resume)415 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
416 CoroHandle handle,
417 CoroResume resume) {
418 auto execute = [handle, resume]() { (*resume)(handle); };
419 std::unique_lock<std::mutex> lock(value->mu);
420 if (State(value->state).isAvailableOrError()) {
421 lock.unlock();
422 execute();
423 } else {
424 value->awaiters.emplace_back([execute]() { execute(); });
425 }
426 }
427
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup * group,CoroHandle handle,CoroResume resume)428 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
429 CoroHandle handle,
430 CoroResume resume) {
431 auto execute = [handle, resume]() { (*resume)(handle); };
432 std::unique_lock<std::mutex> lock(group->mu);
433 if (group->pendingTokens == 0) {
434 lock.unlock();
435 execute();
436 } else {
437 group->awaiters.emplace_back([execute]() { execute(); });
438 }
439 }
440
mlirAsyncRuntimGetNumWorkerThreads()441 extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
442 return getDefaultAsyncRuntime()->getThreadPool().getThreadCount();
443 }
444
445 //===----------------------------------------------------------------------===//
446 // Small async runtime support library for testing.
447 //===----------------------------------------------------------------------===//
448
mlirAsyncRuntimePrintCurrentThreadId()449 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
450 static thread_local std::thread::id thisId = std::this_thread::get_id();
451 std::cout << "Current thread id: " << thisId << std::endl;
452 }
453
454 //===----------------------------------------------------------------------===//
455 // MLIR Runner (JitRunner) dynamic library integration.
456 //===----------------------------------------------------------------------===//
457
458 // Export symbols for the MLIR runner integration. All other symbols are hidden.
459 #ifdef _WIN32
460 #define API __declspec(dllexport)
461 #else
462 #define API __attribute__((visibility("default")))
463 #endif
464
465 // Visual Studio had a bug that fails to compile nested generic lambdas
466 // inside an `extern "C"` function.
467 // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
468 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
469 // a work around for older versions of Visual Studio.
470 // NOLINTNEXTLINE(*-identifier-naming): externally called.
471 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols);
472
473 // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_runner_init(llvm::StringMap<void * > & exportSymbols)474 void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
475 auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
476 assert(exportSymbols.count(name) == 0 && "symbol already exists");
477 exportSymbols[name] = reinterpret_cast<void *>(ptr);
478 };
479
480 exportSymbol("mlirAsyncRuntimeAddRef",
481 &mlir::runtime::mlirAsyncRuntimeAddRef);
482 exportSymbol("mlirAsyncRuntimeDropRef",
483 &mlir::runtime::mlirAsyncRuntimeDropRef);
484 exportSymbol("mlirAsyncRuntimeExecute",
485 &mlir::runtime::mlirAsyncRuntimeExecute);
486 exportSymbol("mlirAsyncRuntimeGetValueStorage",
487 &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
488 exportSymbol("mlirAsyncRuntimeCreateToken",
489 &mlir::runtime::mlirAsyncRuntimeCreateToken);
490 exportSymbol("mlirAsyncRuntimeCreateValue",
491 &mlir::runtime::mlirAsyncRuntimeCreateValue);
492 exportSymbol("mlirAsyncRuntimeEmplaceToken",
493 &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
494 exportSymbol("mlirAsyncRuntimeEmplaceValue",
495 &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
496 exportSymbol("mlirAsyncRuntimeSetTokenError",
497 &mlir::runtime::mlirAsyncRuntimeSetTokenError);
498 exportSymbol("mlirAsyncRuntimeSetValueError",
499 &mlir::runtime::mlirAsyncRuntimeSetValueError);
500 exportSymbol("mlirAsyncRuntimeIsTokenError",
501 &mlir::runtime::mlirAsyncRuntimeIsTokenError);
502 exportSymbol("mlirAsyncRuntimeIsValueError",
503 &mlir::runtime::mlirAsyncRuntimeIsValueError);
504 exportSymbol("mlirAsyncRuntimeIsGroupError",
505 &mlir::runtime::mlirAsyncRuntimeIsGroupError);
506 exportSymbol("mlirAsyncRuntimeAwaitToken",
507 &mlir::runtime::mlirAsyncRuntimeAwaitToken);
508 exportSymbol("mlirAsyncRuntimeAwaitValue",
509 &mlir::runtime::mlirAsyncRuntimeAwaitValue);
510 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
511 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
512 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
513 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
514 exportSymbol("mlirAsyncRuntimeCreateGroup",
515 &mlir::runtime::mlirAsyncRuntimeCreateGroup);
516 exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
517 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
518 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
519 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
520 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
521 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
522 exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
523 &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
524 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
525 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
526 }
527
528 // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_runner_destroy()529 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
530
531 } // namespace runtime
532 } // namespace mlir
533
534 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
535