1 // The MIT License (MIT) 2 // 3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev 4 // 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 // 12 // The above copyright notice and this permission notice shall be included in 13 // all copies or substantial portions of the Software. 14 // 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 // THE SOFTWARE. 22 23 #include <MTScheduler.h> 24 25 namespace MT 26 { 27 28 TaskScheduler::TaskScheduler() 29 : roundRobinThreadIndex(0) 30 { 31 //query number of processor 32 threadsCount = Max(Thread::GetNumberOfHardwareThreads() - 2, 1); 33 34 if (threadsCount > MT_MAX_THREAD_COUNT) 35 { 36 threadsCount = MT_MAX_THREAD_COUNT; 37 } 38 39 // create fiber pool 40 for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++) 41 { 42 FiberContext& context = fiberContext[i]; 43 context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context); 44 availableFibers.Push( &context ); 45 } 46 47 // create worker thread pool 48 for (uint32 i = 0; i < threadsCount; i++) 49 { 50 threadContext[i].SetThreadIndex(i); 51 threadContext[i].taskScheduler = this; 52 threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] ); 53 } 54 } 55 56 TaskScheduler::~TaskScheduler() 57 { 58 for (uint32 i = 0; i < threadsCount; i++) 59 { 60 threadContext[i].state.Set(internal::ThreadState::EXIT); 61 threadContext[i].hasNewTasksEvent.Signal(); 62 } 63 64 for (uint32 i = 0; i < threadsCount; i++) 65 { 66 threadContext[i].thread.Stop(); 67 } 68 } 69 70 FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task) 71 { 72 FiberContext *fiberContext = task.awaitingFiber; 73 if (fiberContext) 74 { 75 task.awaitingFiber = nullptr; 76 return fiberContext; 77 } 78 79 if (!availableFibers.TryPop(fiberContext)) 80 { 81 ASSERT(false, "Fibers pool is empty"); 82 } 83 84 fiberContext->currentTask = task.desc; 85 fiberContext->currentGroup = task.group; 86 fiberContext->parentFiber = task.parentFiber; 87 return fiberContext; 88 } 89 90 void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext) 91 { 92 ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber"); 93 fiberContext->Reset(); 94 availableFibers.Push(fiberContext); 95 } 96 97 FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext) 98 { 99 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 100 101 ASSERT(fiberContext, "Invalid fiber context"); 102 ASSERT(fiberContext->currentTask.IsValid(), "Invalid task"); 103 ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group"); 104 105 // Set actual thread context to fiber 106 fiberContext->SetThreadContext(&threadContext); 107 108 // Update task status 109 fiberContext->SetStatus(FiberTaskStatus::RUNNED); 110 111 ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 112 113 // Run current task code 114 Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber); 115 116 // If task was done 117 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 118 if (taskStatus == FiberTaskStatus::FINISHED) 119 { 120 TaskGroup::Type taskGroup = fiberContext->currentGroup; 121 ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group."); 122 123 // Update group status 124 int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec(); 125 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 126 if (groupTaskCount == 0) 127 { 128 // Restore awaiting tasks 129 threadContext.RestoreAwaitingTasks(taskGroup); 130 threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal(); 131 } 132 133 // Update total task count 134 groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec(); 135 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 136 if (groupTaskCount == 0) 137 { 138 // Notify all tasks in all group finished 139 threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal(); 140 } 141 142 FiberContext* parentFiberContext = fiberContext->parentFiber; 143 if (parentFiberContext != nullptr) 144 { 145 int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec(); 146 ASSERT(childrenFibersCount >= 0, "Sanity check failed!"); 147 148 if (childrenFibersCount == 0) 149 { 150 // This is a last subtask. Restore parent task 151 #if FIBER_DEBUG 152 153 int ownerThread = parentFiberContext->fiber.GetOwnerThread(); 154 FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus(); 155 internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext(); 156 int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter(); 157 ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state"); 158 159 ownerThread; 160 parentTaskStatus; 161 parentThreadContext; 162 fiberUsageCounter; 163 #endif 164 165 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 166 ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context"); 167 168 // WARNING!! Thread context can changed here! Set actual current thread context. 169 parentFiberContext->SetThreadContext(&threadContext); 170 171 ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 172 173 // All subtasks is done. 174 // Exiting and return parent fiber to scheduler 175 return parentFiberContext; 176 } else 177 { 178 // Other subtasks still exist 179 // Exiting 180 return nullptr; 181 } 182 } else 183 { 184 // Task is finished and no parent task 185 // Exiting 186 return nullptr; 187 } 188 } 189 190 ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status") 191 return nullptr; 192 } 193 194 195 void TaskScheduler::FiberMain(void* userData) 196 { 197 FiberContext& fiberContext = *(FiberContext*)(userData); 198 for(;;) 199 { 200 ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context"); 201 ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group"); 202 ASSERT(fiberContext.GetThreadContext(), "Invalid thread context"); 203 ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 204 205 fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData ); 206 207 fiberContext.SetStatus(FiberTaskStatus::FINISHED); 208 209 #ifdef MT_INSTRUMENTED_BUILD 210 fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask); 211 #endif 212 213 Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber); 214 } 215 216 } 217 218 219 bool TaskScheduler::StealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task) 220 { 221 // Try to steal tasks from random worker thread 222 uint32 workersCount = threadContext.taskScheduler->GetWorkerCount(); 223 if (workersCount <= 1) 224 { 225 return false; 226 } 227 228 uint32 victimIndex = threadContext.random.Get() % workersCount; 229 if (victimIndex == threadContext.workerIndex) 230 { 231 victimIndex = victimIndex++; 232 victimIndex = victimIndex % workersCount; 233 } 234 235 internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[victimIndex]; 236 return victimContext.queue.TryPop(task); 237 } 238 239 void TaskScheduler::ThreadMain( void* userData ) 240 { 241 internal::ThreadContext& context = *(internal::ThreadContext*)(userData); 242 ASSERT(context.taskScheduler, "Task scheduler must be not null!"); 243 context.schedulerFiber.CreateFromThread(context.thread); 244 245 while(context.state.Get() != internal::ThreadState::EXIT) 246 { 247 internal::GroupedTask task; 248 if (context.queue.TryPop(task) || StealTask(context, task) ) 249 { 250 // There is a new task 251 FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task); 252 ASSERT(fiberContext, "Can't get execution context from pool"); 253 ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed"); 254 255 while(fiberContext) 256 { 257 #ifdef MT_INSTRUMENTED_BUILD 258 context.NotifyTaskResumed(task.desc); 259 #endif 260 261 // prevent invalid fiber resume from child tasks, before ExecuteTask is done 262 fiberContext->childrenFibersCount.Inc(); 263 264 FiberContext* parentFiber = ExecuteTask(context, fiberContext); 265 266 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 267 268 //release guard 269 int childrenFibersCount = fiberContext->childrenFibersCount.Dec(); 270 271 // Can drop fiber context - task is finished 272 if (taskStatus == FiberTaskStatus::FINISHED) 273 { 274 ASSERT( childrenFibersCount == 0, "Sanity check failed"); 275 context.taskScheduler->ReleaseFiberContext(fiberContext); 276 277 // If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit 278 fiberContext = parentFiber; 279 } else 280 { 281 ASSERT( childrenFibersCount >= 0, "Sanity check failed"); 282 283 // No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask 284 if (childrenFibersCount == 0) 285 { 286 ASSERT(parentFiber == nullptr, "Sanity check failed"); 287 } else 288 { 289 // If subtasks still exist, drop current task execution. task will be resumed when last subtask finished 290 break; 291 } 292 293 // If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called 294 if (taskStatus == FiberTaskStatus::AWAITING_GROUP) 295 { 296 break; 297 } 298 } 299 } //while(fiberContext) 300 301 } else 302 { 303 // Queue is empty and stealing attempt failed 304 // Wait new events 305 context.hasNewTasksEvent.Wait(2000); 306 } 307 308 } // main thread loop 309 } 310 311 void TaskScheduler::RunTasksImpl(WrapperArray<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState) 312 { 313 // Reset counter to initial value 314 int taskCountInGroup[TaskGroup::COUNT]; 315 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 316 { 317 taskCountInGroup[i] = 0; 318 } 319 320 // Set parent fiber pointer 321 // Calculate the number of tasks per group 322 // Calculate total number of tasks 323 size_t count = 0; 324 for (size_t i = 0; i < buckets.Size(); ++i) 325 { 326 internal::TaskBucket& bucket = buckets[i]; 327 for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++) 328 { 329 internal::GroupedTask & task = bucket.tasks[taskIndex]; 330 331 ASSERT(task.group < TaskGroup::COUNT, "Invalid group."); 332 333 task.parentFiber = parentFiber; 334 taskCountInGroup[task.group]++; 335 } 336 count += bucket.count; 337 } 338 339 // Increments child fibers count on parent fiber 340 if (parentFiber) 341 { 342 parentFiber->childrenFibersCount.Add((uint32)count); 343 } 344 345 if (restoredFromAwaitState == false) 346 { 347 // Increments all task in progress counter 348 allGroupStats.allDoneEvent.Reset(); 349 allGroupStats.inProgressTaskCount.Add((uint32)count); 350 351 // Increments task in progress counters (per group) 352 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 353 { 354 int groupTaskCount = taskCountInGroup[i]; 355 if (groupTaskCount > 0) 356 { 357 groupStats[i].allDoneEvent.Reset(); 358 groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount); 359 } 360 } 361 } else 362 { 363 // If task's restored from await state, counters already in correct state 364 } 365 366 // Add to thread queue 367 for (size_t i = 0; i < buckets.Size(); ++i) 368 { 369 int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount; 370 internal::ThreadContext & context = threadContext[bucketIndex]; 371 372 internal::TaskBucket& bucket = buckets[i]; 373 374 context.queue.PushRange(bucket.tasks, bucket.count); 375 context.hasNewTasksEvent.Signal(); 376 } 377 } 378 379 bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds) 380 { 381 VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false); 382 383 return groupStats[group].allDoneEvent.Wait(milliseconds); 384 } 385 386 bool TaskScheduler::WaitAll(uint32 milliseconds) 387 { 388 VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false); 389 390 return allGroupStats.allDoneEvent.Wait(milliseconds); 391 } 392 393 bool TaskScheduler::IsEmpty() 394 { 395 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 396 { 397 if (!threadContext[i].queue.IsEmpty()) 398 { 399 return false; 400 } 401 } 402 return true; 403 } 404 405 uint32 TaskScheduler::GetWorkerCount() const 406 { 407 return threadsCount; 408 } 409 410 bool TaskScheduler::IsWorkerThread() const 411 { 412 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 413 { 414 if (threadContext[i].thread.IsCurrentThread()) 415 { 416 return true; 417 } 418 } 419 return false; 420 } 421 422 #ifdef MT_INSTRUMENTED_BUILD 423 424 size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize) 425 { 426 if (workerIndex >= MT_MAX_THREAD_COUNT) 427 { 428 return 0; 429 } 430 431 size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize); 432 return elementsCount; 433 } 434 435 #endif 436 437 438 } 439