1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 #include "llvm/ADT/StringMap.h" 28 29 using namespace mlir::runtime; 30 31 //===----------------------------------------------------------------------===// 32 // Async runtime API. 33 //===----------------------------------------------------------------------===// 34 35 namespace mlir { 36 namespace runtime { 37 namespace { 38 39 // Forward declare class defined below. 40 class RefCounted; 41 42 // -------------------------------------------------------------------------- // 43 // AsyncRuntime orchestrates all async operations and Async runtime API is built 44 // on top of the default runtime instance. 45 // -------------------------------------------------------------------------- // 46 47 class AsyncRuntime { 48 public: 49 AsyncRuntime() : numRefCountedObjects(0) {} 50 51 ~AsyncRuntime() { 52 assert(getNumRefCountedObjects() == 0 && 53 "all ref counted objects must be destroyed"); 54 } 55 56 int32_t getNumRefCountedObjects() { 57 return numRefCountedObjects.load(std::memory_order_relaxed); 58 } 59 60 private: 61 friend class RefCounted; 62 63 // Count the total number of reference counted objects in this instance 64 // of an AsyncRuntime. For debugging purposes only. 65 void addNumRefCountedObjects() { 66 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 67 } 68 void dropNumRefCountedObjects() { 69 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 70 } 71 72 std::atomic<int32_t> numRefCountedObjects; 73 }; 74 75 // -------------------------------------------------------------------------- // 76 // A base class for all reference counted objects created by the async runtime. 77 // -------------------------------------------------------------------------- // 78 79 class RefCounted { 80 public: 81 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) 82 : runtime(runtime), refCount(refCount) { 83 runtime->addNumRefCountedObjects(); 84 } 85 86 virtual ~RefCounted() { 87 assert(refCount.load() == 0 && "reference count must be zero"); 88 runtime->dropNumRefCountedObjects(); 89 } 90 91 RefCounted(const RefCounted &) = delete; 92 RefCounted &operator=(const RefCounted &) = delete; 93 94 void addRef(int32_t count = 1) { refCount.fetch_add(count); } 95 96 void dropRef(int32_t count = 1) { 97 int32_t previous = refCount.fetch_sub(count); 98 assert(previous >= count && "reference count should not go below zero"); 99 if (previous == count) 100 destroy(); 101 } 102 103 protected: 104 virtual void destroy() { delete this; } 105 106 private: 107 AsyncRuntime *runtime; 108 std::atomic<int32_t> refCount; 109 }; 110 111 } // namespace 112 113 // Returns the default per-process instance of an async runtime. 114 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() { 115 static auto runtime = std::make_unique<AsyncRuntime>(); 116 return runtime; 117 } 118 119 static void resetDefaultAsyncRuntime() { 120 return getDefaultAsyncRuntimeInstance().reset(); 121 } 122 123 static AsyncRuntime *getDefaultAsyncRuntime() { 124 return getDefaultAsyncRuntimeInstance().get(); 125 } 126 127 // Async token provides a mechanism to signal asynchronous operation completion. 128 struct AsyncToken : public RefCounted { 129 // AsyncToken created with a reference count of 2 because it will be returned 130 // to the `async.execute` caller and also will be later on emplaced by the 131 // asynchronously executed task. If the caller immediately will drop its 132 // reference we must ensure that the token will be alive until the 133 // asynchronous operation is completed. 134 AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} 135 136 // Internal state below guarded by a mutex. 137 std::mutex mu; 138 std::condition_variable cv; 139 140 bool ready = false; 141 std::vector<std::function<void()>> awaiters; 142 }; 143 144 // Async value provides a mechanism to access the result of asynchronous 145 // operations. It owns the storage that is used to store/load the value of the 146 // underlying type, and a flag to signal if the value is ready or not. 147 struct AsyncValue : public RefCounted { 148 // AsyncValue similar to an AsyncToken created with a reference count of 2. 149 AsyncValue(AsyncRuntime *runtime, int32_t size) 150 : RefCounted(runtime, /*count=*/2), storage(size) {} 151 152 // Internal state below guarded by a mutex. 153 std::mutex mu; 154 std::condition_variable cv; 155 156 bool ready = false; 157 std::vector<std::function<void()>> awaiters; 158 159 // Use vector of bytes to store async value payload. 160 std::vector<int8_t> storage; 161 }; 162 163 // Async group provides a mechanism to group together multiple async tokens or 164 // values to await on all of them together (wait for the completion of all 165 // tokens or values added to the group). 166 struct AsyncGroup : public RefCounted { 167 AsyncGroup(AsyncRuntime *runtime) 168 : RefCounted(runtime), pendingTokens(0), rank(0) {} 169 170 std::atomic<int> pendingTokens; 171 std::atomic<int> rank; 172 173 // Internal state below guarded by a mutex. 174 std::mutex mu; 175 std::condition_variable cv; 176 177 std::vector<std::function<void()>> awaiters; 178 }; 179 180 } // namespace runtime 181 } // namespace mlir 182 183 // Adds references to reference counted runtime object. 184 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { 185 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 186 refCounted->addRef(count); 187 } 188 189 // Drops references from reference counted runtime object. 190 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { 191 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 192 refCounted->dropRef(count); 193 } 194 195 // Creates a new `async.token` in not-ready state. 196 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 197 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime()); 198 return token; 199 } 200 201 // Creates a new `async.value` in not-ready state. 202 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { 203 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size); 204 return value; 205 } 206 207 // Create a new `async.group` in empty state. 208 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { 209 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime()); 210 return group; 211 } 212 213 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, 214 AsyncGroup *group) { 215 std::unique_lock<std::mutex> lockToken(token->mu); 216 std::unique_lock<std::mutex> lockGroup(group->mu); 217 218 // Get the rank of the token inside the group before we drop the reference. 219 int rank = group->rank.fetch_add(1); 220 group->pendingTokens.fetch_add(1); 221 222 auto onTokenReady = [group]() { 223 // Run all group awaiters if it was the last token in the group. 224 if (group->pendingTokens.fetch_sub(1) == 1) { 225 group->cv.notify_all(); 226 for (auto &awaiter : group->awaiters) 227 awaiter(); 228 } 229 }; 230 231 if (token->ready) { 232 // Update group pending tokens immediately and maybe run awaiters. 233 onTokenReady(); 234 235 } else { 236 // Update group pending tokens when token will become ready. Because this 237 // will happen asynchronously we must ensure that `group` is alive until 238 // then, and re-ackquire the lock. 239 group->addRef(); 240 241 token->awaiters.push_back([group, onTokenReady]() { 242 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 243 { 244 std::unique_lock<std::mutex> lockGroup(group->mu); 245 onTokenReady(); 246 } 247 group->dropRef(); 248 }); 249 } 250 251 return rank; 252 } 253 254 // Switches `async.token` to ready state and runs all awaiters. 255 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 256 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 257 { 258 std::unique_lock<std::mutex> lock(token->mu); 259 token->ready = true; 260 token->cv.notify_all(); 261 for (auto &awaiter : token->awaiters) 262 awaiter(); 263 } 264 265 // Async tokens created with a ref count `2` to keep token alive until the 266 // async task completes. Drop this reference explicitly when token emplaced. 267 token->dropRef(); 268 } 269 270 // Switches `async.value` to ready state and runs all awaiters. 271 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { 272 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 273 { 274 std::unique_lock<std::mutex> lock(value->mu); 275 value->ready = true; 276 value->cv.notify_all(); 277 for (auto &awaiter : value->awaiters) 278 awaiter(); 279 } 280 281 // Async values created with a ref count `2` to keep value alive until the 282 // async task completes. Drop this reference explicitly when value emplaced. 283 value->dropRef(); 284 } 285 286 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 287 std::unique_lock<std::mutex> lock(token->mu); 288 if (!token->ready) 289 token->cv.wait(lock, [token] { return token->ready; }); 290 } 291 292 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { 293 std::unique_lock<std::mutex> lock(value->mu); 294 if (!value->ready) 295 value->cv.wait(lock, [value] { return value->ready; }); 296 } 297 298 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 299 std::unique_lock<std::mutex> lock(group->mu); 300 if (group->pendingTokens != 0) 301 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 302 } 303 304 // Returns a pointer to the storage owned by the async value. 305 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) { 306 return value->storage.data(); 307 } 308 309 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 310 (*resume)(handle); 311 } 312 313 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 314 CoroHandle handle, 315 CoroResume resume) { 316 std::unique_lock<std::mutex> lock(token->mu); 317 auto execute = [handle, resume]() { (*resume)(handle); }; 318 if (token->ready) 319 execute(); 320 else 321 token->awaiters.push_back([execute]() { execute(); }); 322 } 323 324 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, 325 CoroHandle handle, 326 CoroResume resume) { 327 std::unique_lock<std::mutex> lock(value->mu); 328 auto execute = [handle, resume]() { (*resume)(handle); }; 329 if (value->ready) 330 execute(); 331 else 332 value->awaiters.push_back([execute]() { execute(); }); 333 } 334 335 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, 336 CoroHandle handle, 337 CoroResume resume) { 338 std::unique_lock<std::mutex> lock(group->mu); 339 auto execute = [handle, resume]() { (*resume)(handle); }; 340 if (group->pendingTokens == 0) 341 execute(); 342 else 343 group->awaiters.push_back([execute]() { execute(); }); 344 } 345 346 //===----------------------------------------------------------------------===// 347 // Small async runtime support library for testing. 348 //===----------------------------------------------------------------------===// 349 350 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 351 static thread_local std::thread::id thisId = std::this_thread::get_id(); 352 std::cout << "Current thread id: " << thisId << std::endl; 353 } 354 355 //===----------------------------------------------------------------------===// 356 // MLIR Runner (JitRunner) dynamic library integration. 357 //===----------------------------------------------------------------------===// 358 359 // Export symbols for the MLIR runner integration. All other symbols are hidden. 360 #ifndef _WIN32 361 #define API __attribute__((visibility("default"))) 362 363 extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) { 364 auto exportSymbol = [&](llvm::StringRef name, auto ptr) { 365 assert(exportSymbols.count(name) == 0 && "symbol already exists"); 366 exportSymbols[name] = reinterpret_cast<void *>(ptr); 367 }; 368 369 exportSymbol("mlirAsyncRuntimeAddRef", 370 &mlir::runtime::mlirAsyncRuntimeAddRef); 371 exportSymbol("mlirAsyncRuntimeDropRef", 372 &mlir::runtime::mlirAsyncRuntimeDropRef); 373 exportSymbol("mlirAsyncRuntimeExecute", 374 &mlir::runtime::mlirAsyncRuntimeExecute); 375 exportSymbol("mlirAsyncRuntimeGetValueStorage", 376 &mlir::runtime::mlirAsyncRuntimeGetValueStorage); 377 exportSymbol("mlirAsyncRuntimeCreateToken", 378 &mlir::runtime::mlirAsyncRuntimeCreateToken); 379 exportSymbol("mlirAsyncRuntimeCreateValue", 380 &mlir::runtime::mlirAsyncRuntimeCreateValue); 381 exportSymbol("mlirAsyncRuntimeEmplaceToken", 382 &mlir::runtime::mlirAsyncRuntimeEmplaceToken); 383 exportSymbol("mlirAsyncRuntimeEmplaceValue", 384 &mlir::runtime::mlirAsyncRuntimeEmplaceValue); 385 exportSymbol("mlirAsyncRuntimeAwaitToken", 386 &mlir::runtime::mlirAsyncRuntimeAwaitToken); 387 exportSymbol("mlirAsyncRuntimeAwaitValue", 388 &mlir::runtime::mlirAsyncRuntimeAwaitValue); 389 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute", 390 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute); 391 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute", 392 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute); 393 exportSymbol("mlirAsyncRuntimeCreateGroup", 394 &mlir::runtime::mlirAsyncRuntimeCreateGroup); 395 exportSymbol("mlirAsyncRuntimeAddTokenToGroup", 396 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup); 397 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup", 398 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup); 399 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute", 400 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute); 401 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId", 402 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); 403 } 404 405 extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); } 406 407 #endif // _WIN32 408 409 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 410