1 // The MIT License (MIT) 2 // 3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev 4 // 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 // 12 // The above copyright notice and this permission notice shall be included in 13 // all copies or substantial portions of the Software. 14 // 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 // THE SOFTWARE. 22 23 #include <MTScheduler.h> 24 25 namespace MT 26 { 27 28 TaskScheduler::TaskScheduler(uint32 workerThreadsCount) 29 : roundRobinThreadIndex(0) 30 { 31 #ifdef MT_INSTRUMENTED_BUILD 32 webServerPort = profilerWebServer.Serve(8080, 8090); 33 34 //initialize start time 35 startTime = MT::GetTimeMicroSeconds(); 36 #endif 37 38 39 if (workerThreadsCount != 0) 40 { 41 threadsCount = workerThreadsCount; 42 } else 43 { 44 //query number of processor 45 threadsCount = (uint32)MT::Clamp(Thread::GetNumberOfHardwareThreads() - 2, 1, (int)MT_MAX_THREAD_COUNT); 46 } 47 48 49 // create fiber pool 50 for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++) 51 { 52 FiberContext& context = fiberContext[i]; 53 context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context); 54 availableFibers.Push( &context ); 55 } 56 57 // create worker thread pool 58 for (uint32 i = 0; i < threadsCount; i++) 59 { 60 threadContext[i].SetThreadIndex(i); 61 threadContext[i].taskScheduler = this; 62 threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] ); 63 } 64 } 65 66 TaskScheduler::~TaskScheduler() 67 { 68 for (uint32 i = 0; i < threadsCount; i++) 69 { 70 threadContext[i].state.Set(internal::ThreadState::EXIT); 71 threadContext[i].hasNewTasksEvent.Signal(); 72 } 73 74 for (uint32 i = 0; i < threadsCount; i++) 75 { 76 threadContext[i].thread.Stop(); 77 } 78 } 79 80 FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task) 81 { 82 FiberContext *fiberContext = task.awaitingFiber; 83 if (fiberContext) 84 { 85 task.awaitingFiber = nullptr; 86 return fiberContext; 87 } 88 89 if (!availableFibers.TryPop(fiberContext)) 90 { 91 MT_ASSERT(false, "Fibers pool is empty"); 92 } 93 94 fiberContext->currentTask = task.desc; 95 fiberContext->currentGroup = task.group; 96 fiberContext->parentFiber = task.parentFiber; 97 return fiberContext; 98 } 99 100 void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext) 101 { 102 MT_ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber"); 103 fiberContext->Reset(); 104 availableFibers.Push(fiberContext); 105 } 106 107 FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext) 108 { 109 MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 110 111 MT_ASSERT(fiberContext, "Invalid fiber context"); 112 MT_ASSERT(fiberContext->currentTask.IsValid(), "Invalid task"); 113 MT_ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group"); 114 115 // Set actual thread context to fiber 116 fiberContext->SetThreadContext(&threadContext); 117 118 // Update task status 119 fiberContext->SetStatus(FiberTaskStatus::RUNNED); 120 121 MT_ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 122 123 // Run current task code 124 Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber); 125 126 // If task was done 127 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 128 if (taskStatus == FiberTaskStatus::FINISHED) 129 { 130 TaskGroup::Type taskGroup = fiberContext->currentGroup; 131 MT_ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group."); 132 133 // Update group status 134 int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec(); 135 MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 136 if (groupTaskCount == 0) 137 { 138 // Restore awaiting tasks 139 threadContext.RestoreAwaitingTasks(taskGroup); 140 threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal(); 141 } 142 143 // Update total task count 144 groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec(); 145 MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 146 if (groupTaskCount == 0) 147 { 148 // Notify all tasks in all group finished 149 threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal(); 150 } 151 152 FiberContext* parentFiberContext = fiberContext->parentFiber; 153 if (parentFiberContext != nullptr) 154 { 155 int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec(); 156 MT_ASSERT(childrenFibersCount >= 0, "Sanity check failed!"); 157 158 if (childrenFibersCount == 0) 159 { 160 // This is a last subtask. Restore parent task 161 #if MT_FIBER_DEBUG 162 163 int ownerThread = parentFiberContext->fiber.GetOwnerThread(); 164 FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus(); 165 internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext(); 166 int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter(); 167 MT_ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state"); 168 169 ownerThread; 170 parentTaskStatus; 171 parentThreadContext; 172 fiberUsageCounter; 173 #endif 174 175 MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 176 MT_ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context"); 177 178 // WARNING!! Thread context can changed here! Set actual current thread context. 179 parentFiberContext->SetThreadContext(&threadContext); 180 181 MT_ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 182 183 // All subtasks is done. 184 // Exiting and return parent fiber to scheduler 185 return parentFiberContext; 186 } else 187 { 188 // Other subtasks still exist 189 // Exiting 190 return nullptr; 191 } 192 } else 193 { 194 // Task is finished and no parent task 195 // Exiting 196 return nullptr; 197 } 198 } 199 200 MT_ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status") 201 return nullptr; 202 } 203 204 205 void TaskScheduler::FiberMain(void* userData) 206 { 207 FiberContext& fiberContext = *(FiberContext*)(userData); 208 for(;;) 209 { 210 MT_ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context"); 211 MT_ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group"); 212 MT_ASSERT(fiberContext.GetThreadContext(), "Invalid thread context"); 213 MT_ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 214 215 fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData ); 216 217 fiberContext.SetStatus(FiberTaskStatus::FINISHED); 218 219 #ifdef MT_INSTRUMENTED_BUILD 220 fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask); 221 #endif 222 223 Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber); 224 } 225 226 } 227 228 229 bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount) 230 { 231 if (workersCount <= 1) 232 { 233 return false; 234 } 235 236 uint32 victimIndex = threadContext.random.Get(); 237 238 for (uint32 attempt = 0; attempt < workersCount; attempt++) 239 { 240 uint32 index = victimIndex % workersCount; 241 if (index == threadContext.workerIndex) 242 { 243 victimIndex++; 244 index = victimIndex % workersCount; 245 } 246 247 internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index]; 248 if (victimContext.queue.TryPop(task)) 249 { 250 return true; 251 } 252 253 victimIndex++; 254 } 255 return false; 256 } 257 258 void TaskScheduler::ThreadMain( void* userData ) 259 { 260 internal::ThreadContext& context = *(internal::ThreadContext*)(userData); 261 MT_ASSERT(context.taskScheduler, "Task scheduler must be not null!"); 262 context.schedulerFiber.CreateFromThread(context.thread); 263 264 uint32 workersCount = context.taskScheduler->GetWorkerCount(); 265 266 while(context.state.Get() != internal::ThreadState::EXIT) 267 { 268 internal::GroupedTask task; 269 if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) ) 270 { 271 // There is a new task 272 FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task); 273 MT_ASSERT(fiberContext, "Can't get execution context from pool"); 274 MT_ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed"); 275 276 while(fiberContext) 277 { 278 #ifdef MT_INSTRUMENTED_BUILD 279 context.NotifyTaskResumed(fiberContext->currentTask); 280 #endif 281 282 // prevent invalid fiber resume from child tasks, before ExecuteTask is done 283 fiberContext->childrenFibersCount.Inc(); 284 285 FiberContext* parentFiber = ExecuteTask(context, fiberContext); 286 287 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 288 289 //release guard 290 int childrenFibersCount = fiberContext->childrenFibersCount.Dec(); 291 292 // Can drop fiber context - task is finished 293 if (taskStatus == FiberTaskStatus::FINISHED) 294 { 295 MT_ASSERT( childrenFibersCount == 0, "Sanity check failed"); 296 context.taskScheduler->ReleaseFiberContext(fiberContext); 297 298 // If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit 299 fiberContext = parentFiber; 300 } else 301 { 302 MT_ASSERT( childrenFibersCount >= 0, "Sanity check failed"); 303 304 // No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask 305 if (childrenFibersCount == 0) 306 { 307 MT_ASSERT(parentFiber == nullptr, "Sanity check failed"); 308 } else 309 { 310 // If subtasks still exist, drop current task execution. task will be resumed when last subtask finished 311 break; 312 } 313 314 // If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called 315 if (taskStatus == FiberTaskStatus::AWAITING_GROUP) 316 { 317 break; 318 } 319 } 320 } //while(fiberContext) 321 322 } else 323 { 324 #ifdef MT_INSTRUMENTED_BUILD 325 int64 waitFrom = MT::GetTimeMicroSeconds(); 326 #endif 327 328 // Queue is empty and stealing attempt failed 329 // Wait new events 330 context.hasNewTasksEvent.Wait(2000); 331 332 #ifdef MT_INSTRUMENTED_BUILD 333 int64 waitTo = MT::GetTimeMicroSeconds(); 334 context.NotifyWorkerAwait(waitFrom, waitTo); 335 #endif 336 337 } 338 339 } // main thread loop 340 } 341 342 void TaskScheduler::RunTasksImpl(WrapperArray<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState) 343 { 344 // Reset counter to initial value 345 int taskCountInGroup[TaskGroup::COUNT]; 346 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 347 { 348 taskCountInGroup[i] = 0; 349 } 350 351 // Set parent fiber pointer 352 // Calculate the number of tasks per group 353 // Calculate total number of tasks 354 size_t count = 0; 355 for (size_t i = 0; i < buckets.Size(); ++i) 356 { 357 internal::TaskBucket& bucket = buckets[i]; 358 for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++) 359 { 360 internal::GroupedTask & task = bucket.tasks[taskIndex]; 361 362 MT_ASSERT(task.group < TaskGroup::COUNT, "Invalid group."); 363 364 task.parentFiber = parentFiber; 365 taskCountInGroup[task.group]++; 366 } 367 count += bucket.count; 368 } 369 370 // Increments child fibers count on parent fiber 371 if (parentFiber) 372 { 373 parentFiber->childrenFibersCount.Add((uint32)count); 374 } 375 376 if (restoredFromAwaitState == false) 377 { 378 // Increments all task in progress counter 379 allGroupStats.allDoneEvent.Reset(); 380 allGroupStats.inProgressTaskCount.Add((uint32)count); 381 382 // Increments task in progress counters (per group) 383 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 384 { 385 int groupTaskCount = taskCountInGroup[i]; 386 if (groupTaskCount > 0) 387 { 388 groupStats[i].allDoneEvent.Reset(); 389 groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount); 390 } 391 } 392 } else 393 { 394 // If task's restored from await state, counters already in correct state 395 } 396 397 // Add to thread queue 398 for (size_t i = 0; i < buckets.Size(); ++i) 399 { 400 int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount; 401 internal::ThreadContext & context = threadContext[bucketIndex]; 402 403 internal::TaskBucket& bucket = buckets[i]; 404 405 context.queue.PushRange(bucket.tasks, bucket.count); 406 context.hasNewTasksEvent.Signal(); 407 } 408 } 409 410 bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds) 411 { 412 MT_VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false); 413 414 return groupStats[group].allDoneEvent.Wait(milliseconds); 415 } 416 417 bool TaskScheduler::WaitAll(uint32 milliseconds) 418 { 419 MT_VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false); 420 421 return allGroupStats.allDoneEvent.Wait(milliseconds); 422 } 423 424 bool TaskScheduler::IsEmpty() 425 { 426 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 427 { 428 if (!threadContext[i].queue.IsEmpty()) 429 { 430 return false; 431 } 432 } 433 return true; 434 } 435 436 uint32 TaskScheduler::GetWorkerCount() const 437 { 438 return threadsCount; 439 } 440 441 bool TaskScheduler::IsWorkerThread() const 442 { 443 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 444 { 445 if (threadContext[i].thread.IsCurrentThread()) 446 { 447 return true; 448 } 449 } 450 return false; 451 } 452 453 #ifdef MT_INSTRUMENTED_BUILD 454 455 size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize) 456 { 457 if (workerIndex >= MT_MAX_THREAD_COUNT) 458 { 459 return 0; 460 } 461 462 size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize); 463 return elementsCount; 464 } 465 466 void TaskScheduler::UpdateProfiler() 467 { 468 profilerWebServer.Update(*this); 469 } 470 471 int32 TaskScheduler::GetWebServerPort() const 472 { 473 return webServerPort; 474 } 475 476 477 478 #endif 479 } 480