1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 #include "llvm/ADT/StringMap.h" 28 #include "llvm/Support/ThreadPool.h" 29 30 using namespace mlir::runtime; 31 32 //===----------------------------------------------------------------------===// 33 // Async runtime API. 34 //===----------------------------------------------------------------------===// 35 36 namespace mlir { 37 namespace runtime { 38 namespace { 39 40 // Forward declare class defined below. 41 class RefCounted; 42 43 // -------------------------------------------------------------------------- // 44 // AsyncRuntime orchestrates all async operations and Async runtime API is built 45 // on top of the default runtime instance. 46 // -------------------------------------------------------------------------- // 47 48 class AsyncRuntime { 49 public: 50 AsyncRuntime() : numRefCountedObjects(0) {} 51 52 ~AsyncRuntime() { 53 threadPool.wait(); // wait for the completion of all async tasks 54 assert(getNumRefCountedObjects() == 0 && 55 "all ref counted objects must be destroyed"); 56 } 57 58 int32_t getNumRefCountedObjects() { 59 return numRefCountedObjects.load(std::memory_order_relaxed); 60 } 61 62 llvm::ThreadPool &getThreadPool() { return threadPool; } 63 64 private: 65 friend class RefCounted; 66 67 // Count the total number of reference counted objects in this instance 68 // of an AsyncRuntime. For debugging purposes only. 69 void addNumRefCountedObjects() { 70 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 71 } 72 void dropNumRefCountedObjects() { 73 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 74 } 75 76 std::atomic<int32_t> numRefCountedObjects; 77 llvm::ThreadPool threadPool; 78 }; 79 80 // -------------------------------------------------------------------------- // 81 // A state of the async runtime value (token, value or group). 82 // -------------------------------------------------------------------------- // 83 84 class State { 85 public: 86 enum StateEnum : int8_t { 87 // The underlying value is not yet available for consumption. 88 kUnavailable = 0, 89 // The underlying value is available for consumption. This state can not 90 // transition to any other state. 91 kAvailable = 1, 92 // This underlying value is available and contains an error. This state can 93 // not transition to any other state. 94 kError = 2, 95 }; 96 97 /* implicit */ State(StateEnum s) : state(s) {} 98 /* implicit */ operator StateEnum() { return state; } 99 100 bool isUnavailable() const { return state == kUnavailable; } 101 bool isAvailable() const { return state == kAvailable; } 102 bool isError() const { return state == kError; } 103 bool isAvailableOrError() const { return isAvailable() || isError(); } 104 105 const char *debug() const { 106 switch (state) { 107 case kUnavailable: 108 return "unavailable"; 109 case kAvailable: 110 return "available"; 111 case kError: 112 return "error"; 113 } 114 } 115 116 private: 117 StateEnum state; 118 }; 119 120 // -------------------------------------------------------------------------- // 121 // A base class for all reference counted objects created by the async runtime. 122 // -------------------------------------------------------------------------- // 123 124 class RefCounted { 125 public: 126 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) 127 : runtime(runtime), refCount(refCount) { 128 runtime->addNumRefCountedObjects(); 129 } 130 131 virtual ~RefCounted() { 132 assert(refCount.load() == 0 && "reference count must be zero"); 133 runtime->dropNumRefCountedObjects(); 134 } 135 136 RefCounted(const RefCounted &) = delete; 137 RefCounted &operator=(const RefCounted &) = delete; 138 139 void addRef(int32_t count = 1) { refCount.fetch_add(count); } 140 141 void dropRef(int32_t count = 1) { 142 int32_t previous = refCount.fetch_sub(count); 143 assert(previous >= count && "reference count should not go below zero"); 144 if (previous == count) 145 destroy(); 146 } 147 148 protected: 149 virtual void destroy() { delete this; } 150 151 private: 152 AsyncRuntime *runtime; 153 std::atomic<int32_t> refCount; 154 }; 155 156 } // namespace 157 158 // Returns the default per-process instance of an async runtime. 159 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() { 160 static auto runtime = std::make_unique<AsyncRuntime>(); 161 return runtime; 162 } 163 164 static void resetDefaultAsyncRuntime() { 165 return getDefaultAsyncRuntimeInstance().reset(); 166 } 167 168 static AsyncRuntime *getDefaultAsyncRuntime() { 169 return getDefaultAsyncRuntimeInstance().get(); 170 } 171 172 // Async token provides a mechanism to signal asynchronous operation completion. 173 struct AsyncToken : public RefCounted { 174 // AsyncToken created with a reference count of 2 because it will be returned 175 // to the `async.execute` caller and also will be later on emplaced by the 176 // asynchronously executed task. If the caller immediately will drop its 177 // reference we must ensure that the token will be alive until the 178 // asynchronous operation is completed. 179 AsyncToken(AsyncRuntime *runtime) 180 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {} 181 182 std::atomic<State::StateEnum> state; 183 184 // Pending awaiters are guarded by a mutex. 185 std::mutex mu; 186 std::condition_variable cv; 187 std::vector<std::function<void()>> awaiters; 188 }; 189 190 // Async value provides a mechanism to access the result of asynchronous 191 // operations. It owns the storage that is used to store/load the value of the 192 // underlying type, and a flag to signal if the value is ready or not. 193 struct AsyncValue : public RefCounted { 194 // AsyncValue similar to an AsyncToken created with a reference count of 2. 195 AsyncValue(AsyncRuntime *runtime, int32_t size) 196 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable), 197 storage(size) {} 198 199 std::atomic<State::StateEnum> state; 200 201 // Use vector of bytes to store async value payload. 202 std::vector<int8_t> storage; 203 204 // Pending awaiters are guarded by a mutex. 205 std::mutex mu; 206 std::condition_variable cv; 207 std::vector<std::function<void()>> awaiters; 208 }; 209 210 // Async group provides a mechanism to group together multiple async tokens or 211 // values to await on all of them together (wait for the completion of all 212 // tokens or values added to the group). 213 struct AsyncGroup : public RefCounted { 214 AsyncGroup(AsyncRuntime *runtime) 215 : RefCounted(runtime), pendingTokens(0), rank(0) {} 216 217 std::atomic<int> pendingTokens; 218 std::atomic<int> rank; 219 220 // Pending awaiters are guarded by a mutex. 221 std::mutex mu; 222 std::condition_variable cv; 223 std::vector<std::function<void()>> awaiters; 224 }; 225 226 // Adds references to reference counted runtime object. 227 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { 228 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 229 refCounted->addRef(count); 230 } 231 232 // Drops references from reference counted runtime object. 233 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { 234 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 235 refCounted->dropRef(count); 236 } 237 238 // Creates a new `async.token` in not-ready state. 239 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 240 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime()); 241 return token; 242 } 243 244 // Creates a new `async.value` in not-ready state. 245 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { 246 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size); 247 return value; 248 } 249 250 // Create a new `async.group` in empty state. 251 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { 252 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime()); 253 return group; 254 } 255 256 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, 257 AsyncGroup *group) { 258 std::unique_lock<std::mutex> lockToken(token->mu); 259 std::unique_lock<std::mutex> lockGroup(group->mu); 260 261 // Get the rank of the token inside the group before we drop the reference. 262 int rank = group->rank.fetch_add(1); 263 group->pendingTokens.fetch_add(1); 264 265 auto onTokenReady = [group]() { 266 // Run all group awaiters if it was the last token in the group. 267 if (group->pendingTokens.fetch_sub(1) == 1) { 268 group->cv.notify_all(); 269 for (auto &awaiter : group->awaiters) 270 awaiter(); 271 } 272 }; 273 274 if (State(token->state).isAvailableOrError()) { 275 // Update group pending tokens immediately and maybe run awaiters. 276 onTokenReady(); 277 278 } else { 279 // Update group pending tokens when token will become ready. Because this 280 // will happen asynchronously we must ensure that `group` is alive until 281 // then, and re-ackquire the lock. 282 group->addRef(); 283 284 token->awaiters.push_back([group, onTokenReady]() { 285 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 286 { 287 std::unique_lock<std::mutex> lockGroup(group->mu); 288 onTokenReady(); 289 } 290 group->dropRef(); 291 }); 292 } 293 294 return rank; 295 } 296 297 // Switches `async.token` to available or error state (terminatl state) and runs 298 // all awaiters. 299 static void setTokenState(AsyncToken *token, State state) { 300 assert(state.isAvailableOrError() && "must be terminal state"); 301 assert(State(token->state).isUnavailable() && "token must be unavailable"); 302 303 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 304 { 305 std::unique_lock<std::mutex> lock(token->mu); 306 token->state = state; 307 token->cv.notify_all(); 308 for (auto &awaiter : token->awaiters) 309 awaiter(); 310 } 311 312 // Async tokens created with a ref count `2` to keep token alive until the 313 // async task completes. Drop this reference explicitly when token emplaced. 314 token->dropRef(); 315 } 316 317 static void setValueState(AsyncValue *value, State state) { 318 assert(state.isAvailableOrError() && "must be terminal state"); 319 assert(State(value->state).isUnavailable() && "value must be unavailable"); 320 321 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 322 { 323 std::unique_lock<std::mutex> lock(value->mu); 324 value->state = state; 325 value->cv.notify_all(); 326 for (auto &awaiter : value->awaiters) 327 awaiter(); 328 } 329 330 // Async values created with a ref count `2` to keep value alive until the 331 // async task completes. Drop this reference explicitly when value emplaced. 332 value->dropRef(); 333 } 334 335 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 336 setTokenState(token, State::kAvailable); 337 } 338 339 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { 340 setValueState(value, State::kAvailable); 341 } 342 343 extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) { 344 setTokenState(token, State::kError); 345 } 346 347 extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) { 348 setValueState(value, State::kError); 349 } 350 351 extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) { 352 return State(token->state).isError(); 353 } 354 355 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) { 356 return State(value->state).isError(); 357 } 358 359 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 360 std::unique_lock<std::mutex> lock(token->mu); 361 if (!State(token->state).isAvailableOrError()) 362 token->cv.wait( 363 lock, [token] { return State(token->state).isAvailableOrError(); }); 364 } 365 366 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { 367 std::unique_lock<std::mutex> lock(value->mu); 368 if (!State(value->state).isAvailableOrError()) 369 value->cv.wait( 370 lock, [value] { return State(value->state).isAvailableOrError(); }); 371 } 372 373 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 374 std::unique_lock<std::mutex> lock(group->mu); 375 if (group->pendingTokens != 0) 376 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 377 } 378 379 // Returns a pointer to the storage owned by the async value. 380 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) { 381 assert(!State(value->state).isError() && "unexpected error state"); 382 return value->storage.data(); 383 } 384 385 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 386 auto *runtime = getDefaultAsyncRuntime(); 387 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); 388 } 389 390 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 391 CoroHandle handle, 392 CoroResume resume) { 393 auto execute = [handle, resume]() { (*resume)(handle); }; 394 std::unique_lock<std::mutex> lock(token->mu); 395 if (State(token->state).isAvailableOrError()) { 396 lock.unlock(); 397 execute(); 398 } else { 399 token->awaiters.push_back([execute]() { execute(); }); 400 } 401 } 402 403 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, 404 CoroHandle handle, 405 CoroResume resume) { 406 auto execute = [handle, resume]() { (*resume)(handle); }; 407 std::unique_lock<std::mutex> lock(value->mu); 408 if (State(value->state).isAvailableOrError()) { 409 lock.unlock(); 410 execute(); 411 } else { 412 value->awaiters.push_back([execute]() { execute(); }); 413 } 414 } 415 416 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, 417 CoroHandle handle, 418 CoroResume resume) { 419 auto execute = [handle, resume]() { (*resume)(handle); }; 420 std::unique_lock<std::mutex> lock(group->mu); 421 if (group->pendingTokens == 0) { 422 lock.unlock(); 423 execute(); 424 } else { 425 group->awaiters.push_back([execute]() { execute(); }); 426 } 427 } 428 429 //===----------------------------------------------------------------------===// 430 // Small async runtime support library for testing. 431 //===----------------------------------------------------------------------===// 432 433 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 434 static thread_local std::thread::id thisId = std::this_thread::get_id(); 435 std::cout << "Current thread id: " << thisId << std::endl; 436 } 437 438 //===----------------------------------------------------------------------===// 439 // MLIR Runner (JitRunner) dynamic library integration. 440 //===----------------------------------------------------------------------===// 441 442 // Export symbols for the MLIR runner integration. All other symbols are hidden. 443 #ifdef _WIN32 444 #define API __declspec(dllexport) 445 #else 446 #define API __attribute__((visibility("default"))) 447 #endif 448 449 // Visual Studio had a bug that fails to compile nested generic lambdas 450 // inside an `extern "C"` function. 451 // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html 452 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is 453 // a work around for older versions of Visual Studio. 454 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols); 455 456 void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) { 457 auto exportSymbol = [&](llvm::StringRef name, auto ptr) { 458 assert(exportSymbols.count(name) == 0 && "symbol already exists"); 459 exportSymbols[name] = reinterpret_cast<void *>(ptr); 460 }; 461 462 exportSymbol("mlirAsyncRuntimeAddRef", 463 &mlir::runtime::mlirAsyncRuntimeAddRef); 464 exportSymbol("mlirAsyncRuntimeDropRef", 465 &mlir::runtime::mlirAsyncRuntimeDropRef); 466 exportSymbol("mlirAsyncRuntimeExecute", 467 &mlir::runtime::mlirAsyncRuntimeExecute); 468 exportSymbol("mlirAsyncRuntimeGetValueStorage", 469 &mlir::runtime::mlirAsyncRuntimeGetValueStorage); 470 exportSymbol("mlirAsyncRuntimeCreateToken", 471 &mlir::runtime::mlirAsyncRuntimeCreateToken); 472 exportSymbol("mlirAsyncRuntimeCreateValue", 473 &mlir::runtime::mlirAsyncRuntimeCreateValue); 474 exportSymbol("mlirAsyncRuntimeEmplaceToken", 475 &mlir::runtime::mlirAsyncRuntimeEmplaceToken); 476 exportSymbol("mlirAsyncRuntimeEmplaceValue", 477 &mlir::runtime::mlirAsyncRuntimeEmplaceValue); 478 exportSymbol("mlirAsyncRuntimeSetTokenError", 479 &mlir::runtime::mlirAsyncRuntimeSetTokenError); 480 exportSymbol("mlirAsyncRuntimeSetValueError", 481 &mlir::runtime::mlirAsyncRuntimeSetValueError); 482 exportSymbol("mlirAsyncRuntimeIsTokenError", 483 &mlir::runtime::mlirAsyncRuntimeIsTokenError); 484 exportSymbol("mlirAsyncRuntimeIsValueError", 485 &mlir::runtime::mlirAsyncRuntimeIsValueError); 486 exportSymbol("mlirAsyncRuntimeAwaitToken", 487 &mlir::runtime::mlirAsyncRuntimeAwaitToken); 488 exportSymbol("mlirAsyncRuntimeAwaitValue", 489 &mlir::runtime::mlirAsyncRuntimeAwaitValue); 490 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute", 491 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute); 492 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute", 493 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute); 494 exportSymbol("mlirAsyncRuntimeCreateGroup", 495 &mlir::runtime::mlirAsyncRuntimeCreateGroup); 496 exportSymbol("mlirAsyncRuntimeAddTokenToGroup", 497 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup); 498 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup", 499 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup); 500 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute", 501 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute); 502 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId", 503 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); 504 } 505 506 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); } 507 508 } // namespace runtime 509 } // namespace mlir 510 511 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 512