1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 using namespace mlir::runtime; 28 29 //===----------------------------------------------------------------------===// 30 // Async runtime API. 31 //===----------------------------------------------------------------------===// 32 33 namespace mlir { 34 namespace runtime { 35 namespace { 36 37 // Forward declare class defined below. 38 class RefCounted; 39 40 // -------------------------------------------------------------------------- // 41 // AsyncRuntime orchestrates all async operations and Async runtime API is built 42 // on top of the default runtime instance. 43 // -------------------------------------------------------------------------- // 44 45 class AsyncRuntime { 46 public: 47 AsyncRuntime() : numRefCountedObjects(0) {} 48 49 ~AsyncRuntime() { 50 assert(getNumRefCountedObjects() == 0 && 51 "all ref counted objects must be destroyed"); 52 } 53 54 int32_t getNumRefCountedObjects() { 55 return numRefCountedObjects.load(std::memory_order_relaxed); 56 } 57 58 private: 59 friend class RefCounted; 60 61 // Count the total number of reference counted objects in this instance 62 // of an AsyncRuntime. For debugging purposes only. 63 void addNumRefCountedObjects() { 64 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 65 } 66 void dropNumRefCountedObjects() { 67 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 68 } 69 70 std::atomic<int32_t> numRefCountedObjects; 71 }; 72 73 // -------------------------------------------------------------------------- // 74 // A base class for all reference counted objects created by the async runtime. 75 // -------------------------------------------------------------------------- // 76 77 class RefCounted { 78 public: 79 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) 80 : runtime(runtime), refCount(refCount) { 81 runtime->addNumRefCountedObjects(); 82 } 83 84 virtual ~RefCounted() { 85 assert(refCount.load() == 0 && "reference count must be zero"); 86 runtime->dropNumRefCountedObjects(); 87 } 88 89 RefCounted(const RefCounted &) = delete; 90 RefCounted &operator=(const RefCounted &) = delete; 91 92 void addRef(int32_t count = 1) { refCount.fetch_add(count); } 93 94 void dropRef(int32_t count = 1) { 95 int32_t previous = refCount.fetch_sub(count); 96 assert(previous >= count && "reference count should not go below zero"); 97 if (previous == count) 98 destroy(); 99 } 100 101 protected: 102 virtual void destroy() { delete this; } 103 104 private: 105 AsyncRuntime *runtime; 106 std::atomic<int32_t> refCount; 107 }; 108 109 } // namespace 110 111 // Returns the default per-process instance of an async runtime. 112 static AsyncRuntime *getDefaultAsyncRuntimeInstance() { 113 static auto runtime = std::make_unique<AsyncRuntime>(); 114 return runtime.get(); 115 } 116 117 // Async token provides a mechanism to signal asynchronous operation completion. 118 struct AsyncToken : public RefCounted { 119 // AsyncToken created with a reference count of 2 because it will be returned 120 // to the `async.execute` caller and also will be later on emplaced by the 121 // asynchronously executed task. If the caller immediately will drop its 122 // reference we must ensure that the token will be alive until the 123 // asynchronous operation is completed. 124 AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} 125 126 // Internal state below guarded by a mutex. 127 std::mutex mu; 128 std::condition_variable cv; 129 130 bool ready = false; 131 std::vector<std::function<void()>> awaiters; 132 }; 133 134 // Async value provides a mechanism to access the result of asynchronous 135 // operations. It owns the storage that is used to store/load the value of the 136 // underlying type, and a flag to signal if the value is ready or not. 137 struct AsyncValue : public RefCounted { 138 // AsyncValue similar to an AsyncToken created with a reference count of 2. 139 AsyncValue(AsyncRuntime *runtime, int32_t size) 140 : RefCounted(runtime, /*count=*/2), storage(size) {} 141 142 // Internal state below guarded by a mutex. 143 std::mutex mu; 144 std::condition_variable cv; 145 146 bool ready = false; 147 std::vector<std::function<void()>> awaiters; 148 149 // Use vector of bytes to store async value payload. 150 std::vector<int8_t> storage; 151 }; 152 153 // Async group provides a mechanism to group together multiple async tokens or 154 // values to await on all of them together (wait for the completion of all 155 // tokens or values added to the group). 156 struct AsyncGroup : public RefCounted { 157 AsyncGroup(AsyncRuntime *runtime) 158 : RefCounted(runtime), pendingTokens(0), rank(0) {} 159 160 std::atomic<int> pendingTokens; 161 std::atomic<int> rank; 162 163 // Internal state below guarded by a mutex. 164 std::mutex mu; 165 std::condition_variable cv; 166 167 std::vector<std::function<void()>> awaiters; 168 }; 169 170 } // namespace runtime 171 } // namespace mlir 172 173 // Adds references to reference counted runtime object. 174 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { 175 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 176 refCounted->addRef(count); 177 } 178 179 // Drops references from reference counted runtime object. 180 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { 181 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 182 refCounted->dropRef(count); 183 } 184 185 // Creates a new `async.token` in not-ready state. 186 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 187 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); 188 return token; 189 } 190 191 // Creates a new `async.value` in not-ready state. 192 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) { 193 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size); 194 return value; 195 } 196 197 // Create a new `async.group` in empty state. 198 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { 199 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); 200 return group; 201 } 202 203 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, 204 AsyncGroup *group) { 205 std::unique_lock<std::mutex> lockToken(token->mu); 206 std::unique_lock<std::mutex> lockGroup(group->mu); 207 208 // Get the rank of the token inside the group before we drop the reference. 209 int rank = group->rank.fetch_add(1); 210 group->pendingTokens.fetch_add(1); 211 212 auto onTokenReady = [group]() { 213 // Run all group awaiters if it was the last token in the group. 214 if (group->pendingTokens.fetch_sub(1) == 1) { 215 group->cv.notify_all(); 216 for (auto &awaiter : group->awaiters) 217 awaiter(); 218 } 219 }; 220 221 if (token->ready) { 222 // Update group pending tokens immediately and maybe run awaiters. 223 onTokenReady(); 224 225 } else { 226 // Update group pending tokens when token will become ready. Because this 227 // will happen asynchronously we must ensure that `group` is alive until 228 // then, and re-ackquire the lock. 229 group->addRef(); 230 231 token->awaiters.push_back([group, onTokenReady]() { 232 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 233 { 234 std::unique_lock<std::mutex> lockGroup(group->mu); 235 onTokenReady(); 236 } 237 group->dropRef(); 238 }); 239 } 240 241 return rank; 242 } 243 244 // Switches `async.token` to ready state and runs all awaiters. 245 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 246 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 247 { 248 std::unique_lock<std::mutex> lock(token->mu); 249 token->ready = true; 250 token->cv.notify_all(); 251 for (auto &awaiter : token->awaiters) 252 awaiter(); 253 } 254 255 // Async tokens created with a ref count `2` to keep token alive until the 256 // async task completes. Drop this reference explicitly when token emplaced. 257 token->dropRef(); 258 } 259 260 // Switches `async.value` to ready state and runs all awaiters. 261 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { 262 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 263 { 264 std::unique_lock<std::mutex> lock(value->mu); 265 value->ready = true; 266 value->cv.notify_all(); 267 for (auto &awaiter : value->awaiters) 268 awaiter(); 269 } 270 271 // Async values created with a ref count `2` to keep value alive until the 272 // async task completes. Drop this reference explicitly when value emplaced. 273 value->dropRef(); 274 } 275 276 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 277 std::unique_lock<std::mutex> lock(token->mu); 278 if (!token->ready) 279 token->cv.wait(lock, [token] { return token->ready; }); 280 } 281 282 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { 283 std::unique_lock<std::mutex> lock(value->mu); 284 if (!value->ready) 285 value->cv.wait(lock, [value] { return value->ready; }); 286 } 287 288 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 289 std::unique_lock<std::mutex> lock(group->mu); 290 if (group->pendingTokens != 0) 291 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 292 } 293 294 // Returns a pointer to the storage owned by the async value. 295 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) { 296 return value->storage.data(); 297 } 298 299 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 300 (*resume)(handle); 301 } 302 303 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 304 CoroHandle handle, 305 CoroResume resume) { 306 std::unique_lock<std::mutex> lock(token->mu); 307 auto execute = [handle, resume]() { (*resume)(handle); }; 308 if (token->ready) 309 execute(); 310 else 311 token->awaiters.push_back([execute]() { execute(); }); 312 } 313 314 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, 315 CoroHandle handle, 316 CoroResume resume) { 317 std::unique_lock<std::mutex> lock(value->mu); 318 auto execute = [handle, resume]() { (*resume)(handle); }; 319 if (value->ready) 320 execute(); 321 else 322 value->awaiters.push_back([execute]() { execute(); }); 323 } 324 325 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, 326 CoroHandle handle, 327 CoroResume resume) { 328 std::unique_lock<std::mutex> lock(group->mu); 329 auto execute = [handle, resume]() { (*resume)(handle); }; 330 if (group->pendingTokens == 0) 331 execute(); 332 else 333 group->awaiters.push_back([execute]() { execute(); }); 334 } 335 336 //===----------------------------------------------------------------------===// 337 // Small async runtime support library for testing. 338 //===----------------------------------------------------------------------===// 339 340 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 341 static thread_local std::thread::id thisId = std::this_thread::get_id(); 342 std::cout << "Current thread id: " << thisId << std::endl; 343 } 344 345 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 346