1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <cassert> 20 #include <condition_variable> 21 #include <functional> 22 #include <iostream> 23 #include <mutex> 24 #include <thread> 25 #include <vector> 26 27 #include "llvm/Support/ThreadPool.h" 28 29 //===----------------------------------------------------------------------===// 30 // Async runtime API. 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 35 // Forward declare class defined below. 36 class RefCounted; 37 38 // -------------------------------------------------------------------------- // 39 // AsyncRuntime orchestrates all async operations and Async runtime API is built 40 // on top of the default runtime instance. 41 // -------------------------------------------------------------------------- // 42 43 class AsyncRuntime { 44 public: 45 AsyncRuntime() : numRefCountedObjects(0) {} 46 47 ~AsyncRuntime() { 48 threadPool.wait(); // wait for the completion of all async tasks 49 assert(getNumRefCountedObjects() == 0 && 50 "all ref counted objects must be destroyed"); 51 } 52 53 int32_t getNumRefCountedObjects() { 54 return numRefCountedObjects.load(std::memory_order_relaxed); 55 } 56 57 llvm::ThreadPool &getThreadPool() { return threadPool; } 58 59 private: 60 friend class RefCounted; 61 62 // Count the total number of reference counted objects in this instance 63 // of an AsyncRuntime. For debugging purposes only. 64 void addNumRefCountedObjects() { 65 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); 66 } 67 void dropNumRefCountedObjects() { 68 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); 69 } 70 71 std::atomic<int32_t> numRefCountedObjects; 72 73 llvm::ThreadPool threadPool; 74 }; 75 76 // Returns the default per-process instance of an async runtime. 77 AsyncRuntime *getDefaultAsyncRuntimeInstance() { 78 static auto runtime = std::make_unique<AsyncRuntime>(); 79 return runtime.get(); 80 } 81 82 // -------------------------------------------------------------------------- // 83 // A base class for all reference counted objects created by the async runtime. 84 // -------------------------------------------------------------------------- // 85 86 class RefCounted { 87 public: 88 RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) 89 : runtime(runtime), refCount(refCount) { 90 runtime->addNumRefCountedObjects(); 91 } 92 93 virtual ~RefCounted() { 94 assert(refCount.load() == 0 && "reference count must be zero"); 95 runtime->dropNumRefCountedObjects(); 96 } 97 98 RefCounted(const RefCounted &) = delete; 99 RefCounted &operator=(const RefCounted &) = delete; 100 101 void addRef(int32_t count = 1) { refCount.fetch_add(count); } 102 103 void dropRef(int32_t count = 1) { 104 int32_t previous = refCount.fetch_sub(count); 105 assert(previous >= count && "reference count should not go below zero"); 106 if (previous == count) 107 destroy(); 108 } 109 110 protected: 111 virtual void destroy() { delete this; } 112 113 private: 114 AsyncRuntime *runtime; 115 std::atomic<int32_t> refCount; 116 }; 117 118 } // namespace 119 120 struct AsyncToken : public RefCounted { 121 // AsyncToken created with a reference count of 2 because it will be returned 122 // to the `async.execute` caller and also will be later on emplaced by the 123 // asynchronously executed task. If the caller immediately will drop its 124 // reference we must ensure that the token will be alive until the 125 // asynchronous operation is completed. 126 AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} 127 128 // Internal state below guarded by a mutex. 129 std::mutex mu; 130 std::condition_variable cv; 131 132 bool ready = false; 133 std::vector<std::function<void()>> awaiters; 134 }; 135 136 struct AsyncGroup : public RefCounted { 137 AsyncGroup(AsyncRuntime *runtime) 138 : RefCounted(runtime), pendingTokens(0), rank(0) {} 139 140 std::atomic<int> pendingTokens; 141 std::atomic<int> rank; 142 143 // Internal state below guarded by a mutex. 144 std::mutex mu; 145 std::condition_variable cv; 146 147 std::vector<std::function<void()>> awaiters; 148 }; 149 150 // Adds references to reference counted runtime object. 151 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { 152 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 153 refCounted->addRef(count); 154 } 155 156 // Drops references from reference counted runtime object. 157 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { 158 RefCounted *refCounted = static_cast<RefCounted *>(ptr); 159 refCounted->dropRef(count); 160 } 161 162 // Create a new `async.token` in not-ready state. 163 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 164 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); 165 return token; 166 } 167 168 // Create a new `async.group` in empty state. 169 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { 170 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); 171 return group; 172 } 173 174 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, 175 AsyncGroup *group) { 176 std::unique_lock<std::mutex> lockToken(token->mu); 177 std::unique_lock<std::mutex> lockGroup(group->mu); 178 179 // Get the rank of the token inside the group before we drop the reference. 180 int rank = group->rank.fetch_add(1); 181 group->pendingTokens.fetch_add(1); 182 183 auto onTokenReady = [group]() { 184 // Run all group awaiters if it was the last token in the group. 185 if (group->pendingTokens.fetch_sub(1) == 1) { 186 group->cv.notify_all(); 187 for (auto &awaiter : group->awaiters) 188 awaiter(); 189 } 190 }; 191 192 if (token->ready) { 193 // Update group pending tokens immediately and maybe run awaiters. 194 onTokenReady(); 195 196 } else { 197 // Update group pending tokens when token will become ready. Because this 198 // will happen asynchronously we must ensure that `group` is alive until 199 // then, and re-ackquire the lock. 200 group->addRef(); 201 202 token->awaiters.push_back([group, onTokenReady]() { 203 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 204 { 205 std::unique_lock<std::mutex> lockGroup(group->mu); 206 onTokenReady(); 207 } 208 group->dropRef(); 209 }); 210 } 211 212 return rank; 213 } 214 215 // Switches `async.token` to ready state and runs all awaiters. 216 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 217 // Make sure that `dropRef` does not destroy the mutex owned by the lock. 218 { 219 std::unique_lock<std::mutex> lock(token->mu); 220 token->ready = true; 221 token->cv.notify_all(); 222 for (auto &awaiter : token->awaiters) 223 awaiter(); 224 } 225 226 // Async tokens created with a ref count `2` to keep token alive until the 227 // async task completes. Drop this reference explicitly when token emplaced. 228 token->dropRef(); 229 } 230 231 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 232 std::unique_lock<std::mutex> lock(token->mu); 233 if (!token->ready) 234 token->cv.wait(lock, [token] { return token->ready; }); 235 } 236 237 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 238 std::unique_lock<std::mutex> lock(group->mu); 239 if (group->pendingTokens != 0) 240 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 241 } 242 243 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 244 auto *runtime = getDefaultAsyncRuntimeInstance(); 245 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); }); 246 } 247 248 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 249 CoroHandle handle, 250 CoroResume resume) { 251 std::unique_lock<std::mutex> lock(token->mu); 252 auto execute = [handle, resume]() { (*resume)(handle); }; 253 if (token->ready) 254 execute(); 255 else 256 token->awaiters.push_back([execute]() { execute(); }); 257 } 258 259 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, 260 CoroHandle handle, 261 CoroResume resume) { 262 std::unique_lock<std::mutex> lock(group->mu); 263 auto execute = [handle, resume]() { (*resume)(handle); }; 264 if (group->pendingTokens == 0) 265 execute(); 266 else 267 group->awaiters.push_back([execute]() { execute(); }); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // Small async runtime support library for testing. 272 //===----------------------------------------------------------------------===// 273 274 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 275 static thread_local std::thread::id thisId = std::this_thread::get_id(); 276 std::cout << "Current thread id: " << thisId << std::endl; 277 } 278 279 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 280