1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 //===----------------------------------------------------------------------===// 28 // Async runtime API. 29 //===----------------------------------------------------------------------===// 30 31 namespace { 32 33 // Forward declare class defined below. 34 class RefCounted; 35 36 // -------------------------------------------------------------------------- // 37 // AsyncRuntime orchestrates all async operations and Async runtime API is built 38 // on top of the default runtime instance. 39 // -------------------------------------------------------------------------- // 40 41 class AsyncRuntime { 42 public: 43 AsyncRuntime() : numRefCountedObjects(0) {} 44 45 ~AsyncRuntime() { 46 assert(getNumRefCountedObjects() == 0 && 47 "all ref counted objects must be destroyed"); 48 } 49 50 int32_t getNumRefCountedObjects() { 51 return numRefCountedObjects.load(std::memory_order_relaxed); 52 } 53 54 private: 55 friend class RefCounted; 56 57 // Count the total number of reference counted objects in this instance 58 // of an AsyncRuntime. For debugging purposes only. 59 void addNumRefCountedObjects() { 60 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 61 } 62 void dropNumRefCountedObjects() { 63 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 64 } 65 66 std::atomic<int32_t> numRefCountedObjects; 67 }; 68 69 // Returns the default per-process instance of an async runtime. 70 AsyncRuntime *getDefaultAsyncRuntimeInstance() { 71 static auto runtime = std::make_unique<AsyncRuntime>(); 72 return runtime.get(); 73 } 74 75 // -------------------------------------------------------------------------- // 76 // A base class for all reference counted objects created by the async runtime. 77 // -------------------------------------------------------------------------- // 78 79 class RefCounted { 80 public: 81 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) 82 : runtime(runtime), refCount(refCount) { 83 runtime->addNumRefCountedObjects(); 84 } 85 86 virtual ~RefCounted() { 87 assert(refCount.load() == 0 && "reference count must be zero"); 88 runtime->dropNumRefCountedObjects(); 89 } 90 91 RefCounted(const RefCounted &) = delete; 92 RefCounted &operator=(const RefCounted &) = delete; 93 94 void addRef(int32_t count = 1) { refCount.fetch_add(count); } 95 96 void dropRef(int32_t count = 1) { 97 int32_t previous = refCount.fetch_sub(count); 98 assert(previous >= count && "reference count should not go below zero"); 99 if (previous == count) 100 destroy(); 101 } 102 103 protected: 104 virtual void destroy() { delete this; } 105 106 private: 107 AsyncRuntime *runtime; 108 std::atomic<int32_t> refCount; 109 }; 110 111 } // namespace 112 113 struct AsyncToken : public RefCounted { 114 // AsyncToken created with a reference count of 2 because it will be returned 115 // to the `async.execute` caller and also will be later on emplaced by the 116 // asynchronously executed task. If the caller immediately will drop its 117 // reference we must ensure that the token will be alive until the 118 // asynchronous operation is completed. 119 AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} 120 121 // Internal state below guarded by a mutex. 122 std::mutex mu; 123 std::condition_variable cv; 124 125 bool ready = false; 126 std::vector<std::function<void()>> awaiters; 127 }; 128 129 struct AsyncGroup : public RefCounted { 130 AsyncGroup(AsyncRuntime *runtime) 131 : RefCounted(runtime), pendingTokens(0), rank(0) {} 132 133 std::atomic<int> pendingTokens; 134 std::atomic<int> rank; 135 136 // Internal state below guarded by a mutex. 137 std::mutex mu; 138 std::condition_variable cv; 139 140 std::vector<std::function<void()>> awaiters; 141 }; 142 143 // Adds references to reference counted runtime object. 144 extern "C" MLIR_ASYNCRUNTIME_EXPORT void 145 mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { 146 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 147 refCounted->addRef(count); 148 } 149 150 // Drops references from reference counted runtime object. 151 extern "C" MLIR_ASYNCRUNTIME_EXPORT void 152 mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { 153 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 154 refCounted->dropRef(count); 155 } 156 157 // Create a new `async.token` in not-ready state. 158 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 159 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); 160 return token; 161 } 162 163 // Create a new `async.group` in empty state. 164 extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() { 165 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); 166 return group; 167 } 168 169 extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t 170 mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) { 171 std::unique_lock<std::mutex> lockToken(token->mu); 172 std::unique_lock<std::mutex> lockGroup(group->mu); 173 174 // Get the rank of the token inside the group before we drop the reference. 175 int rank = group->rank.fetch_add(1); 176 group->pendingTokens.fetch_add(1); 177 178 auto onTokenReady = [group, token](bool dropRef) { 179 // Run all group awaiters if it was the last token in the group. 180 if (group->pendingTokens.fetch_sub(1) == 1) { 181 group->cv.notify_all(); 182 for (auto &awaiter : group->awaiters) 183 awaiter(); 184 } 185 186 // We no longer need the token or the group, drop references on them. 187 if (dropRef) { 188 group->dropRef(); 189 token->dropRef(); 190 } 191 }; 192 193 if (token->ready) { 194 onTokenReady(false); 195 } else { 196 group->addRef(); 197 token->addRef(); 198 token->awaiters.push_back([onTokenReady]() { onTokenReady(true); }); 199 } 200 201 return rank; 202 } 203 204 // Switches `async.token` to ready state and runs all awaiters. 205 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 206 std::unique_lock<std::mutex> lock(token->mu); 207 token->ready = true; 208 token->cv.notify_all(); 209 for (auto &awaiter : token->awaiters) 210 awaiter(); 211 212 // Async tokens created with a ref count `2` to keep token alive until the 213 // async task completes. Drop this reference explicitly when token emplaced. 214 token->dropRef(); 215 } 216 217 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 218 std::unique_lock<std::mutex> lock(token->mu); 219 if (!token->ready) 220 token->cv.wait(lock, [token] { return token->ready; }); 221 } 222 223 extern "C" MLIR_ASYNCRUNTIME_EXPORT void 224 mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 225 std::unique_lock<std::mutex> lock(group->mu); 226 if (group->pendingTokens != 0) 227 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 228 } 229 230 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 231 #if LLVM_ENABLE_THREADS 232 std::thread thread([handle, resume]() { (*resume)(handle); }); 233 thread.detach(); 234 #else 235 (*resume)(handle); 236 #endif 237 } 238 239 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 240 CoroHandle handle, 241 CoroResume resume) { 242 std::unique_lock<std::mutex> lock(token->mu); 243 244 auto execute = [handle, resume, token](bool dropRef) { 245 if (dropRef) 246 token->dropRef(); 247 mlirAsyncRuntimeExecute(handle, resume); 248 }; 249 250 if (token->ready) { 251 execute(false); 252 } else { 253 token->addRef(); 254 token->awaiters.push_back([execute]() { execute(true); }); 255 } 256 } 257 258 extern "C" MLIR_ASYNCRUNTIME_EXPORT void 259 mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle, 260 CoroResume resume) { 261 std::unique_lock<std::mutex> lock(group->mu); 262 263 auto execute = [handle, resume, group](bool dropRef) { 264 if (dropRef) 265 group->dropRef(); 266 mlirAsyncRuntimeExecute(handle, resume); 267 }; 268 269 if (group->pendingTokens == 0) { 270 execute(false); 271 } else { 272 group->addRef(); 273 group->awaiters.push_back([execute]() { execute(true); }); 274 } 275 } 276 277 //===----------------------------------------------------------------------===// 278 // Small async runtime support library for testing. 279 //===----------------------------------------------------------------------===// 280 281 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 282 static thread_local std::thread::id thisId = std::this_thread::get_id(); 283 std::cout << "Current thread id: " << thisId << "\n"; 284 } 285 286 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 287