1 #include <MTScheduler.h> 2 3 namespace MT 4 { 5 6 TaskScheduler::TaskScheduler() 7 : roundRobinThreadIndex(0) 8 { 9 //query number of processor 10 threadsCount = Max(Thread::GetNumberOfHardwareThreads() - 2, 1); 11 12 if (threadsCount > MT_MAX_THREAD_COUNT) 13 { 14 threadsCount = MT_MAX_THREAD_COUNT; 15 } 16 17 // create fiber pool 18 for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++) 19 { 20 FiberContext& context = fiberContext[i]; 21 context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context); 22 availableFibers.Push( &context ); 23 } 24 25 // create worker thread pool 26 for (uint32 i = 0; i < threadsCount; i++) 27 { 28 threadContext[i].taskScheduler = this; 29 threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] ); 30 } 31 } 32 33 TaskScheduler::~TaskScheduler() 34 { 35 for (uint32 i = 0; i < threadsCount; i++) 36 { 37 threadContext[i].state.Set(internal::ThreadState::EXIT); 38 threadContext[i].hasNewTasksEvent.Signal(); 39 } 40 41 for (uint32 i = 0; i < threadsCount; i++) 42 { 43 threadContext[i].thread.Stop(); 44 } 45 } 46 47 FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task) 48 { 49 FiberContext *fiberContext = task.awaitingFiber; 50 if (fiberContext) 51 { 52 task.awaitingFiber = nullptr; 53 return fiberContext; 54 } 55 56 if (!availableFibers.TryPop(fiberContext)) 57 { 58 ASSERT(false, "Fibers pool is empty"); 59 } 60 61 fiberContext->currentTask = task.desc; 62 fiberContext->currentGroup = task.group; 63 fiberContext->parentFiber = task.parentFiber; 64 return fiberContext; 65 } 66 67 void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext) 68 { 69 ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber"); 70 fiberContext->Reset(); 71 availableFibers.Push(fiberContext); 72 } 73 74 FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext) 75 { 76 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 77 78 ASSERT(fiberContext, "Invalid fiber context"); 79 ASSERT(fiberContext->currentTask.IsValid(), "Invalid task"); 80 ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group"); 81 82 // Set actual thread context to fiber 83 fiberContext->SetThreadContext(&threadContext); 84 85 // Update task status 86 fiberContext->SetStatus(FiberTaskStatus::RUNNED); 87 88 ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 89 90 // Run current task code 91 Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber); 92 93 // If task was done 94 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 95 if (taskStatus == FiberTaskStatus::FINISHED) 96 { 97 TaskGroup::Type taskGroup = fiberContext->currentGroup; 98 ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group."); 99 100 // Update group status 101 int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec(); 102 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 103 if (groupTaskCount == 0) 104 { 105 // Restore awaiting tasks 106 threadContext.RestoreAwaitingTasks(taskGroup); 107 threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal(); 108 } 109 110 // Update total task count 111 groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec(); 112 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 113 if (groupTaskCount == 0) 114 { 115 // Notify all tasks in all group finished 116 threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal(); 117 } 118 119 FiberContext* parentFiberContext = fiberContext->parentFiber; 120 if (parentFiberContext != nullptr) 121 { 122 int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec(); 123 ASSERT(childrenFibersCount >= 0, "Sanity check failed!"); 124 125 if (childrenFibersCount == 0) 126 { 127 // This is a last subtask. Restore parent task 128 #if FIBER_DEBUG 129 130 int ownerThread = parentFiberContext->fiber.GetOwnerThread(); 131 FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus(); 132 internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext(); 133 int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter(); 134 ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state"); 135 136 ownerThread; 137 parentTaskStatus; 138 parentThreadContext; 139 fiberUsageCounter; 140 #endif 141 142 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 143 ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context"); 144 145 // WARNING!! Thread context can changed here! Set actual current thread context. 146 parentFiberContext->SetThreadContext(&threadContext); 147 148 ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 149 150 // All subtasks is done. 151 // Exiting and return parent fiber to scheduler 152 return parentFiberContext; 153 } else 154 { 155 // Other subtasks still exist 156 // Exiting 157 return nullptr; 158 } 159 } else 160 { 161 // Task is finished and no parent task 162 // Exiting 163 return nullptr; 164 } 165 } 166 167 ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status") 168 return nullptr; 169 } 170 171 172 void TaskScheduler::FiberMain(void* userData) 173 { 174 FiberContext& fiberContext = *(FiberContext*)(userData); 175 for(;;) 176 { 177 ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context"); 178 ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group"); 179 ASSERT(fiberContext.GetThreadContext(), "Invalid thread context"); 180 ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 181 182 fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData ); 183 184 fiberContext.SetStatus(FiberTaskStatus::FINISHED); 185 186 Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber); 187 } 188 189 } 190 191 192 void TaskScheduler::ThreadMain( void* userData ) 193 { 194 internal::ThreadContext& context = *(internal::ThreadContext*)(userData); 195 ASSERT(context.taskScheduler, "Task scheduler must be not null!"); 196 context.schedulerFiber.CreateFromThread(context.thread); 197 198 while(context.state.Get() != internal::ThreadState::EXIT) 199 { 200 internal::GroupedTask task; 201 if (context.queue.TryPop(task)) 202 { 203 // There is a new task 204 FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task); 205 ASSERT(fiberContext, "Can't get execution context from pool"); 206 ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed"); 207 208 while(fiberContext) 209 { 210 // prevent invalid fiber resume from child tasks, before ExecuteTask is done 211 fiberContext->childrenFibersCount.Inc(); 212 213 FiberContext* parentFiber = ExecuteTask(context, fiberContext); 214 215 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 216 217 //release guard 218 int childrenFibersCount = fiberContext->childrenFibersCount.Dec(); 219 220 // Can drop fiber context - task is finished 221 if (taskStatus == FiberTaskStatus::FINISHED) 222 { 223 ASSERT( childrenFibersCount == 0, "Sanity check failed"); 224 225 context.taskScheduler->ReleaseFiberContext(fiberContext); 226 227 // If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit 228 fiberContext = parentFiber; 229 } else 230 { 231 ASSERT( childrenFibersCount >= 0, "Sanity check failed"); 232 233 // No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask 234 if (childrenFibersCount == 0) 235 { 236 ASSERT(parentFiber == nullptr, "Sanity check failed"); 237 } else 238 { 239 // If subtasks still exist, drop current task execution. task will be resumed when last subtask finished 240 break; 241 } 242 243 // If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called 244 if (taskStatus == FiberTaskStatus::AWAITING_GROUP) 245 { 246 break; 247 } 248 } 249 } //while(fiberContext) 250 251 } else 252 { 253 // Queue is empty 254 // TODO: can try to steal tasks from other threads 255 context.hasNewTasksEvent.Wait(2000); 256 } 257 258 } // main thread loop 259 } 260 261 void TaskScheduler::RunTasksImpl(fixed_array<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState) 262 { 263 // Reset counter to initial value 264 int taskCountInGroup[TaskGroup::COUNT]; 265 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 266 { 267 taskCountInGroup[i] = 0; 268 } 269 270 // Set parent fiber pointer 271 // Calculate the number of tasks per group 272 // Calculate total number of tasks 273 size_t count = 0; 274 for (size_t i = 0; i < buckets.size(); ++i) 275 { 276 internal::TaskBucket& bucket = buckets[i]; 277 for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++) 278 { 279 internal::GroupedTask & task = bucket.tasks[taskIndex]; 280 281 ASSERT(task.group < TaskGroup::COUNT, "Invalid group."); 282 283 task.parentFiber = parentFiber; 284 taskCountInGroup[task.group]++; 285 } 286 count += bucket.count; 287 } 288 289 // Increments child fibers count on parent fiber 290 if (parentFiber) 291 { 292 parentFiber->childrenFibersCount.Add((uint32)count); 293 } 294 295 if (restoredFromAwaitState == false) 296 { 297 // Increments all task in progress counter 298 allGroupStats.allDoneEvent.Reset(); 299 allGroupStats.inProgressTaskCount.Add((uint32)count); 300 301 // Increments task in progress counters (per group) 302 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 303 { 304 int groupTaskCount = taskCountInGroup[i]; 305 if (groupTaskCount > 0) 306 { 307 groupStats[i].allDoneEvent.Reset(); 308 groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount); 309 } 310 } 311 } else 312 { 313 // If task's restored from await state, counters already in correct state 314 } 315 316 // Add to thread queue 317 for (size_t i = 0; i < buckets.size(); ++i) 318 { 319 int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount; 320 internal::ThreadContext & context = threadContext[bucketIndex]; 321 322 internal::TaskBucket& bucket = buckets[i]; 323 324 context.queue.PushRange(bucket.tasks, bucket.count); 325 context.hasNewTasksEvent.Signal(); 326 } 327 } 328 329 bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds) 330 { 331 VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false); 332 333 return groupStats[group].allDoneEvent.Wait(milliseconds); 334 } 335 336 bool TaskScheduler::WaitAll(uint32 milliseconds) 337 { 338 VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false); 339 340 return allGroupStats.allDoneEvent.Wait(milliseconds); 341 } 342 343 bool TaskScheduler::IsEmpty() 344 { 345 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 346 { 347 if (!threadContext[i].queue.IsEmpty()) 348 { 349 return false; 350 } 351 } 352 return true; 353 } 354 355 uint32 TaskScheduler::GetWorkerCount() const 356 { 357 return threadsCount; 358 } 359 360 bool TaskScheduler::IsWorkerThread() const 361 { 362 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 363 { 364 if (threadContext[i].thread.IsCurrentThread()) 365 { 366 return true; 367 } 368 } 369 return false; 370 } 371 372 } 373