1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 #include "llvm/ADT/StringMap.h" 28 #include "llvm/Support/ThreadPool.h" 29 30 using namespace mlir::runtime; 31 32 //===----------------------------------------------------------------------===// 33 // Async runtime API. 34 //===----------------------------------------------------------------------===// 35 36 namespace mlir { 37 namespace runtime { 38 namespace { 39 40 // Forward declare class defined below. 41 class RefCounted; 42 43 // -------------------------------------------------------------------------- // 44 // AsyncRuntime orchestrates all async operations and Async runtime API is built 45 // on top of the default runtime instance. 46 // -------------------------------------------------------------------------- // 47 48 class AsyncRuntime { 49 public: 50 AsyncRuntime() : numRefCountedObjects(0) {} 51 52 ~AsyncRuntime() { 53 threadPool.wait(); // wait for the completion of all async tasks 54 assert(getNumRefCountedObjects() == 0 && 55 "all ref counted objects must be destroyed"); 56 } 57 58 int32_t getNumRefCountedObjects() { 59 return numRefCountedObjects.load(std::memory_order_relaxed); 60 } 61 62 llvm::ThreadPool &getThreadPool() { return threadPool; } 63 64 private: 65 friend class RefCounted; 66 67 // Count the total number of reference counted objects in this instance 68 // of an AsyncRuntime. For debugging purposes only. 69 void addNumRefCountedObjects() { 70 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 71 } 72 void dropNumRefCountedObjects() { 73 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 74 } 75 76 std::atomic<int32_t> numRefCountedObjects; 77 llvm::ThreadPool threadPool; 78 }; 79 80 // -------------------------------------------------------------------------- // 81 // A state of the async runtime value (token, value or group). 82 // -------------------------------------------------------------------------- // 83 84 class State { 85 public: 86 enum StateEnum : int8_t { 87 // The underlying value is not yet available for consumption. 88 kUnavailable = 0, 89 // The underlying value is available for consumption. This state can not 90 // transition to any other state. 91 kAvailable = 1, 92 // This underlying value is available and contains an error. This state can 93 // not transition to any other state. 94 kError = 2, 95 }; 96 97 /* implicit */ State(StateEnum s) : state(s) {} 98 /* implicit */ operator StateEnum() { return state; } 99 100 bool isUnavailable() const { return state == kUnavailable; } 101 bool isAvailable() const { return state == kAvailable; } 102 bool isError() const { return state == kError; } 103 bool isAvailableOrError() const { return isAvailable() || isError(); } 104 105 const char *debug() const { 106 switch (state) { 107 case kUnavailable: 108 return "unavailable"; 109 case kAvailable: 110 return "available"; 111 case kError: 112 return "error"; 113 } 114 } 115 116 private: 117 StateEnum state; 118 }; 119 120 // -------------------------------------------------------------------------- // 121 // A base class for all reference counted objects created by the async runtime. 122 // -------------------------------------------------------------------------- // 123 124 class RefCounted { 125 public: 126 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) 127 : runtime(runtime), refCount(refCount) { 128 runtime->addNumRefCountedObjects(); 129 } 130 131 virtual ~RefCounted() { 132 assert(refCount.load() == 0 && "reference count must be zero"); 133 runtime->dropNumRefCountedObjects(); 134 } 135 136 RefCounted(const RefCounted &) = delete; 137 RefCounted &operator=(const RefCounted &) = delete; 138 139 void addRef(int32_t count = 1) { refCount.fetch_add(count); } 140 141 void dropRef(int32_t count = 1) { 142 int32_t previous = refCount.fetch_sub(count); 143 assert(previous >= count && "reference count should not go below zero"); 144 if (previous == count) 145 destroy(); 146 } 147 148 protected: 149 virtual void destroy() { delete this; } 150 151 private: 152 AsyncRuntime *runtime; 153 std::atomic<int32_t> refCount; 154 }; 155 156 } // namespace 157 158 // Returns the default per-process instance of an async runtime. 159 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() { 160 static auto runtime = std::make_unique<AsyncRuntime>(); 161 return runtime; 162 } 163 164 static void resetDefaultAsyncRuntime() { 165 return getDefaultAsyncRuntimeInstance().reset(); 166 } 167 168 static AsyncRuntime *getDefaultAsyncRuntime() { 169 return getDefaultAsyncRuntimeInstance().get(); 170 } 171 172 // Async token provides a mechanism to signal asynchronous operation completion. 173 struct AsyncToken : public RefCounted { 174 // AsyncToken created with a reference count of 2 because it will be returned 175 // to the `async.execute` caller and also will be later on emplaced by the 176 // asynchronously executed task. If the caller immediately will drop its 177 // reference we must ensure that the token will be alive until the 178 // asynchronous operation is completed. 179 AsyncToken(AsyncRuntime *runtime) 180 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {} 181 182 std::atomic<State::StateEnum> state; 183 184 // Pending awaiters are guarded by a mutex. 185 std::mutex mu; 186 std::condition_variable cv; 187 std::vector<std::function<void()>> awaiters; 188 }; 189 190 // Async value provides a mechanism to access the result of asynchronous 191 // operations. It owns the storage that is used to store/load the value of the 192 // underlying type, and a flag to signal if the value is ready or not. 193 struct AsyncValue : public RefCounted { 194 // AsyncValue similar to an AsyncToken created with a reference count of 2. 195 AsyncValue(AsyncRuntime *runtime, int32_t size) 196 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable), 197 storage(size) {} 198 199 std::atomic<State::StateEnum> state; 200 201 // Use vector of bytes to store async value payload. 202 std::vector<int8_t> storage; 203 204 // Pending awaiters are guarded by a mutex. 205 std::mutex mu; 206 std::condition_variable cv; 207 std::vector<std::function<void()>> awaiters; 208 }; 209 210 // Async group provides a mechanism to group together multiple async tokens or 211 // values to await on all of them together (wait for the completion of all 212 // tokens or values added to the group). 213 struct AsyncGroup : public RefCounted { 214 AsyncGroup(AsyncRuntime *runtime) 215 : RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {} 216 217 std::atomic<int> pendingTokens; 218 std::atomic<int> numErrors; 219 std::atomic<int> rank; 220 221 // Pending awaiters are guarded by a mutex. 222 std::mutex mu; 223 std::condition_variable cv; 224 std::vector<std::function<void()>> awaiters; 225 }; 226 227 // Adds references to reference counted runtime object. 228 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { 229 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 230 refCounted->addRef(count); 231 } 232 233 // Drops references from reference counted runtime object. 234 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { 235 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 236 refCounted->dropRef(count); 237 } 238 239 // Creates a new `async.token` in not-ready state. 240 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 241 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime()); 242 return token; 243 } 244 245 // Creates a new `async.value` in not-ready state. 246 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { 247 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size); 248 return value; 249 } 250 251 // Create a new `async.group` in empty state. 252 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { 253 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime()); 254 return group; 255 } 256 257 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, 258 AsyncGroup *group) { 259 std::unique_lock<std::mutex> lockToken(token->mu); 260 std::unique_lock<std::mutex> lockGroup(group->mu); 261 262 // Get the rank of the token inside the group before we drop the reference. 263 int rank = group->rank.fetch_add(1); 264 group->pendingTokens.fetch_add(1); 265 266 auto onTokenReady = [group, token]() { 267 // Increment the number of errors in the group. 268 if (State(token->state).isError()) 269 group->numErrors.fetch_add(1); 270 271 // Run all group awaiters if it was the last token in the group. 272 if (group->pendingTokens.fetch_sub(1) == 1) { 273 group->cv.notify_all(); 274 for (auto &awaiter : group->awaiters) 275 awaiter(); 276 } 277 }; 278 279 if (State(token->state).isAvailableOrError()) { 280 // Update group pending tokens immediately and maybe run awaiters. 281 onTokenReady(); 282 283 } else { 284 // Update group pending tokens when token will become ready. Because this 285 // will happen asynchronously we must ensure that `group` is alive until 286 // then, and re-ackquire the lock. 287 group->addRef(); 288 289 token->awaiters.push_back([group, onTokenReady]() { 290 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 291 { 292 std::unique_lock<std::mutex> lockGroup(group->mu); 293 onTokenReady(); 294 } 295 group->dropRef(); 296 }); 297 } 298 299 return rank; 300 } 301 302 // Switches `async.token` to available or error state (terminatl state) and runs 303 // all awaiters. 304 static void setTokenState(AsyncToken *token, State state) { 305 assert(state.isAvailableOrError() && "must be terminal state"); 306 assert(State(token->state).isUnavailable() && "token must be unavailable"); 307 308 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 309 { 310 std::unique_lock<std::mutex> lock(token->mu); 311 token->state = state; 312 token->cv.notify_all(); 313 for (auto &awaiter : token->awaiters) 314 awaiter(); 315 } 316 317 // Async tokens created with a ref count `2` to keep token alive until the 318 // async task completes. Drop this reference explicitly when token emplaced. 319 token->dropRef(); 320 } 321 322 static void setValueState(AsyncValue *value, State state) { 323 assert(state.isAvailableOrError() && "must be terminal state"); 324 assert(State(value->state).isUnavailable() && "value must be unavailable"); 325 326 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 327 { 328 std::unique_lock<std::mutex> lock(value->mu); 329 value->state = state; 330 value->cv.notify_all(); 331 for (auto &awaiter : value->awaiters) 332 awaiter(); 333 } 334 335 // Async values created with a ref count `2` to keep value alive until the 336 // async task completes. Drop this reference explicitly when value emplaced. 337 value->dropRef(); 338 } 339 340 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 341 setTokenState(token, State::kAvailable); 342 } 343 344 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { 345 setValueState(value, State::kAvailable); 346 } 347 348 extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) { 349 setTokenState(token, State::kError); 350 } 351 352 extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) { 353 setValueState(value, State::kError); 354 } 355 356 extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) { 357 return State(token->state).isError(); 358 } 359 360 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) { 361 return State(value->state).isError(); 362 } 363 364 extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) { 365 return group->numErrors.load() > 0; 366 } 367 368 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 369 std::unique_lock<std::mutex> lock(token->mu); 370 if (!State(token->state).isAvailableOrError()) 371 token->cv.wait( 372 lock, [token] { return State(token->state).isAvailableOrError(); }); 373 } 374 375 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { 376 std::unique_lock<std::mutex> lock(value->mu); 377 if (!State(value->state).isAvailableOrError()) 378 value->cv.wait( 379 lock, [value] { return State(value->state).isAvailableOrError(); }); 380 } 381 382 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 383 std::unique_lock<std::mutex> lock(group->mu); 384 if (group->pendingTokens != 0) 385 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 386 } 387 388 // Returns a pointer to the storage owned by the async value. 389 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) { 390 assert(!State(value->state).isError() && "unexpected error state"); 391 return value->storage.data(); 392 } 393 394 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 395 auto *runtime = getDefaultAsyncRuntime(); 396 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); 397 } 398 399 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 400 CoroHandle handle, 401 CoroResume resume) { 402 auto execute = [handle, resume]() { (*resume)(handle); }; 403 std::unique_lock<std::mutex> lock(token->mu); 404 if (State(token->state).isAvailableOrError()) { 405 lock.unlock(); 406 execute(); 407 } else { 408 token->awaiters.push_back([execute]() { execute(); }); 409 } 410 } 411 412 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, 413 CoroHandle handle, 414 CoroResume resume) { 415 auto execute = [handle, resume]() { (*resume)(handle); }; 416 std::unique_lock<std::mutex> lock(value->mu); 417 if (State(value->state).isAvailableOrError()) { 418 lock.unlock(); 419 execute(); 420 } else { 421 value->awaiters.push_back([execute]() { execute(); }); 422 } 423 } 424 425 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, 426 CoroHandle handle, 427 CoroResume resume) { 428 auto execute = [handle, resume]() { (*resume)(handle); }; 429 std::unique_lock<std::mutex> lock(group->mu); 430 if (group->pendingTokens == 0) { 431 lock.unlock(); 432 execute(); 433 } else { 434 group->awaiters.push_back([execute]() { execute(); }); 435 } 436 } 437 438 //===----------------------------------------------------------------------===// 439 // Small async runtime support library for testing. 440 //===----------------------------------------------------------------------===// 441 442 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 443 static thread_local std::thread::id thisId = std::this_thread::get_id(); 444 std::cout << "Current thread id: " << thisId << std::endl; 445 } 446 447 //===----------------------------------------------------------------------===// 448 // MLIR Runner (JitRunner) dynamic library integration. 449 //===----------------------------------------------------------------------===// 450 451 // Export symbols for the MLIR runner integration. All other symbols are hidden. 452 #ifdef _WIN32 453 #define API __declspec(dllexport) 454 #else 455 #define API __attribute__((visibility("default"))) 456 #endif 457 458 // Visual Studio had a bug that fails to compile nested generic lambdas 459 // inside an `extern "C"` function. 460 // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html 461 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is 462 // a work around for older versions of Visual Studio. 463 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols); 464 465 void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) { 466 auto exportSymbol = [&](llvm::StringRef name, auto ptr) { 467 assert(exportSymbols.count(name) == 0 && "symbol already exists"); 468 exportSymbols[name] = reinterpret_cast<void *>(ptr); 469 }; 470 471 exportSymbol("mlirAsyncRuntimeAddRef", 472 &mlir::runtime::mlirAsyncRuntimeAddRef); 473 exportSymbol("mlirAsyncRuntimeDropRef", 474 &mlir::runtime::mlirAsyncRuntimeDropRef); 475 exportSymbol("mlirAsyncRuntimeExecute", 476 &mlir::runtime::mlirAsyncRuntimeExecute); 477 exportSymbol("mlirAsyncRuntimeGetValueStorage", 478 &mlir::runtime::mlirAsyncRuntimeGetValueStorage); 479 exportSymbol("mlirAsyncRuntimeCreateToken", 480 &mlir::runtime::mlirAsyncRuntimeCreateToken); 481 exportSymbol("mlirAsyncRuntimeCreateValue", 482 &mlir::runtime::mlirAsyncRuntimeCreateValue); 483 exportSymbol("mlirAsyncRuntimeEmplaceToken", 484 &mlir::runtime::mlirAsyncRuntimeEmplaceToken); 485 exportSymbol("mlirAsyncRuntimeEmplaceValue", 486 &mlir::runtime::mlirAsyncRuntimeEmplaceValue); 487 exportSymbol("mlirAsyncRuntimeSetTokenError", 488 &mlir::runtime::mlirAsyncRuntimeSetTokenError); 489 exportSymbol("mlirAsyncRuntimeSetValueError", 490 &mlir::runtime::mlirAsyncRuntimeSetValueError); 491 exportSymbol("mlirAsyncRuntimeIsTokenError", 492 &mlir::runtime::mlirAsyncRuntimeIsTokenError); 493 exportSymbol("mlirAsyncRuntimeIsValueError", 494 &mlir::runtime::mlirAsyncRuntimeIsValueError); 495 exportSymbol("mlirAsyncRuntimeIsGroupError", 496 &mlir::runtime::mlirAsyncRuntimeIsGroupError); 497 exportSymbol("mlirAsyncRuntimeAwaitToken", 498 &mlir::runtime::mlirAsyncRuntimeAwaitToken); 499 exportSymbol("mlirAsyncRuntimeAwaitValue", 500 &mlir::runtime::mlirAsyncRuntimeAwaitValue); 501 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute", 502 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute); 503 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute", 504 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute); 505 exportSymbol("mlirAsyncRuntimeCreateGroup", 506 &mlir::runtime::mlirAsyncRuntimeCreateGroup); 507 exportSymbol("mlirAsyncRuntimeAddTokenToGroup", 508 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup); 509 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup", 510 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup); 511 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute", 512 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute); 513 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId", 514 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); 515 } 516 517 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); } 518 519 } // namespace runtime 520 } // namespace mlir 521 522 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 523