1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements basic Async runtime API for supporting Async dialect 10 // to LLVM dialect lowering. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/AsyncRuntime.h" 15 16 #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 17 18 #include <atomic> 19 #include <condition_variable> 20 #include <functional> 21 #include <iostream> 22 #include <mutex> 23 #include <thread> 24 #include <vector> 25 26 //===----------------------------------------------------------------------===// 27 // Async runtime API. 28 //===----------------------------------------------------------------------===// 29 30 struct AsyncToken { 31 bool ready = false; 32 std::mutex mu; 33 std::condition_variable cv; 34 std::vector<std::function<void()>> awaiters; 35 }; 36 37 struct AsyncGroup { 38 std::atomic<int> pendingTokens{0}; 39 std::atomic<int> rank{0}; 40 std::mutex mu; 41 std::condition_variable cv; 42 std::vector<std::function<void()>> awaiters; 43 }; 44 45 // Create a new `async.token` in not-ready state. 46 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { 47 AsyncToken *token = new AsyncToken; 48 return token; 49 } 50 51 // Create a new `async.group` in empty state. 52 extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() { 53 AsyncGroup *group = new AsyncGroup; 54 return group; 55 } 56 57 extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t 58 mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) { 59 std::unique_lock<std::mutex> lockToken(token->mu); 60 std::unique_lock<std::mutex> lockGroup(group->mu); 61 62 group->pendingTokens.fetch_add(1); 63 64 auto onTokenReady = [group]() { 65 // Run all group awaiters if it was the last token in the group. 66 if (group->pendingTokens.fetch_sub(1) == 1) { 67 group->cv.notify_all(); 68 for (auto &awaiter : group->awaiters) 69 awaiter(); 70 } 71 }; 72 73 if (token->ready) 74 onTokenReady(); 75 else 76 token->awaiters.push_back([onTokenReady]() { onTokenReady(); }); 77 78 return group->rank.fetch_add(1); 79 } 80 81 // Switches `async.token` to ready state and runs all awaiters. 82 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { 83 std::unique_lock<std::mutex> lock(token->mu); 84 token->ready = true; 85 token->cv.notify_all(); 86 for (auto &awaiter : token->awaiters) 87 awaiter(); 88 } 89 90 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { 91 std::unique_lock<std::mutex> lock(token->mu); 92 if (!token->ready) 93 token->cv.wait(lock, [token] { return token->ready; }); 94 } 95 96 extern "C" MLIR_ASYNCRUNTIME_EXPORT void 97 mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { 98 std::unique_lock<std::mutex> lock(group->mu); 99 if (group->pendingTokens != 0) 100 group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); 101 } 102 103 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { 104 #if LLVM_ENABLE_THREADS 105 std::thread thread([handle, resume]() { (*resume)(handle); }); 106 thread.detach(); 107 #else 108 (*resume)(handle); 109 #endif 110 } 111 112 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, 113 CoroHandle handle, 114 CoroResume resume) { 115 std::unique_lock<std::mutex> lock(token->mu); 116 117 auto execute = [handle, resume]() { 118 mlirAsyncRuntimeExecute(handle, resume); 119 }; 120 121 if (token->ready) 122 execute(); 123 else 124 token->awaiters.push_back([execute]() { execute(); }); 125 } 126 127 extern "C" MLIR_ASYNCRUNTIME_EXPORT void 128 mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle, 129 CoroResume resume) { 130 std::unique_lock<std::mutex> lock(group->mu); 131 132 auto execute = [handle, resume]() { 133 mlirAsyncRuntimeExecute(handle, resume); 134 }; 135 136 if (group->pendingTokens == 0) 137 execute(); 138 else 139 group->awaiters.push_back([execute]() { execute(); }); 140 } 141 142 //===----------------------------------------------------------------------===// 143 // Small async runtime support library for testing. 144 //===----------------------------------------------------------------------===// 145 146 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { 147 static thread_local std::thread::id thisId = std::this_thread::get_id(); 148 std::cout << "Current thread id: " << thisId << "\n"; 149 } 150 151 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS 152