1 #pragma once 2 3 #include <MTTools.h> 4 #include <MTPlatform.h> 5 #include <MTConcurrentQueueLIFO.h> 6 #include <MTStackArray.h> 7 #include <MTFixedArray.h> 8 9 10 namespace MT 11 { 12 const uint32 MT_MAX_THREAD_COUNT = 32; 13 const uint32 MT_MAX_FIBERS_COUNT = 128; 14 const uint32 MT_SCHEDULER_STACK_SIZE = 131072; 15 const uint32 MT_FIBER_STACK_SIZE = 32768; 16 17 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 18 // Task group 19 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 20 // Application can wait until whole group was finished. 21 namespace TaskGroup 22 { 23 enum Type 24 { 25 GROUP_0 = 0, 26 GROUP_1 = 1, 27 GROUP_2 = 2, 28 29 COUNT, 30 31 GROUP_UNDEFINED 32 }; 33 } 34 35 36 class FiberContext; 37 38 typedef void (*TTaskEntryPoint)(FiberContext & context, void* userData); 39 40 41 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 42 // Task description 43 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 44 struct TaskDesc 45 { 46 //Task entry point 47 TTaskEntryPoint taskFunc; 48 49 //Task user data (task context) 50 void* userData; 51 52 TaskDesc() 53 : taskFunc(nullptr) 54 , userData(nullptr) 55 { 56 } 57 58 TaskDesc(TTaskEntryPoint _taskFunc, void* _userData) 59 : taskFunc(_taskFunc) 60 , userData(_userData) 61 { 62 } 63 64 bool IsValid() 65 { 66 return (taskFunc != nullptr); 67 } 68 }; 69 70 struct GroupedTask; 71 struct ThreadContext; 72 73 class TaskScheduler; 74 75 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 76 struct TaskBucket 77 { 78 GroupedTask* tasks; 79 size_t count; 80 TaskBucket(GroupedTask* _tasks, size_t _count) 81 : tasks(_tasks) 82 , count(_count) 83 { 84 } 85 }; 86 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 87 88 89 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 90 // Fiber task status 91 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 92 // Task can be completed for several reasons. 93 // For example task was done or someone call Yield from the Task body. 94 namespace FiberTaskStatus 95 { 96 enum Type 97 { 98 UNKNOWN = 0, 99 RUNNED = 1, 100 FINISHED = 2, 101 AWAITING_GROUP = 3, 102 AWAITING_CHILD = 4, 103 }; 104 } 105 106 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 107 // Fiber context 108 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 109 // Context passed to fiber main function 110 class FiberContext 111 { 112 private: 113 114 void RunSubtasksAndYieldImpl(fixed_array<TaskBucket>& buckets); 115 116 public: 117 118 FiberContext(); 119 120 template<class TTask> 121 void RunSubtasksAndYield(TaskGroup::Type taskGroup, const TTask* taskArray, size_t count); 122 123 template<class TTask> 124 void RunAsync(TaskGroup::Type taskGroup, TTask* taskArray, uint32 count); 125 126 void WaitGroupAndYield(TaskGroup::Type group); 127 128 void Reset(); 129 130 void SetThreadContext(ThreadContext * _threadContext); 131 ThreadContext* GetThreadContext(); 132 133 void SetStatus(FiberTaskStatus::Type _taskStatus); 134 FiberTaskStatus::Type GetStatus() const; 135 136 private: 137 138 // Active thread context (null if fiber context is not executing now) 139 ThreadContext * threadContext; 140 141 // Active task status 142 FiberTaskStatus::Type taskStatus; 143 144 public: 145 146 // Active task attached to this fiber 147 TaskDesc currentTask; 148 149 150 // Active task group 151 TaskGroup::Type currentGroup; 152 153 // Number of children fibers 154 AtomicInt childrenFibersCount; 155 156 // Parent fiber 157 FiberContext* parentFiber; 158 159 // System Fiber 160 Fiber fiber; 161 162 // Prevent false sharing between threads 163 uint8 cacheline[64]; 164 }; 165 166 167 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 168 struct ThreadState 169 { 170 enum Type 171 { 172 ALIVE, 173 EXIT, 174 }; 175 }; 176 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 177 struct GroupedTask 178 { 179 FiberContext* awaitingFiber; 180 FiberContext* parentFiber; 181 TaskGroup::Type group; 182 TaskDesc desc; 183 184 GroupedTask() 185 : parentFiber(nullptr) 186 , awaitingFiber(nullptr) 187 , group(TaskGroup::GROUP_UNDEFINED) 188 {} 189 190 GroupedTask(TaskDesc& _desc, TaskGroup::Type _group) 191 : parentFiber(nullptr) 192 , awaitingFiber(nullptr) 193 , group(_group) 194 , desc(_desc) 195 {} 196 }; 197 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 198 // Thread (Scheduler fiber) context 199 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 200 struct ThreadContext 201 { 202 FiberContext* lastActiveFiberContext; 203 204 // pointer to task manager 205 TaskScheduler* taskScheduler; 206 207 // thread 208 Thread thread; 209 210 // scheduler fiber 211 Fiber schedulerFiber; 212 213 // task queue awaiting execution 214 ConcurrentQueueLIFO<GroupedTask> queue; 215 216 // new task was arrived to queue event 217 Event hasNewTasksEvent; 218 219 // whether thread is alive 220 AtomicInt state; 221 222 // Temporary buffer 223 std::vector<GroupedTask> descBuffer; 224 225 // prevent false sharing between threads 226 uint8 cacheline[64]; 227 228 ThreadContext(); 229 ~ThreadContext(); 230 231 void RestoreAwaitingTasks(TaskGroup::Type taskGroup); 232 }; 233 234 235 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 236 // Task scheduler 237 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 238 class TaskScheduler 239 { 240 friend class FiberContext; 241 friend struct ThreadContext; 242 243 struct GroupStats 244 { 245 AtomicInt inProgressTaskCount; 246 Event allDoneEvent; 247 248 GroupStats() 249 { 250 inProgressTaskCount.Set(0); 251 allDoneEvent.Create( EventReset::MANUAL, true ); 252 } 253 }; 254 255 // Thread index for new task 256 AtomicInt roundRobinThreadIndex; 257 258 // Threads created by task manager 259 uint32 threadsCount; 260 ThreadContext threadContext[MT_MAX_THREAD_COUNT]; 261 262 // Per group task statistic 263 GroupStats groupStats[TaskGroup::COUNT]; 264 265 // All groups task statistic 266 GroupStats allGroupStats; 267 268 269 //Task awaiting group through FiberContext::WaitGroupAndYield call 270 ConcurrentQueueLIFO<FiberContext*> waitTaskQueues[TaskGroup::COUNT]; 271 272 273 // Fibers pool 274 ConcurrentQueueLIFO<FiberContext*> availableFibers; 275 276 // Fibers context 277 FiberContext fiberContext[MT_MAX_FIBERS_COUNT]; 278 279 FiberContext* RequestFiberContext(GroupedTask& task); 280 void ReleaseFiberContext(FiberContext* fiberExecutionContext); 281 282 void RunTasksImpl(fixed_array<TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState); 283 284 static void ThreadMain( void* userData ); 285 static void FiberMain( void* userData ); 286 static FiberContext* ExecuteTask (ThreadContext& threadContext, FiberContext* fiberContext); 287 288 289 template<class T> 290 GroupedTask GetGroupedTask(TaskGroup::Type group, T * src) const 291 { 292 TaskDesc desc(T::TaskEntryPoint, (void*)(src)); 293 return GroupedTask(desc, group); 294 } 295 296 //template specialization for FiberContext* 297 template<> 298 GroupedTask GetGroupedTask(TaskGroup::Type group, FiberContext ** src) const 299 { 300 ASSERT(group == TaskGroup::GROUP_UNDEFINED, "Group must be GROUP_UNDEFINED"); 301 FiberContext * fiberContext = *src; 302 GroupedTask groupedTask(fiberContext->currentTask, fiberContext->currentGroup); 303 groupedTask.awaitingFiber = fiberContext; 304 return groupedTask; 305 } 306 307 308 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 309 // Distributes task to threads: 310 // | Task1 | Task2 | Task3 | Task4 | Task5 | Task6 | 311 // ThreadCount = 4 312 // Thread0: Task1, Task5 313 // Thread1: Task2, Task6 314 // Thread2: Task3 315 // Thread3: Task4 316 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 317 template<class TTask> 318 bool DistibuteDescriptions(TaskGroup::Type group, TTask* taskArray, fixed_array<GroupedTask>& descriptions, fixed_array<TaskBucket>& buckets) const 319 { 320 size_t index = 0; 321 322 for (size_t bucketIndex = 0; (bucketIndex < buckets.size()) && (index < descriptions.size()); ++bucketIndex) 323 { 324 size_t bucketStartIndex = index; 325 326 for (size_t i = bucketIndex; i < descriptions.size(); i += buckets.size()) 327 { 328 descriptions[index++] = GetGroupedTask(group, &taskArray[i]); 329 } 330 331 buckets[bucketIndex] = TaskBucket(&descriptions[bucketStartIndex], index - bucketStartIndex); 332 } 333 334 ASSERT(index == descriptions.size(), "Sanity check") 335 336 return index > 0; 337 } 338 339 public: 340 341 TaskScheduler(); 342 ~TaskScheduler(); 343 344 template<class TTask> 345 void RunAsync(TaskGroup::Type group, TTask* taskArray, uint32 count); 346 347 bool WaitGroup(TaskGroup::Type group, uint32 milliseconds); 348 bool WaitAll(uint32 milliseconds); 349 350 bool IsEmpty(); 351 352 uint32 GetWorkerCount() const; 353 354 bool IsWorkerThread() const; 355 }; 356 357 template<class TTask> 358 void TaskScheduler::RunAsync(TaskGroup::Type group, TTask* taskArray, uint32 count) 359 { 360 ASSERT(!IsWorkerThread(), "Can't use RunAsync inside Task. Use FiberContext.RunAsync() instead."); 361 362 fixed_array<GroupedTask> buffer(ALLOCATE_ON_STACK(GroupedTask, count), count); 363 364 size_t bucketCount = Min(threadsCount, count); 365 fixed_array<TaskBucket> buckets(ALLOCATE_ON_STACK(TaskBucket, bucketCount), bucketCount); 366 367 DistibuteDescriptions(group, taskArray, buffer, buckets); 368 RunTasksImpl(buckets, nullptr, false); 369 } 370 371 template<class TTask> 372 void FiberContext::RunSubtasksAndYield(TaskGroup::Type taskGroup, const TTask* taskArray, size_t count) 373 { 374 ASSERT(threadContext, "ThreadContext is NULL"); 375 ASSERT(count < threadContext->descBuffer.size(), "Buffer overrun!") 376 377 size_t threadsCount = threadContext->taskScheduler->GetWorkerCount(); 378 379 fixed_array<GroupedTask> buffer(&threadContext->descBuffer.front(), count); 380 381 size_t bucketCount = Min(threadsCount, count); 382 fixed_array<TaskBucket> buckets(ALLOCATE_ON_STACK(TaskBucket, bucketCount), bucketCount); 383 384 threadContext->taskScheduler->DistibuteDescriptions(taskGroup, taskArray, buffer, buckets); 385 RunSubtasksAndYieldImpl(buckets); 386 } 387 388 template<class TTask> 389 void FiberContext::RunAsync(TaskGroup::Type taskGroup, TTask* taskArray, uint32 count) 390 { 391 ASSERT(threadContext, "ThreadContext is NULL"); 392 ASSERT(threadContext->taskScheduler->IsWorkerThread(), "Can't use RunAsync outside Task. Use TaskScheduler.RunAsync() instead."); 393 394 TaskScheduler& scheduler = *(threadContext->taskScheduler); 395 396 fixed_array<GroupedTask> buffer(&threadContext->descBuffer.front(), count); 397 398 size_t bucketCount = Min(scheduler.GetWorkerCount(), count); 399 fixed_array<TaskBucket> buckets(ALLOCATE_ON_STACK(TaskBucket, bucketCount), bucketCount); 400 401 scheduler.DistibuteDescriptions(taskGroup, taskArray, buffer, buckets); 402 scheduler.RunTasksImpl(buckets, nullptr, false); 403 } 404 405 406 407 408 409 template<typename T> 410 struct TaskBase 411 { 412 static void TaskEntryPoint(MT::FiberContext& fiberContext, void* userData) 413 { 414 T* task = static_cast<T*>(userData); 415 task->Do(fiberContext); 416 } 417 }; 418 419 420 421 } 422