1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 #include "llvm/ADT/StringMap.h" 28 #include "llvm/Support/ThreadPool.h" 29 30 using namespace mlir::runtime; 31 32 //===----------------------------------------------------------------------===// 33 // Async runtime API. 34 //===----------------------------------------------------------------------===// 35 36 namespace mlir { 37 namespace runtime { 38 namespace { 39 40 // Forward declare class defined below. 41 class RefCounted; 42 43 // -------------------------------------------------------------------------- // 44 // AsyncRuntime orchestrates all async operations and Async runtime API is built 45 // on top of the default runtime instance. 46 // -------------------------------------------------------------------------- // 47 48 class AsyncRuntime { 49 public: 50 AsyncRuntime() : numRefCountedObjects(0) {} 51 52 ~AsyncRuntime() { 53 threadPool.wait(); // wait for the completion of all async tasks 54 assert(getNumRefCountedObjects() == 0 && 55 "all ref counted objects must be destroyed"); 56 } 57 58 int32_t getNumRefCountedObjects() { 59 return numRefCountedObjects.load(std::memory_order_relaxed); 60 } 61 62 llvm::ThreadPool &getThreadPool() { return threadPool; } 63 64 private: 65 friend class RefCounted; 66 67 // Count the total number of reference counted objects in this instance 68 // of an AsyncRuntime. For debugging purposes only. 69 void addNumRefCountedObjects() { 70 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 71 } 72 void dropNumRefCountedObjects() { 73 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 74 } 75 76 std::atomic<int32_t> numRefCountedObjects; 77 llvm::ThreadPool threadPool; 78 }; 79 80 // -------------------------------------------------------------------------- // 81 // A base class for all reference counted objects created by the async runtime. 82 // -------------------------------------------------------------------------- // 83 84 class RefCounted { 85 public: 86 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) 87 : runtime(runtime), refCount(refCount) { 88 runtime->addNumRefCountedObjects(); 89 } 90 91 virtual ~RefCounted() { 92 assert(refCount.load() == 0 && "reference count must be zero"); 93 runtime->dropNumRefCountedObjects(); 94 } 95 96 RefCounted(const RefCounted &) = delete; 97 RefCounted &operator=(const RefCounted &) = delete; 98 99 void addRef(int32_t count = 1) { refCount.fetch_add(count); } 100 101 void dropRef(int32_t count = 1) { 102 int32_t previous = refCount.fetch_sub(count); 103 assert(previous >= count && "reference count should not go below zero"); 104 if (previous == count) 105 destroy(); 106 } 107 108 protected: 109 virtual void destroy() { delete this; } 110 111 private: 112 AsyncRuntime *runtime; 113 std::atomic<int32_t> refCount; 114 }; 115 116 } // namespace 117 118 // Returns the default per-process instance of an async runtime. 119 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() { 120 static auto runtime = std::make_unique<AsyncRuntime>(); 121 return runtime; 122 } 123 124 static void resetDefaultAsyncRuntime() { 125 return getDefaultAsyncRuntimeInstance().reset(); 126 } 127 128 static AsyncRuntime *getDefaultAsyncRuntime() { 129 return getDefaultAsyncRuntimeInstance().get(); 130 } 131 132 // Async token provides a mechanism to signal asynchronous operation completion. 133 struct AsyncToken : public RefCounted { 134 // AsyncToken created with a reference count of 2 because it will be returned 135 // to the `async.execute` caller and also will be later on emplaced by the 136 // asynchronously executed task. If the caller immediately will drop its 137 // reference we must ensure that the token will be alive until the 138 // asynchronous operation is completed. 139 AsyncToken(AsyncRuntime *runtime) 140 : RefCounted(runtime, /*count=*/2), ready(false) {} 141 142 std::atomic<bool> ready; 143 144 // Pending awaiters are guarded by a mutex. 145 std::mutex mu; 146 std::condition_variable cv; 147 std::vector<std::function<void()>> awaiters; 148 }; 149 150 // Async value provides a mechanism to access the result of asynchronous 151 // operations. It owns the storage that is used to store/load the value of the 152 // underlying type, and a flag to signal if the value is ready or not. 153 struct AsyncValue : public RefCounted { 154 // AsyncValue similar to an AsyncToken created with a reference count of 2. 155 AsyncValue(AsyncRuntime *runtime, int32_t size) 156 : RefCounted(runtime, /*count=*/2), ready(false), storage(size) {} 157 158 std::atomic<bool> ready; 159 160 // Use vector of bytes to store async value payload. 161 std::vector<int8_t> storage; 162 163 // Pending awaiters are guarded by a mutex. 164 std::mutex mu; 165 std::condition_variable cv; 166 std::vector<std::function<void()>> awaiters; 167 }; 168 169 // Async group provides a mechanism to group together multiple async tokens or 170 // values to await on all of them together (wait for the completion of all 171 // tokens or values added to the group). 172 struct AsyncGroup : public RefCounted { 173 AsyncGroup(AsyncRuntime *runtime) 174 : RefCounted(runtime), pendingTokens(0), rank(0) {} 175 176 std::atomic<int> pendingTokens; 177 std::atomic<int> rank; 178 179 // Pending awaiters are guarded by a mutex. 180 std::mutex mu; 181 std::condition_variable cv; 182 std::vector<std::function<void()>> awaiters; 183 }; 184 185 } // namespace runtime 186 } // namespace mlir 187 188 // Adds references to reference counted runtime object. 189 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { 190 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 191 refCounted->addRef(count); 192 } 193 194 // Drops references from reference counted runtime object. 195 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { 196 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 197 refCounted->dropRef(count); 198 } 199 200 // Creates a new `async.token` in not-ready state. 201 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 202 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime()); 203 return token; 204 } 205 206 // Creates a new `async.value` in not-ready state. 207 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { 208 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size); 209 return value; 210 } 211 212 // Create a new `async.group` in empty state. 213 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { 214 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime()); 215 return group; 216 } 217 218 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, 219 AsyncGroup *group) { 220 std::unique_lock<std::mutex> lockToken(token->mu); 221 std::unique_lock<std::mutex> lockGroup(group->mu); 222 223 // Get the rank of the token inside the group before we drop the reference. 224 int rank = group->rank.fetch_add(1); 225 group->pendingTokens.fetch_add(1); 226 227 auto onTokenReady = [group]() { 228 // Run all group awaiters if it was the last token in the group. 229 if (group->pendingTokens.fetch_sub(1) == 1) { 230 group->cv.notify_all(); 231 for (auto &awaiter : group->awaiters) 232 awaiter(); 233 } 234 }; 235 236 if (token->ready) { 237 // Update group pending tokens immediately and maybe run awaiters. 238 onTokenReady(); 239 240 } else { 241 // Update group pending tokens when token will become ready. Because this 242 // will happen asynchronously we must ensure that `group` is alive until 243 // then, and re-ackquire the lock. 244 group->addRef(); 245 246 token->awaiters.push_back([group, onTokenReady]() { 247 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 248 { 249 std::unique_lock<std::mutex> lockGroup(group->mu); 250 onTokenReady(); 251 } 252 group->dropRef(); 253 }); 254 } 255 256 return rank; 257 } 258 259 // Switches `async.token` to ready state and runs all awaiters. 260 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 261 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 262 { 263 std::unique_lock<std::mutex> lock(token->mu); 264 token->ready = true; 265 token->cv.notify_all(); 266 for (auto &awaiter : token->awaiters) 267 awaiter(); 268 } 269 270 // Async tokens created with a ref count `2` to keep token alive until the 271 // async task completes. Drop this reference explicitly when token emplaced. 272 token->dropRef(); 273 } 274 275 // Switches `async.value` to ready state and runs all awaiters. 276 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { 277 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 278 { 279 std::unique_lock<std::mutex> lock(value->mu); 280 value->ready = true; 281 value->cv.notify_all(); 282 for (auto &awaiter : value->awaiters) 283 awaiter(); 284 } 285 286 // Async values created with a ref count `2` to keep value alive until the 287 // async task completes. Drop this reference explicitly when value emplaced. 288 value->dropRef(); 289 } 290 291 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 292 std::unique_lock<std::mutex> lock(token->mu); 293 if (!token->ready) 294 token->cv.wait(lock, [token] { return token->ready.load(); }); 295 } 296 297 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { 298 std::unique_lock<std::mutex> lock(value->mu); 299 if (!value->ready) 300 value->cv.wait(lock, [value] { return value->ready.load(); }); 301 } 302 303 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 304 std::unique_lock<std::mutex> lock(group->mu); 305 if (group->pendingTokens != 0) 306 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 307 } 308 309 // Returns a pointer to the storage owned by the async value. 310 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) { 311 return value->storage.data(); 312 } 313 314 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 315 auto *runtime = getDefaultAsyncRuntime(); 316 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); 317 } 318 319 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 320 CoroHandle handle, 321 CoroResume resume) { 322 auto execute = [handle, resume]() { (*resume)(handle); }; 323 if (token->ready) { 324 execute(); 325 } else { 326 std::unique_lock<std::mutex> lock(token->mu); 327 token->awaiters.push_back([execute]() { execute(); }); 328 } 329 } 330 331 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, 332 CoroHandle handle, 333 CoroResume resume) { 334 auto execute = [handle, resume]() { (*resume)(handle); }; 335 if (value->ready) { 336 execute(); 337 } else { 338 std::unique_lock<std::mutex> lock(value->mu); 339 value->awaiters.push_back([execute]() { execute(); }); 340 } 341 } 342 343 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, 344 CoroHandle handle, 345 CoroResume resume) { 346 auto execute = [handle, resume]() { (*resume)(handle); }; 347 if (group->pendingTokens == 0) { 348 execute(); 349 } else { 350 std::unique_lock<std::mutex> lock(group->mu); 351 group->awaiters.push_back([execute]() { execute(); }); 352 } 353 } 354 355 //===----------------------------------------------------------------------===// 356 // Small async runtime support library for testing. 357 //===----------------------------------------------------------------------===// 358 359 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 360 static thread_local std::thread::id thisId = std::this_thread::get_id(); 361 std::cout << "Current thread id: " << thisId << std::endl; 362 } 363 364 //===----------------------------------------------------------------------===// 365 // MLIR Runner (JitRunner) dynamic library integration. 366 //===----------------------------------------------------------------------===// 367 368 // Export symbols for the MLIR runner integration. All other symbols are hidden. 369 #ifndef _WIN32 370 #define API __attribute__((visibility("default"))) 371 372 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) { 373 auto exportSymbol = [&](llvm::StringRef name, auto ptr) { 374 assert(exportSymbols.count(name) == 0 && "symbol already exists"); 375 exportSymbols[name] = reinterpret_cast<void *>(ptr); 376 }; 377 378 exportSymbol("mlirAsyncRuntimeAddRef", 379 &mlir::runtime::mlirAsyncRuntimeAddRef); 380 exportSymbol("mlirAsyncRuntimeDropRef", 381 &mlir::runtime::mlirAsyncRuntimeDropRef); 382 exportSymbol("mlirAsyncRuntimeExecute", 383 &mlir::runtime::mlirAsyncRuntimeExecute); 384 exportSymbol("mlirAsyncRuntimeGetValueStorage", 385 &mlir::runtime::mlirAsyncRuntimeGetValueStorage); 386 exportSymbol("mlirAsyncRuntimeCreateToken", 387 &mlir::runtime::mlirAsyncRuntimeCreateToken); 388 exportSymbol("mlirAsyncRuntimeCreateValue", 389 &mlir::runtime::mlirAsyncRuntimeCreateValue); 390 exportSymbol("mlirAsyncRuntimeEmplaceToken", 391 &mlir::runtime::mlirAsyncRuntimeEmplaceToken); 392 exportSymbol("mlirAsyncRuntimeEmplaceValue", 393 &mlir::runtime::mlirAsyncRuntimeEmplaceValue); 394 exportSymbol("mlirAsyncRuntimeAwaitToken", 395 &mlir::runtime::mlirAsyncRuntimeAwaitToken); 396 exportSymbol("mlirAsyncRuntimeAwaitValue", 397 &mlir::runtime::mlirAsyncRuntimeAwaitValue); 398 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute", 399 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute); 400 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute", 401 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute); 402 exportSymbol("mlirAsyncRuntimeCreateGroup", 403 &mlir::runtime::mlirAsyncRuntimeCreateGroup); 404 exportSymbol("mlirAsyncRuntimeAddTokenToGroup", 405 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup); 406 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup", 407 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup); 408 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute", 409 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute); 410 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId", 411 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); 412 } 413 414 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); } 415 416 #endif // _WIN32 417 418 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 419