1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 #include "llvm/ADT/StringMap.h" 28 #include "llvm/Support/ThreadPool.h" 29 30 using namespace mlir::runtime; 31 32 //===----------------------------------------------------------------------===// 33 // Async runtime API. 34 //===----------------------------------------------------------------------===// 35 36 namespace mlir { 37 namespace runtime { 38 namespace { 39 40 // Forward declare class defined below. 41 class RefCounted; 42 43 // -------------------------------------------------------------------------- // 44 // AsyncRuntime orchestrates all async operations and Async runtime API is built 45 // on top of the default runtime instance. 46 // -------------------------------------------------------------------------- // 47 48 class AsyncRuntime { 49 public: 50 AsyncRuntime() : numRefCountedObjects(0) {} 51 52 ~AsyncRuntime() { 53 threadPool.wait(); // wait for the completion of all async tasks 54 assert(getNumRefCountedObjects() == 0 && 55 "all ref counted objects must be destroyed"); 56 } 57 58 int32_t getNumRefCountedObjects() { 59 return numRefCountedObjects.load(std::memory_order_relaxed); 60 } 61 62 llvm::ThreadPool &getThreadPool() { return threadPool; } 63 64 private: 65 friend class RefCounted; 66 67 // Count the total number of reference counted objects in this instance 68 // of an AsyncRuntime. For debugging purposes only. 69 void addNumRefCountedObjects() { 70 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 71 } 72 void dropNumRefCountedObjects() { 73 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 74 } 75 76 std::atomic<int32_t> numRefCountedObjects; 77 llvm::ThreadPool threadPool; 78 }; 79 80 // -------------------------------------------------------------------------- // 81 // A base class for all reference counted objects created by the async runtime. 82 // -------------------------------------------------------------------------- // 83 84 class RefCounted { 85 public: 86 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) 87 : runtime(runtime), refCount(refCount) { 88 runtime->addNumRefCountedObjects(); 89 } 90 91 virtual ~RefCounted() { 92 assert(refCount.load() == 0 && "reference count must be zero"); 93 runtime->dropNumRefCountedObjects(); 94 } 95 96 RefCounted(const RefCounted &) = delete; 97 RefCounted &operator=(const RefCounted &) = delete; 98 99 void addRef(int32_t count = 1) { refCount.fetch_add(count); } 100 101 void dropRef(int32_t count = 1) { 102 int32_t previous = refCount.fetch_sub(count); 103 assert(previous >= count && "reference count should not go below zero"); 104 if (previous == count) 105 destroy(); 106 } 107 108 protected: 109 virtual void destroy() { delete this; } 110 111 private: 112 AsyncRuntime *runtime; 113 std::atomic<int32_t> refCount; 114 }; 115 116 } // namespace 117 118 // Returns the default per-process instance of an async runtime. 119 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() { 120 static auto runtime = std::make_unique<AsyncRuntime>(); 121 return runtime; 122 } 123 124 static void resetDefaultAsyncRuntime() { 125 return getDefaultAsyncRuntimeInstance().reset(); 126 } 127 128 static AsyncRuntime *getDefaultAsyncRuntime() { 129 return getDefaultAsyncRuntimeInstance().get(); 130 } 131 132 // Async token provides a mechanism to signal asynchronous operation completion. 133 struct AsyncToken : public RefCounted { 134 // AsyncToken created with a reference count of 2 because it will be returned 135 // to the `async.execute` caller and also will be later on emplaced by the 136 // asynchronously executed task. If the caller immediately will drop its 137 // reference we must ensure that the token will be alive until the 138 // asynchronous operation is completed. 139 AsyncToken(AsyncRuntime *runtime) 140 : RefCounted(runtime, /*count=*/2), ready(false) {} 141 142 std::atomic<bool> ready; 143 144 // Pending awaiters are guarded by a mutex. 145 std::mutex mu; 146 std::condition_variable cv; 147 std::vector<std::function<void()>> awaiters; 148 }; 149 150 // Async value provides a mechanism to access the result of asynchronous 151 // operations. It owns the storage that is used to store/load the value of the 152 // underlying type, and a flag to signal if the value is ready or not. 153 struct AsyncValue : public RefCounted { 154 // AsyncValue similar to an AsyncToken created with a reference count of 2. 155 AsyncValue(AsyncRuntime *runtime, int32_t size) 156 : RefCounted(runtime, /*count=*/2), ready(false), storage(size) {} 157 158 std::atomic<bool> ready; 159 160 // Use vector of bytes to store async value payload. 161 std::vector<int8_t> storage; 162 163 // Pending awaiters are guarded by a mutex. 164 std::mutex mu; 165 std::condition_variable cv; 166 std::vector<std::function<void()>> awaiters; 167 }; 168 169 // Async group provides a mechanism to group together multiple async tokens or 170 // values to await on all of them together (wait for the completion of all 171 // tokens or values added to the group). 172 struct AsyncGroup : public RefCounted { 173 AsyncGroup(AsyncRuntime *runtime) 174 : RefCounted(runtime), pendingTokens(0), rank(0) {} 175 176 std::atomic<int> pendingTokens; 177 std::atomic<int> rank; 178 179 // Pending awaiters are guarded by a mutex. 180 std::mutex mu; 181 std::condition_variable cv; 182 std::vector<std::function<void()>> awaiters; 183 }; 184 185 186 // Adds references to reference counted runtime object. 187 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { 188 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 189 refCounted->addRef(count); 190 } 191 192 // Drops references from reference counted runtime object. 193 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { 194 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 195 refCounted->dropRef(count); 196 } 197 198 // Creates a new `async.token` in not-ready state. 199 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 200 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime()); 201 return token; 202 } 203 204 // Creates a new `async.value` in not-ready state. 205 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { 206 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size); 207 return value; 208 } 209 210 // Create a new `async.group` in empty state. 211 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { 212 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime()); 213 return group; 214 } 215 216 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, 217 AsyncGroup *group) { 218 std::unique_lock<std::mutex> lockToken(token->mu); 219 std::unique_lock<std::mutex> lockGroup(group->mu); 220 221 // Get the rank of the token inside the group before we drop the reference. 222 int rank = group->rank.fetch_add(1); 223 group->pendingTokens.fetch_add(1); 224 225 auto onTokenReady = [group]() { 226 // Run all group awaiters if it was the last token in the group. 227 if (group->pendingTokens.fetch_sub(1) == 1) { 228 group->cv.notify_all(); 229 for (auto &awaiter : group->awaiters) 230 awaiter(); 231 } 232 }; 233 234 if (token->ready) { 235 // Update group pending tokens immediately and maybe run awaiters. 236 onTokenReady(); 237 238 } else { 239 // Update group pending tokens when token will become ready. Because this 240 // will happen asynchronously we must ensure that `group` is alive until 241 // then, and re-ackquire the lock. 242 group->addRef(); 243 244 token->awaiters.push_back([group, onTokenReady]() { 245 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 246 { 247 std::unique_lock<std::mutex> lockGroup(group->mu); 248 onTokenReady(); 249 } 250 group->dropRef(); 251 }); 252 } 253 254 return rank; 255 } 256 257 // Switches `async.token` to ready state and runs all awaiters. 258 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 259 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 260 { 261 std::unique_lock<std::mutex> lock(token->mu); 262 token->ready = true; 263 token->cv.notify_all(); 264 for (auto &awaiter : token->awaiters) 265 awaiter(); 266 } 267 268 // Async tokens created with a ref count `2` to keep token alive until the 269 // async task completes. Drop this reference explicitly when token emplaced. 270 token->dropRef(); 271 } 272 273 // Switches `async.value` to ready state and runs all awaiters. 274 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { 275 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 276 { 277 std::unique_lock<std::mutex> lock(value->mu); 278 value->ready = true; 279 value->cv.notify_all(); 280 for (auto &awaiter : value->awaiters) 281 awaiter(); 282 } 283 284 // Async values created with a ref count `2` to keep value alive until the 285 // async task completes. Drop this reference explicitly when value emplaced. 286 value->dropRef(); 287 } 288 289 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 290 std::unique_lock<std::mutex> lock(token->mu); 291 if (!token->ready) 292 token->cv.wait(lock, [token] { return token->ready.load(); }); 293 } 294 295 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { 296 std::unique_lock<std::mutex> lock(value->mu); 297 if (!value->ready) 298 value->cv.wait(lock, [value] { return value->ready.load(); }); 299 } 300 301 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 302 std::unique_lock<std::mutex> lock(group->mu); 303 if (group->pendingTokens != 0) 304 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 305 } 306 307 // Returns a pointer to the storage owned by the async value. 308 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) { 309 return value->storage.data(); 310 } 311 312 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 313 auto *runtime = getDefaultAsyncRuntime(); 314 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); 315 } 316 317 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 318 CoroHandle handle, 319 CoroResume resume) { 320 auto execute = [handle, resume]() { (*resume)(handle); }; 321 std::unique_lock<std::mutex> lock(token->mu); 322 if (token->ready) { 323 lock.unlock(); 324 execute(); 325 } else { 326 token->awaiters.push_back([execute]() { execute(); }); 327 } 328 } 329 330 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, 331 CoroHandle handle, 332 CoroResume resume) { 333 auto execute = [handle, resume]() { (*resume)(handle); }; 334 std::unique_lock<std::mutex> lock(value->mu); 335 if (value->ready) { 336 lock.unlock(); 337 execute(); 338 } else { 339 value->awaiters.push_back([execute]() { execute(); }); 340 } 341 } 342 343 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, 344 CoroHandle handle, 345 CoroResume resume) { 346 auto execute = [handle, resume]() { (*resume)(handle); }; 347 std::unique_lock<std::mutex> lock(group->mu); 348 if (group->pendingTokens == 0) { 349 lock.unlock(); 350 execute(); 351 } else { 352 group->awaiters.push_back([execute]() { execute(); }); 353 } 354 } 355 356 //===----------------------------------------------------------------------===// 357 // Small async runtime support library for testing. 358 //===----------------------------------------------------------------------===// 359 360 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 361 static thread_local std::thread::id thisId = std::this_thread::get_id(); 362 std::cout << "Current thread id: " << thisId << std::endl; 363 } 364 365 //===----------------------------------------------------------------------===// 366 // MLIR Runner (JitRunner) dynamic library integration. 367 //===----------------------------------------------------------------------===// 368 369 // Export symbols for the MLIR runner integration. All other symbols are hidden. 370 #ifdef _WIN32 371 #define API __declspec(dllexport) 372 #else 373 #define API __attribute__((visibility("default"))) 374 #endif 375 376 // Visual Studio had a bug that fails to compile nested generic lambdas 377 // inside an `extern "C"` function. 378 // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html 379 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is 380 // a work around for older versions of Visual Studio. 381 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols); 382 383 void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) { 384 auto exportSymbol = [&](llvm::StringRef name, auto ptr) { 385 assert(exportSymbols.count(name) == 0 && "symbol already exists"); 386 exportSymbols[name] = reinterpret_cast<void *>(ptr); 387 }; 388 389 exportSymbol("mlirAsyncRuntimeAddRef", 390 &mlir::runtime::mlirAsyncRuntimeAddRef); 391 exportSymbol("mlirAsyncRuntimeDropRef", 392 &mlir::runtime::mlirAsyncRuntimeDropRef); 393 exportSymbol("mlirAsyncRuntimeExecute", 394 &mlir::runtime::mlirAsyncRuntimeExecute); 395 exportSymbol("mlirAsyncRuntimeGetValueStorage", 396 &mlir::runtime::mlirAsyncRuntimeGetValueStorage); 397 exportSymbol("mlirAsyncRuntimeCreateToken", 398 &mlir::runtime::mlirAsyncRuntimeCreateToken); 399 exportSymbol("mlirAsyncRuntimeCreateValue", 400 &mlir::runtime::mlirAsyncRuntimeCreateValue); 401 exportSymbol("mlirAsyncRuntimeEmplaceToken", 402 &mlir::runtime::mlirAsyncRuntimeEmplaceToken); 403 exportSymbol("mlirAsyncRuntimeEmplaceValue", 404 &mlir::runtime::mlirAsyncRuntimeEmplaceValue); 405 exportSymbol("mlirAsyncRuntimeAwaitToken", 406 &mlir::runtime::mlirAsyncRuntimeAwaitToken); 407 exportSymbol("mlirAsyncRuntimeAwaitValue", 408 &mlir::runtime::mlirAsyncRuntimeAwaitValue); 409 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute", 410 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute); 411 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute", 412 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute); 413 exportSymbol("mlirAsyncRuntimeCreateGroup", 414 &mlir::runtime::mlirAsyncRuntimeCreateGroup); 415 exportSymbol("mlirAsyncRuntimeAddTokenToGroup", 416 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup); 417 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup", 418 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup); 419 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute", 420 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute); 421 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId", 422 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); 423 } 424 425 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); } 426 427 } // namespace runtime 428 } // namespace mlir 429 430 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 431