1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 #include "llvm/ADT/StringMap.h" 28 #include "llvm/Support/ThreadPool.h" 29 30 using namespace mlir::runtime; 31 32 //===----------------------------------------------------------------------===// 33 // Async runtime API. 34 //===----------------------------------------------------------------------===// 35 36 namespace mlir { 37 namespace runtime { 38 namespace { 39 40 // Forward declare class defined below. 41 class RefCounted; 42 43 // -------------------------------------------------------------------------- // 44 // AsyncRuntime orchestrates all async operations and Async runtime API is built 45 // on top of the default runtime instance. 46 // -------------------------------------------------------------------------- // 47 48 class AsyncRuntime { 49 public: 50 AsyncRuntime() : numRefCountedObjects(0) {} 51 52 ~AsyncRuntime() { 53 threadPool.wait(); // wait for the completion of all async tasks 54 assert(getNumRefCountedObjects() == 0 && 55 "all ref counted objects must be destroyed"); 56 } 57 58 int64_t getNumRefCountedObjects() { 59 return numRefCountedObjects.load(std::memory_order_relaxed); 60 } 61 62 llvm::ThreadPool &getThreadPool() { return threadPool; } 63 64 private: 65 friend class RefCounted; 66 67 // Count the total number of reference counted objects in this instance 68 // of an AsyncRuntime. For debugging purposes only. 69 void addNumRefCountedObjects() { 70 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 71 } 72 void dropNumRefCountedObjects() { 73 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 74 } 75 76 std::atomic<int64_t> numRefCountedObjects; 77 llvm::ThreadPool threadPool; 78 }; 79 80 // -------------------------------------------------------------------------- // 81 // A state of the async runtime value (token, value or group). 82 // -------------------------------------------------------------------------- // 83 84 class State { 85 public: 86 enum StateEnum : int8_t { 87 // The underlying value is not yet available for consumption. 88 kUnavailable = 0, 89 // The underlying value is available for consumption. This state can not 90 // transition to any other state. 91 kAvailable = 1, 92 // This underlying value is available and contains an error. This state can 93 // not transition to any other state. 94 kError = 2, 95 }; 96 97 /* implicit */ State(StateEnum s) : state(s) {} 98 /* implicit */ operator StateEnum() { return state; } 99 100 bool isUnavailable() const { return state == kUnavailable; } 101 bool isAvailable() const { return state == kAvailable; } 102 bool isError() const { return state == kError; } 103 bool isAvailableOrError() const { return isAvailable() || isError(); } 104 105 const char *debug() const { 106 switch (state) { 107 case kUnavailable: 108 return "unavailable"; 109 case kAvailable: 110 return "available"; 111 case kError: 112 return "error"; 113 } 114 } 115 116 private: 117 StateEnum state; 118 }; 119 120 // -------------------------------------------------------------------------- // 121 // A base class for all reference counted objects created by the async runtime. 122 // -------------------------------------------------------------------------- // 123 124 class RefCounted { 125 public: 126 RefCounted(AsyncRuntime *runtime, int64_t refCount = 1) 127 : runtime(runtime), refCount(refCount) { 128 runtime->addNumRefCountedObjects(); 129 } 130 131 virtual ~RefCounted() { 132 assert(refCount.load() == 0 && "reference count must be zero"); 133 runtime->dropNumRefCountedObjects(); 134 } 135 136 RefCounted(const RefCounted &) = delete; 137 RefCounted &operator=(const RefCounted &) = delete; 138 139 void addRef(int64_t count = 1) { refCount.fetch_add(count); } 140 141 void dropRef(int64_t count = 1) { 142 int64_t previous = refCount.fetch_sub(count); 143 assert(previous >= count && "reference count should not go below zero"); 144 if (previous == count) 145 destroy(); 146 } 147 148 protected: 149 virtual void destroy() { delete this; } 150 151 private: 152 AsyncRuntime *runtime; 153 std::atomic<int64_t> refCount; 154 }; 155 156 } // namespace 157 158 // Returns the default per-process instance of an async runtime. 159 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() { 160 static auto runtime = std::make_unique<AsyncRuntime>(); 161 return runtime; 162 } 163 164 static void resetDefaultAsyncRuntime() { 165 return getDefaultAsyncRuntimeInstance().reset(); 166 } 167 168 static AsyncRuntime *getDefaultAsyncRuntime() { 169 return getDefaultAsyncRuntimeInstance().get(); 170 } 171 172 // Async token provides a mechanism to signal asynchronous operation completion. 173 struct AsyncToken : public RefCounted { 174 // AsyncToken created with a reference count of 2 because it will be returned 175 // to the `async.execute` caller and also will be later on emplaced by the 176 // asynchronously executed task. If the caller immediately will drop its 177 // reference we must ensure that the token will be alive until the 178 // asynchronous operation is completed. 179 AsyncToken(AsyncRuntime *runtime) 180 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {} 181 182 std::atomic<State::StateEnum> state; 183 184 // Pending awaiters are guarded by a mutex. 185 std::mutex mu; 186 std::condition_variable cv; 187 std::vector<std::function<void()>> awaiters; 188 }; 189 190 // Async value provides a mechanism to access the result of asynchronous 191 // operations. It owns the storage that is used to store/load the value of the 192 // underlying type, and a flag to signal if the value is ready or not. 193 struct AsyncValue : public RefCounted { 194 // AsyncValue similar to an AsyncToken created with a reference count of 2. 195 AsyncValue(AsyncRuntime *runtime, int64_t size) 196 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable), 197 storage(size) {} 198 199 std::atomic<State::StateEnum> state; 200 201 // Use vector of bytes to store async value payload. 202 std::vector<int8_t> storage; 203 204 // Pending awaiters are guarded by a mutex. 205 std::mutex mu; 206 std::condition_variable cv; 207 std::vector<std::function<void()>> awaiters; 208 }; 209 210 // Async group provides a mechanism to group together multiple async tokens or 211 // values to await on all of them together (wait for the completion of all 212 // tokens or values added to the group). 213 struct AsyncGroup : public RefCounted { 214 AsyncGroup(AsyncRuntime *runtime, int64_t size) 215 : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {} 216 217 std::atomic<int> pendingTokens; 218 std::atomic<int> numErrors; 219 std::atomic<int> rank; 220 221 // Pending awaiters are guarded by a mutex. 222 std::mutex mu; 223 std::condition_variable cv; 224 std::vector<std::function<void()>> awaiters; 225 }; 226 227 // Adds references to reference counted runtime object. 228 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) { 229 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 230 refCounted->addRef(count); 231 } 232 233 // Drops references from reference counted runtime object. 234 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) { 235 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 236 refCounted->dropRef(count); 237 } 238 239 // Creates a new `async.token` in not-ready state. 240 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 241 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime()); 242 return token; 243 } 244 245 // Creates a new `async.value` in not-ready state. 246 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) { 247 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size); 248 return value; 249 } 250 251 // Create a new `async.group` in empty state. 252 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) { 253 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size); 254 return group; 255 } 256 257 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, 258 AsyncGroup *group) { 259 std::unique_lock<std::mutex> lockToken(token->mu); 260 std::unique_lock<std::mutex> lockGroup(group->mu); 261 262 // Get the rank of the token inside the group before we drop the reference. 263 int rank = group->rank.fetch_add(1); 264 265 auto onTokenReady = [group, token]() { 266 // Increment the number of errors in the group. 267 if (State(token->state).isError()) 268 group->numErrors.fetch_add(1); 269 270 // If pending tokens go below zero it means that more tokens than the group 271 // size were added to this group. 272 assert(group->pendingTokens > 0 && "wrong group size"); 273 274 // Run all group awaiters if it was the last token in the group. 275 if (group->pendingTokens.fetch_sub(1) == 1) { 276 group->cv.notify_all(); 277 for (auto &awaiter : group->awaiters) 278 awaiter(); 279 } 280 }; 281 282 if (State(token->state).isAvailableOrError()) { 283 // Update group pending tokens immediately and maybe run awaiters. 284 onTokenReady(); 285 286 } else { 287 // Update group pending tokens when token will become ready. Because this 288 // will happen asynchronously we must ensure that `group` is alive until 289 // then, and re-ackquire the lock. 290 group->addRef(); 291 292 token->awaiters.emplace_back([group, onTokenReady]() { 293 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 294 { 295 std::unique_lock<std::mutex> lockGroup(group->mu); 296 onTokenReady(); 297 } 298 group->dropRef(); 299 }); 300 } 301 302 return rank; 303 } 304 305 // Switches `async.token` to available or error state (terminatl state) and runs 306 // all awaiters. 307 static void setTokenState(AsyncToken *token, State state) { 308 assert(state.isAvailableOrError() && "must be terminal state"); 309 assert(State(token->state).isUnavailable() && "token must be unavailable"); 310 311 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 312 { 313 std::unique_lock<std::mutex> lock(token->mu); 314 token->state = state; 315 token->cv.notify_all(); 316 for (auto &awaiter : token->awaiters) 317 awaiter(); 318 } 319 320 // Async tokens created with a ref count `2` to keep token alive until the 321 // async task completes. Drop this reference explicitly when token emplaced. 322 token->dropRef(); 323 } 324 325 static void setValueState(AsyncValue *value, State state) { 326 assert(state.isAvailableOrError() && "must be terminal state"); 327 assert(State(value->state).isUnavailable() && "value must be unavailable"); 328 329 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 330 { 331 std::unique_lock<std::mutex> lock(value->mu); 332 value->state = state; 333 value->cv.notify_all(); 334 for (auto &awaiter : value->awaiters) 335 awaiter(); 336 } 337 338 // Async values created with a ref count `2` to keep value alive until the 339 // async task completes. Drop this reference explicitly when value emplaced. 340 value->dropRef(); 341 } 342 343 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 344 setTokenState(token, State::kAvailable); 345 } 346 347 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { 348 setValueState(value, State::kAvailable); 349 } 350 351 extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) { 352 setTokenState(token, State::kError); 353 } 354 355 extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) { 356 setValueState(value, State::kError); 357 } 358 359 extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) { 360 return State(token->state).isError(); 361 } 362 363 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) { 364 return State(value->state).isError(); 365 } 366 367 extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) { 368 return group->numErrors.load() > 0; 369 } 370 371 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 372 std::unique_lock<std::mutex> lock(token->mu); 373 if (!State(token->state).isAvailableOrError()) 374 token->cv.wait( 375 lock, [token] { return State(token->state).isAvailableOrError(); }); 376 } 377 378 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { 379 std::unique_lock<std::mutex> lock(value->mu); 380 if (!State(value->state).isAvailableOrError()) 381 value->cv.wait( 382 lock, [value] { return State(value->state).isAvailableOrError(); }); 383 } 384 385 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 386 std::unique_lock<std::mutex> lock(group->mu); 387 if (group->pendingTokens != 0) 388 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 389 } 390 391 // Returns a pointer to the storage owned by the async value. 392 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) { 393 assert(!State(value->state).isError() && "unexpected error state"); 394 return value->storage.data(); 395 } 396 397 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 398 auto *runtime = getDefaultAsyncRuntime(); 399 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); 400 } 401 402 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 403 CoroHandle handle, 404 CoroResume resume) { 405 auto execute = [handle, resume]() { (*resume)(handle); }; 406 std::unique_lock<std::mutex> lock(token->mu); 407 if (State(token->state).isAvailableOrError()) { 408 lock.unlock(); 409 execute(); 410 } else { 411 token->awaiters.emplace_back([execute]() { execute(); }); 412 } 413 } 414 415 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, 416 CoroHandle handle, 417 CoroResume resume) { 418 auto execute = [handle, resume]() { (*resume)(handle); }; 419 std::unique_lock<std::mutex> lock(value->mu); 420 if (State(value->state).isAvailableOrError()) { 421 lock.unlock(); 422 execute(); 423 } else { 424 value->awaiters.emplace_back([execute]() { execute(); }); 425 } 426 } 427 428 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, 429 CoroHandle handle, 430 CoroResume resume) { 431 auto execute = [handle, resume]() { (*resume)(handle); }; 432 std::unique_lock<std::mutex> lock(group->mu); 433 if (group->pendingTokens == 0) { 434 lock.unlock(); 435 execute(); 436 } else { 437 group->awaiters.emplace_back([execute]() { execute(); }); 438 } 439 } 440 441 extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() { 442 return getDefaultAsyncRuntime()->getThreadPool().getThreadCount(); 443 } 444 445 //===----------------------------------------------------------------------===// 446 // Small async runtime support library for testing. 447 //===----------------------------------------------------------------------===// 448 449 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 450 static thread_local std::thread::id thisId = std::this_thread::get_id(); 451 std::cout << "Current thread id: " << thisId << std::endl; 452 } 453 454 //===----------------------------------------------------------------------===// 455 // MLIR Runner (JitRunner) dynamic library integration. 456 //===----------------------------------------------------------------------===// 457 458 // Export symbols for the MLIR runner integration. All other symbols are hidden. 459 #ifdef _WIN32 460 #define API __declspec(dllexport) 461 #else 462 #define API __attribute__((visibility("default"))) 463 #endif 464 465 // Visual Studio had a bug that fails to compile nested generic lambdas 466 // inside an `extern "C"` function. 467 // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html 468 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is 469 // a work around for older versions of Visual Studio. 470 // NOLINTNEXTLINE(*-identifier-naming): externally called. 471 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols); 472 473 // NOLINTNEXTLINE(*-identifier-naming): externally called. 474 void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) { 475 auto exportSymbol = [&](llvm::StringRef name, auto ptr) { 476 assert(exportSymbols.count(name) == 0 && "symbol already exists"); 477 exportSymbols[name] = reinterpret_cast<void *>(ptr); 478 }; 479 480 exportSymbol("mlirAsyncRuntimeAddRef", 481 &mlir::runtime::mlirAsyncRuntimeAddRef); 482 exportSymbol("mlirAsyncRuntimeDropRef", 483 &mlir::runtime::mlirAsyncRuntimeDropRef); 484 exportSymbol("mlirAsyncRuntimeExecute", 485 &mlir::runtime::mlirAsyncRuntimeExecute); 486 exportSymbol("mlirAsyncRuntimeGetValueStorage", 487 &mlir::runtime::mlirAsyncRuntimeGetValueStorage); 488 exportSymbol("mlirAsyncRuntimeCreateToken", 489 &mlir::runtime::mlirAsyncRuntimeCreateToken); 490 exportSymbol("mlirAsyncRuntimeCreateValue", 491 &mlir::runtime::mlirAsyncRuntimeCreateValue); 492 exportSymbol("mlirAsyncRuntimeEmplaceToken", 493 &mlir::runtime::mlirAsyncRuntimeEmplaceToken); 494 exportSymbol("mlirAsyncRuntimeEmplaceValue", 495 &mlir::runtime::mlirAsyncRuntimeEmplaceValue); 496 exportSymbol("mlirAsyncRuntimeSetTokenError", 497 &mlir::runtime::mlirAsyncRuntimeSetTokenError); 498 exportSymbol("mlirAsyncRuntimeSetValueError", 499 &mlir::runtime::mlirAsyncRuntimeSetValueError); 500 exportSymbol("mlirAsyncRuntimeIsTokenError", 501 &mlir::runtime::mlirAsyncRuntimeIsTokenError); 502 exportSymbol("mlirAsyncRuntimeIsValueError", 503 &mlir::runtime::mlirAsyncRuntimeIsValueError); 504 exportSymbol("mlirAsyncRuntimeIsGroupError", 505 &mlir::runtime::mlirAsyncRuntimeIsGroupError); 506 exportSymbol("mlirAsyncRuntimeAwaitToken", 507 &mlir::runtime::mlirAsyncRuntimeAwaitToken); 508 exportSymbol("mlirAsyncRuntimeAwaitValue", 509 &mlir::runtime::mlirAsyncRuntimeAwaitValue); 510 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute", 511 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute); 512 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute", 513 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute); 514 exportSymbol("mlirAsyncRuntimeCreateGroup", 515 &mlir::runtime::mlirAsyncRuntimeCreateGroup); 516 exportSymbol("mlirAsyncRuntimeAddTokenToGroup", 517 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup); 518 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup", 519 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup); 520 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute", 521 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute); 522 exportSymbol("mlirAsyncRuntimGetNumWorkerThreads", 523 &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads); 524 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId", 525 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); 526 } 527 528 // NOLINTNEXTLINE(*-identifier-naming): externally called. 529 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); } 530 531 } // namespace runtime 532 } // namespace mlir 533 534 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 535