1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 #include "llvm/Support/ThreadPool.h" 28 29 //===----------------------------------------------------------------------===// 30 // Async runtime API. 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 35 // Forward declare class defined below. 36 class RefCounted; 37 38 // -------------------------------------------------------------------------- // 39 // AsyncRuntime orchestrates all async operations and Async runtime API is built 40 // on top of the default runtime instance. 41 // -------------------------------------------------------------------------- // 42 43 class AsyncRuntime { 44 public: 45 AsyncRuntime() : numRefCountedObjects(0) {} 46 47 ~AsyncRuntime() { 48 assert(getNumRefCountedObjects() == 0 && 49 "all ref counted objects must be destroyed"); 50 } 51 52 int32_t getNumRefCountedObjects() { 53 return numRefCountedObjects.load(std::memory_order_relaxed); 54 } 55 56 private: 57 friend class RefCounted; 58 59 // Count the total number of reference counted objects in this instance 60 // of an AsyncRuntime. For debugging purposes only. 61 void addNumRefCountedObjects() { 62 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 63 } 64 void dropNumRefCountedObjects() { 65 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 66 } 67 68 std::atomic<int32_t> numRefCountedObjects; 69 }; 70 71 // Returns the default per-process instance of an async runtime. 72 AsyncRuntime *getDefaultAsyncRuntimeInstance() { 73 static auto runtime = std::make_unique<AsyncRuntime>(); 74 return runtime.get(); 75 } 76 77 // -------------------------------------------------------------------------- // 78 // A base class for all reference counted objects created by the async runtime. 79 // -------------------------------------------------------------------------- // 80 81 class RefCounted { 82 public: 83 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) 84 : runtime(runtime), refCount(refCount) { 85 runtime->addNumRefCountedObjects(); 86 } 87 88 virtual ~RefCounted() { 89 assert(refCount.load() == 0 && "reference count must be zero"); 90 runtime->dropNumRefCountedObjects(); 91 } 92 93 RefCounted(const RefCounted &) = delete; 94 RefCounted &operator=(const RefCounted &) = delete; 95 96 void addRef(int32_t count = 1) { refCount.fetch_add(count); } 97 98 void dropRef(int32_t count = 1) { 99 int32_t previous = refCount.fetch_sub(count); 100 assert(previous >= count && "reference count should not go below zero"); 101 if (previous == count) 102 destroy(); 103 } 104 105 protected: 106 virtual void destroy() { delete this; } 107 108 private: 109 AsyncRuntime *runtime; 110 std::atomic<int32_t> refCount; 111 }; 112 113 } // namespace 114 115 struct AsyncToken : public RefCounted { 116 // AsyncToken created with a reference count of 2 because it will be returned 117 // to the `async.execute` caller and also will be later on emplaced by the 118 // asynchronously executed task. If the caller immediately will drop its 119 // reference we must ensure that the token will be alive until the 120 // asynchronous operation is completed. 121 AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} 122 123 // Internal state below guarded by a mutex. 124 std::mutex mu; 125 std::condition_variable cv; 126 127 bool ready = false; 128 std::vector<std::function<void()>> awaiters; 129 }; 130 131 struct AsyncGroup : public RefCounted { 132 AsyncGroup(AsyncRuntime *runtime) 133 : RefCounted(runtime), pendingTokens(0), rank(0) {} 134 135 std::atomic<int> pendingTokens; 136 std::atomic<int> rank; 137 138 // Internal state below guarded by a mutex. 139 std::mutex mu; 140 std::condition_variable cv; 141 142 std::vector<std::function<void()>> awaiters; 143 }; 144 145 // Adds references to reference counted runtime object. 146 extern "C" MLIR_ASYNCRUNTIME_EXPORT void 147 mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { 148 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 149 refCounted->addRef(count); 150 } 151 152 // Drops references from reference counted runtime object. 153 extern "C" MLIR_ASYNCRUNTIME_EXPORT void 154 mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { 155 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 156 refCounted->dropRef(count); 157 } 158 159 // Create a new `async.token` in not-ready state. 160 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 161 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); 162 return token; 163 } 164 165 // Create a new `async.group` in empty state. 166 extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() { 167 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); 168 return group; 169 } 170 171 extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t 172 mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) { 173 std::unique_lock<std::mutex> lockToken(token->mu); 174 std::unique_lock<std::mutex> lockGroup(group->mu); 175 176 // Get the rank of the token inside the group before we drop the reference. 177 int rank = group->rank.fetch_add(1); 178 group->pendingTokens.fetch_add(1); 179 180 auto onTokenReady = [group, token](bool dropRef) { 181 // Run all group awaiters if it was the last token in the group. 182 if (group->pendingTokens.fetch_sub(1) == 1) { 183 group->cv.notify_all(); 184 for (auto &awaiter : group->awaiters) 185 awaiter(); 186 } 187 188 // We no longer need the token or the group, drop references on them. 189 if (dropRef) { 190 group->dropRef(); 191 token->dropRef(); 192 } 193 }; 194 195 if (token->ready) { 196 onTokenReady(false); 197 } else { 198 group->addRef(); 199 token->addRef(); 200 token->awaiters.push_back([onTokenReady]() { onTokenReady(true); }); 201 } 202 203 return rank; 204 } 205 206 // Switches `async.token` to ready state and runs all awaiters. 207 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 208 std::unique_lock<std::mutex> lock(token->mu); 209 token->ready = true; 210 token->cv.notify_all(); 211 for (auto &awaiter : token->awaiters) 212 awaiter(); 213 214 // Async tokens created with a ref count `2` to keep token alive until the 215 // async task completes. Drop this reference explicitly when token emplaced. 216 token->dropRef(); 217 } 218 219 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 220 std::unique_lock<std::mutex> lock(token->mu); 221 if (!token->ready) 222 token->cv.wait(lock, [token] { return token->ready; }); 223 } 224 225 extern "C" MLIR_ASYNCRUNTIME_EXPORT void 226 mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 227 std::unique_lock<std::mutex> lock(group->mu); 228 if (group->pendingTokens != 0) 229 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 230 } 231 232 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 233 #if LLVM_ENABLE_THREADS 234 static llvm::ThreadPool *threadPool = new llvm::ThreadPool(); 235 threadPool->async([handle, resume]() { (*resume)(handle); }); 236 #else 237 (*resume)(handle); 238 #endif 239 } 240 241 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 242 CoroHandle handle, 243 CoroResume resume) { 244 std::unique_lock<std::mutex> lock(token->mu); 245 246 auto execute = [handle, resume, token](bool dropRef) { 247 if (dropRef) 248 token->dropRef(); 249 mlirAsyncRuntimeExecute(handle, resume); 250 }; 251 252 if (token->ready) { 253 execute(false); 254 } else { 255 token->addRef(); 256 token->awaiters.push_back([execute]() { execute(true); }); 257 } 258 } 259 260 extern "C" MLIR_ASYNCRUNTIME_EXPORT void 261 mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle, 262 CoroResume resume) { 263 std::unique_lock<std::mutex> lock(group->mu); 264 265 auto execute = [handle, resume, group](bool dropRef) { 266 if (dropRef) 267 group->dropRef(); 268 mlirAsyncRuntimeExecute(handle, resume); 269 }; 270 271 if (group->pendingTokens == 0) { 272 execute(false); 273 } else { 274 group->addRef(); 275 group->awaiters.push_back([execute]() { execute(true); }); 276 } 277 } 278 279 //===----------------------------------------------------------------------===// 280 // Small async runtime support library for testing. 281 //===----------------------------------------------------------------------===// 282 283 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 284 static thread_local std::thread::id thisId = std::this_thread::get_id(); 285 std::cout << "Current thread id: " << thisId << "\n"; 286 } 287 288 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 289