1 // The MIT License (MIT) 2 // 3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev 4 // 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 // 12 // The above copyright notice and this permission notice shall be included in 13 // all copies or substantial portions of the Software. 14 // 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 // THE SOFTWARE. 22 23 #pragma once 24 25 #include <MTColorTable.h> 26 #include <MTTools.h> 27 #include <MTPlatform.h> 28 #include <MTConcurrentQueueLIFO.h> 29 #include <MTStackArray.h> 30 #include <MTArrayView.h> 31 #include <MTThreadContext.h> 32 #include <MTFiberContext.h> 33 #include <MTAllocator.h> 34 #include <MTTaskPool.h> 35 #include <Scopes/MTScopes.h> 36 37 38 namespace MT 39 { 40 41 template<typename CLASS_TYPE, typename MACRO_TYPE> 42 struct CheckType 43 { 44 static_assert(std::is_same<CLASS_TYPE, MACRO_TYPE>::value, "Invalid type in MT_DECLARE_TASK macro. See CheckType template instantiation params to details."); 45 }; 46 47 struct TypeChecker 48 { 49 template <typename T> 50 static T QueryThisType(T thisPtr) 51 { 52 return (T)nullptr; 53 } 54 }; 55 56 57 template <typename T> 58 inline void CallDtor(T* p) 59 { 60 MT_UNUSED(p); 61 p->~T(); 62 } 63 64 } 65 66 #if _MSC_VER 67 68 // Visual Studio compile time check 69 #define MT_COMPILE_TIME_TYPE_CHECK(TYPE) \ 70 void CompileTimeCheckMethod() \ 71 { \ 72 MT::CheckType< typename std::remove_pointer< decltype(MT::TypeChecker::QueryThisType(this)) >::type, typename TYPE > compileTypeTypesCheck; \ 73 compileTypeTypesCheck; \ 74 } 75 76 #else 77 78 79 // GCC, Clang and other compilers compile time check 80 #define MT_COMPILE_TIME_TYPE_CHECK(TYPE) \ 81 void CompileTimeCheckMethod() \ 82 { \ 83 /* query this pointer type */ \ 84 typedef decltype(MT::TypeChecker::QueryThisType(this)) THIS_PTR_TYPE; \ 85 /* query class type from this pointer type */ \ 86 typedef typename std::remove_pointer<THIS_PTR_TYPE>::type CPP_TYPE; \ 87 /* define macro type */ \ 88 typedef TYPE MACRO_TYPE; \ 89 /* compile time checking that is same types */ \ 90 MT::CheckType< CPP_TYPE, MACRO_TYPE > compileTypeTypesCheck; \ 91 /* remove unused variable warning */ \ 92 MT_UNUSED(compileTypeTypesCheck); \ 93 } 94 95 #endif 96 97 98 99 100 #define MT_DECLARE_TASK_IMPL(TYPE) \ 101 \ 102 MT_COMPILE_TIME_TYPE_CHECK(TYPE) \ 103 \ 104 static void TaskEntryPoint(MT::FiberContext& fiberContext, void* userData) \ 105 { \ 106 TYPE * task = static_cast< TYPE *>(userData); \ 107 task->Do(fiberContext); \ 108 } \ 109 \ 110 static void PoolTaskDestroy(void* userData) \ 111 { \ 112 TYPE * task = static_cast< TYPE *>(userData); \ 113 MT::CallDtor( task ); \ 114 /* Find task pool header */ \ 115 MT::PoolElementHeader * poolHeader = (MT::PoolElementHeader *)((char*)userData - sizeof(MT::PoolElementHeader)); \ 116 /* Fixup pool header, mark task as unused */ \ 117 poolHeader->id.Store(MT::TaskID::UNUSED); \ 118 } \ 119 120 121 122 #ifdef MT_INSTRUMENTED_BUILD 123 #include <MTProfilerEventListener.h> 124 125 #define MT_DECLARE_TASK(TYPE, DEBUG_COLOR) \ 126 static const mt_char* GetDebugID() \ 127 { \ 128 return MT_TEXT( #TYPE ); \ 129 } \ 130 \ 131 static MT::Color::Type GetDebugColor() \ 132 { \ 133 return DEBUG_COLOR; \ 134 } \ 135 \ 136 MT_DECLARE_TASK_IMPL(TYPE); 137 138 139 #else 140 141 #define MT_DECLARE_TASK(TYPE, colorID) \ 142 MT_DECLARE_TASK_IMPL(TYPE); 143 144 #endif 145 146 147 148 149 150 151 namespace MT 152 { 153 const uint32 MT_MAX_THREAD_COUNT = 64; 154 const uint32 MT_MAX_FIBERS_COUNT = 256; 155 const uint32 MT_SCHEDULER_STACK_SIZE = 1048576; 156 const uint32 MT_FIBER_STACK_SIZE = 65536; 157 158 namespace internal 159 { 160 struct ThreadContext; 161 } 162 163 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 164 // Task scheduler 165 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 166 class TaskScheduler 167 { 168 friend class FiberContext; 169 friend struct internal::ThreadContext; 170 171 172 173 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 174 // Task group description 175 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 176 // Application can assign task group to task and later wait until group was finished. 177 class TaskGroupDescription 178 { 179 AtomicInt32 inProgressTaskCount; 180 Event allDoneEvent; 181 182 //Tasks awaiting group through FiberContext::WaitGroupAndYield call 183 ConcurrentQueueLIFO<FiberContext*> waitTasksQueue; 184 185 bool debugIsFree; 186 187 public: 188 189 MT_NOCOPYABLE(TaskGroupDescription); 190 191 TaskGroupDescription() 192 { 193 inProgressTaskCount.Store(0); 194 allDoneEvent.Create( EventReset::MANUAL, true ); 195 debugIsFree = true; 196 } 197 198 int GetTaskCount() const 199 { 200 return inProgressTaskCount.Load(); 201 } 202 203 ConcurrentQueueLIFO<FiberContext*> & GetWaitQueue() 204 { 205 return waitTasksQueue; 206 } 207 208 int Dec() 209 { 210 return inProgressTaskCount.DecFetch(); 211 } 212 213 int Inc() 214 { 215 return inProgressTaskCount.IncFetch(); 216 } 217 218 int Add(int sum) 219 { 220 return inProgressTaskCount.AddFetch(sum); 221 } 222 223 void Signal() 224 { 225 allDoneEvent.Signal(); 226 } 227 228 void Reset() 229 { 230 allDoneEvent.Reset(); 231 } 232 233 bool Wait(uint32 milliseconds) 234 { 235 return allDoneEvent.Wait(milliseconds); 236 } 237 238 void SetDebugIsFree(bool _debugIsFree) 239 { 240 debugIsFree = _debugIsFree; 241 } 242 243 bool GetDebugIsFree() const 244 { 245 return debugIsFree; 246 } 247 }; 248 249 250 // Thread index for new task 251 AtomicInt32 roundRobinThreadIndex; 252 253 // Started threads count 254 AtomicInt32 startedThreadsCount; 255 256 // Threads created by task manager 257 AtomicInt32 threadsCount; 258 internal::ThreadContext threadContext[MT_MAX_THREAD_COUNT]; 259 260 // All groups task statistic 261 TaskGroupDescription allGroups; 262 263 // Groups pool 264 ConcurrentQueueLIFO<TaskGroup> availableGroups; 265 266 // 267 TaskGroupDescription groupStats[TaskGroup::MT_MAX_GROUPS_COUNT]; 268 269 // Fibers pool 270 ConcurrentQueueLIFO<FiberContext*> availableFibers; 271 272 // Fibers context 273 FiberContext fiberContext[MT_MAX_FIBERS_COUNT]; 274 275 #ifdef MT_INSTRUMENTED_BUILD 276 IProfilerEventListener * profilerEventListener; 277 #endif 278 279 FiberContext* RequestFiberContext(internal::GroupedTask& task); 280 void ReleaseFiberContext(FiberContext* fiberExecutionContext); 281 void RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState); 282 TaskGroupDescription & GetGroupDesc(TaskGroup group); 283 284 static void ThreadMain( void* userData ); 285 static void FiberMain( void* userData ); 286 static bool TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount); 287 288 static FiberContext* ExecuteTask (internal::ThreadContext& threadContext, FiberContext* fiberContext); 289 290 public: 291 292 /// \brief Initializes a new instance of the TaskScheduler class. 293 /// \param workerThreadsCount Worker threads count. Automatically determines the required number of threads if workerThreadsCount set to 0 294 #ifdef MT_INSTRUMENTED_BUILD 295 TaskScheduler(uint32 workerThreadsCount = 0, IProfilerEventListener* listener = nullptr); 296 #else 297 TaskScheduler(uint32 workerThreadsCount = 0); 298 #endif 299 300 301 ~TaskScheduler(); 302 303 template<class TTask> 304 void RunAsync(TaskGroup group, TTask* taskArray, uint32 taskCount); 305 306 void RunAsync(TaskGroup group, TaskHandle* taskHandleArray, uint32 taskHandleCount); 307 308 309 bool WaitGroup(TaskGroup group, uint32 milliseconds); 310 bool WaitAll(uint32 milliseconds); 311 312 TaskGroup CreateGroup(); 313 void ReleaseGroup(TaskGroup group); 314 315 bool IsEmpty(); 316 317 int32 GetWorkersCount() const; 318 319 bool IsWorkerThread() const; 320 321 #ifdef MT_INSTRUMENTED_BUILD 322 323 inline IProfilerEventListener* GetProfilerEventListener() 324 { 325 return profilerEventListener; 326 } 327 328 #endif 329 }; 330 } 331 332 #include "MTScheduler.inl" 333 #include "MTFiberContext.inl" 334