1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 #include "llvm/ADT/StringMap.h" 28 #include "llvm/Support/ThreadPool.h" 29 30 using namespace mlir::runtime; 31 32 //===----------------------------------------------------------------------===// 33 // Async runtime API. 34 //===----------------------------------------------------------------------===// 35 36 namespace mlir { 37 namespace runtime { 38 namespace { 39 40 // Forward declare class defined below. 41 class RefCounted; 42 43 // -------------------------------------------------------------------------- // 44 // AsyncRuntime orchestrates all async operations and Async runtime API is built 45 // on top of the default runtime instance. 46 // -------------------------------------------------------------------------- // 47 48 class AsyncRuntime { 49 public: 50 AsyncRuntime() : numRefCountedObjects(0) {} 51 52 ~AsyncRuntime() { 53 threadPool.wait(); // wait for the completion of all async tasks 54 assert(getNumRefCountedObjects() == 0 && 55 "all ref counted objects must be destroyed"); 56 } 57 58 int32_t getNumRefCountedObjects() { 59 return numRefCountedObjects.load(std::memory_order_relaxed); 60 } 61 62 llvm::ThreadPool &getThreadPool() { return threadPool; } 63 64 private: 65 friend class RefCounted; 66 67 // Count the total number of reference counted objects in this instance 68 // of an AsyncRuntime. For debugging purposes only. 69 void addNumRefCountedObjects() { 70 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 71 } 72 void dropNumRefCountedObjects() { 73 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 74 } 75 76 std::atomic<int32_t> numRefCountedObjects; 77 llvm::ThreadPool threadPool; 78 }; 79 80 // -------------------------------------------------------------------------- // 81 // A base class for all reference counted objects created by the async runtime. 82 // -------------------------------------------------------------------------- // 83 84 class RefCounted { 85 public: 86 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) 87 : runtime(runtime), refCount(refCount) { 88 runtime->addNumRefCountedObjects(); 89 } 90 91 virtual ~RefCounted() { 92 assert(refCount.load() == 0 && "reference count must be zero"); 93 runtime->dropNumRefCountedObjects(); 94 } 95 96 RefCounted(const RefCounted &) = delete; 97 RefCounted &operator=(const RefCounted &) = delete; 98 99 void addRef(int32_t count = 1) { refCount.fetch_add(count); } 100 101 void dropRef(int32_t count = 1) { 102 int32_t previous = refCount.fetch_sub(count); 103 assert(previous >= count && "reference count should not go below zero"); 104 if (previous == count) 105 destroy(); 106 } 107 108 protected: 109 virtual void destroy() { delete this; } 110 111 private: 112 AsyncRuntime *runtime; 113 std::atomic<int32_t> refCount; 114 }; 115 116 } // namespace 117 118 // Returns the default per-process instance of an async runtime. 119 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() { 120 static auto runtime = std::make_unique<AsyncRuntime>(); 121 return runtime; 122 } 123 124 static void resetDefaultAsyncRuntime() { 125 return getDefaultAsyncRuntimeInstance().reset(); 126 } 127 128 static AsyncRuntime *getDefaultAsyncRuntime() { 129 return getDefaultAsyncRuntimeInstance().get(); 130 } 131 132 // Async token provides a mechanism to signal asynchronous operation completion. 133 struct AsyncToken : public RefCounted { 134 // AsyncToken created with a reference count of 2 because it will be returned 135 // to the `async.execute` caller and also will be later on emplaced by the 136 // asynchronously executed task. If the caller immediately will drop its 137 // reference we must ensure that the token will be alive until the 138 // asynchronous operation is completed. 139 AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} 140 141 // Internal state below guarded by a mutex. 142 std::mutex mu; 143 std::condition_variable cv; 144 145 bool ready = false; 146 std::vector<std::function<void()>> awaiters; 147 }; 148 149 // Async value provides a mechanism to access the result of asynchronous 150 // operations. It owns the storage that is used to store/load the value of the 151 // underlying type, and a flag to signal if the value is ready or not. 152 struct AsyncValue : public RefCounted { 153 // AsyncValue similar to an AsyncToken created with a reference count of 2. 154 AsyncValue(AsyncRuntime *runtime, int32_t size) 155 : RefCounted(runtime, /*count=*/2), storage(size) {} 156 157 // Internal state below guarded by a mutex. 158 std::mutex mu; 159 std::condition_variable cv; 160 161 bool ready = false; 162 std::vector<std::function<void()>> awaiters; 163 164 // Use vector of bytes to store async value payload. 165 std::vector<int8_t> storage; 166 }; 167 168 // Async group provides a mechanism to group together multiple async tokens or 169 // values to await on all of them together (wait for the completion of all 170 // tokens or values added to the group). 171 struct AsyncGroup : public RefCounted { 172 AsyncGroup(AsyncRuntime *runtime) 173 : RefCounted(runtime), pendingTokens(0), rank(0) {} 174 175 std::atomic<int> pendingTokens; 176 std::atomic<int> rank; 177 178 // Internal state below guarded by a mutex. 179 std::mutex mu; 180 std::condition_variable cv; 181 182 std::vector<std::function<void()>> awaiters; 183 }; 184 185 } // namespace runtime 186 } // namespace mlir 187 188 // Adds references to reference counted runtime object. 189 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { 190 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 191 refCounted->addRef(count); 192 } 193 194 // Drops references from reference counted runtime object. 195 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { 196 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 197 refCounted->dropRef(count); 198 } 199 200 // Creates a new `async.token` in not-ready state. 201 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 202 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime()); 203 return token; 204 } 205 206 // Creates a new `async.value` in not-ready state. 207 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { 208 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size); 209 return value; 210 } 211 212 // Create a new `async.group` in empty state. 213 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { 214 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime()); 215 return group; 216 } 217 218 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, 219 AsyncGroup *group) { 220 std::unique_lock<std::mutex> lockToken(token->mu); 221 std::unique_lock<std::mutex> lockGroup(group->mu); 222 223 // Get the rank of the token inside the group before we drop the reference. 224 int rank = group->rank.fetch_add(1); 225 group->pendingTokens.fetch_add(1); 226 227 auto onTokenReady = [group]() { 228 // Run all group awaiters if it was the last token in the group. 229 if (group->pendingTokens.fetch_sub(1) == 1) { 230 group->cv.notify_all(); 231 for (auto &awaiter : group->awaiters) 232 awaiter(); 233 } 234 }; 235 236 if (token->ready) { 237 // Update group pending tokens immediately and maybe run awaiters. 238 onTokenReady(); 239 240 } else { 241 // Update group pending tokens when token will become ready. Because this 242 // will happen asynchronously we must ensure that `group` is alive until 243 // then, and re-ackquire the lock. 244 group->addRef(); 245 246 token->awaiters.push_back([group, onTokenReady]() { 247 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 248 { 249 std::unique_lock<std::mutex> lockGroup(group->mu); 250 onTokenReady(); 251 } 252 group->dropRef(); 253 }); 254 } 255 256 return rank; 257 } 258 259 // Switches `async.token` to ready state and runs all awaiters. 260 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 261 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 262 { 263 std::unique_lock<std::mutex> lock(token->mu); 264 token->ready = true; 265 token->cv.notify_all(); 266 for (auto &awaiter : token->awaiters) 267 awaiter(); 268 } 269 270 // Async tokens created with a ref count `2` to keep token alive until the 271 // async task completes. Drop this reference explicitly when token emplaced. 272 token->dropRef(); 273 } 274 275 // Switches `async.value` to ready state and runs all awaiters. 276 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { 277 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 278 { 279 std::unique_lock<std::mutex> lock(value->mu); 280 value->ready = true; 281 value->cv.notify_all(); 282 for (auto &awaiter : value->awaiters) 283 awaiter(); 284 } 285 286 // Async values created with a ref count `2` to keep value alive until the 287 // async task completes. Drop this reference explicitly when value emplaced. 288 value->dropRef(); 289 } 290 291 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 292 std::unique_lock<std::mutex> lock(token->mu); 293 if (!token->ready) 294 token->cv.wait(lock, [token] { return token->ready; }); 295 } 296 297 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { 298 std::unique_lock<std::mutex> lock(value->mu); 299 if (!value->ready) 300 value->cv.wait(lock, [value] { return value->ready; }); 301 } 302 303 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 304 std::unique_lock<std::mutex> lock(group->mu); 305 if (group->pendingTokens != 0) 306 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 307 } 308 309 // Returns a pointer to the storage owned by the async value. 310 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) { 311 return value->storage.data(); 312 } 313 314 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 315 auto *runtime = getDefaultAsyncRuntime(); 316 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); 317 } 318 319 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 320 CoroHandle handle, 321 CoroResume resume) { 322 std::unique_lock<std::mutex> lock(token->mu); 323 auto execute = [handle, resume]() { (*resume)(handle); }; 324 if (token->ready) 325 execute(); 326 else 327 token->awaiters.push_back([execute]() { execute(); }); 328 } 329 330 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, 331 CoroHandle handle, 332 CoroResume resume) { 333 std::unique_lock<std::mutex> lock(value->mu); 334 auto execute = [handle, resume]() { (*resume)(handle); }; 335 if (value->ready) 336 execute(); 337 else 338 value->awaiters.push_back([execute]() { execute(); }); 339 } 340 341 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, 342 CoroHandle handle, 343 CoroResume resume) { 344 std::unique_lock<std::mutex> lock(group->mu); 345 auto execute = [handle, resume]() { (*resume)(handle); }; 346 if (group->pendingTokens == 0) 347 execute(); 348 else 349 group->awaiters.push_back([execute]() { execute(); }); 350 } 351 352 //===----------------------------------------------------------------------===// 353 // Small async runtime support library for testing. 354 //===----------------------------------------------------------------------===// 355 356 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 357 static thread_local std::thread::id thisId = std::this_thread::get_id(); 358 std::cout << "Current thread id: " << thisId << std::endl; 359 } 360 361 //===----------------------------------------------------------------------===// 362 // MLIR Runner (JitRunner) dynamic library integration. 363 //===----------------------------------------------------------------------===// 364 365 // Export symbols for the MLIR runner integration. All other symbols are hidden. 366 #ifndef _WIN32 367 #define API __attribute__((visibility("default"))) 368 369 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) { 370 auto exportSymbol = [&](llvm::StringRef name, auto ptr) { 371 assert(exportSymbols.count(name) == 0 && "symbol already exists"); 372 exportSymbols[name] = reinterpret_cast<void *>(ptr); 373 }; 374 375 exportSymbol("mlirAsyncRuntimeAddRef", 376 &mlir::runtime::mlirAsyncRuntimeAddRef); 377 exportSymbol("mlirAsyncRuntimeDropRef", 378 &mlir::runtime::mlirAsyncRuntimeDropRef); 379 exportSymbol("mlirAsyncRuntimeExecute", 380 &mlir::runtime::mlirAsyncRuntimeExecute); 381 exportSymbol("mlirAsyncRuntimeGetValueStorage", 382 &mlir::runtime::mlirAsyncRuntimeGetValueStorage); 383 exportSymbol("mlirAsyncRuntimeCreateToken", 384 &mlir::runtime::mlirAsyncRuntimeCreateToken); 385 exportSymbol("mlirAsyncRuntimeCreateValue", 386 &mlir::runtime::mlirAsyncRuntimeCreateValue); 387 exportSymbol("mlirAsyncRuntimeEmplaceToken", 388 &mlir::runtime::mlirAsyncRuntimeEmplaceToken); 389 exportSymbol("mlirAsyncRuntimeEmplaceValue", 390 &mlir::runtime::mlirAsyncRuntimeEmplaceValue); 391 exportSymbol("mlirAsyncRuntimeAwaitToken", 392 &mlir::runtime::mlirAsyncRuntimeAwaitToken); 393 exportSymbol("mlirAsyncRuntimeAwaitValue", 394 &mlir::runtime::mlirAsyncRuntimeAwaitValue); 395 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute", 396 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute); 397 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute", 398 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute); 399 exportSymbol("mlirAsyncRuntimeCreateGroup", 400 &mlir::runtime::mlirAsyncRuntimeCreateGroup); 401 exportSymbol("mlirAsyncRuntimeAddTokenToGroup", 402 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup); 403 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup", 404 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup); 405 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute", 406 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute); 407 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId", 408 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); 409 } 410 411 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); } 412 413 #endif // _WIN32 414 415 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 416