1 // The MIT License (MIT) 2 // 3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev 4 // 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 // 12 // The above copyright notice and this permission notice shall be included in 13 // all copies or substantial portions of the Software. 14 // 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 // THE SOFTWARE. 22 23 #include <MTScheduler.h> 24 25 namespace MT 26 { 27 28 TaskScheduler::TaskScheduler() 29 : roundRobinThreadIndex(0) 30 { 31 #ifdef MT_INSTRUMENTED_BUILD 32 profilerWebServer.Serve(8080, 8090); 33 34 //initialize static start time 35 TaskScheduler::GetStartTime(); 36 #endif 37 38 //query number of processor 39 threadsCount = Max((uint32)Thread::GetNumberOfHardwareThreads() >> 1, (uint32)1); 40 41 if (threadsCount > MT_MAX_THREAD_COUNT) 42 { 43 threadsCount = MT_MAX_THREAD_COUNT; 44 } 45 46 // create fiber pool 47 for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++) 48 { 49 FiberContext& context = fiberContext[i]; 50 context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context); 51 availableFibers.Push( &context ); 52 } 53 54 // create worker thread pool 55 for (uint32 i = 0; i < threadsCount; i++) 56 { 57 threadContext[i].SetThreadIndex(i); 58 threadContext[i].taskScheduler = this; 59 threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] ); 60 } 61 } 62 63 TaskScheduler::~TaskScheduler() 64 { 65 for (uint32 i = 0; i < threadsCount; i++) 66 { 67 threadContext[i].state.Set(internal::ThreadState::EXIT); 68 threadContext[i].hasNewTasksEvent.Signal(); 69 } 70 71 for (uint32 i = 0; i < threadsCount; i++) 72 { 73 threadContext[i].thread.Stop(); 74 } 75 } 76 77 FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task) 78 { 79 FiberContext *fiberContext = task.awaitingFiber; 80 if (fiberContext) 81 { 82 task.awaitingFiber = nullptr; 83 return fiberContext; 84 } 85 86 if (!availableFibers.TryPop(fiberContext)) 87 { 88 ASSERT(false, "Fibers pool is empty"); 89 } 90 91 fiberContext->currentTask = task.desc; 92 fiberContext->currentGroup = task.group; 93 fiberContext->parentFiber = task.parentFiber; 94 return fiberContext; 95 } 96 97 void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext) 98 { 99 ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber"); 100 fiberContext->Reset(); 101 availableFibers.Push(fiberContext); 102 } 103 104 FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext) 105 { 106 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 107 108 ASSERT(fiberContext, "Invalid fiber context"); 109 ASSERT(fiberContext->currentTask.IsValid(), "Invalid task"); 110 ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group"); 111 112 // Set actual thread context to fiber 113 fiberContext->SetThreadContext(&threadContext); 114 115 // Update task status 116 fiberContext->SetStatus(FiberTaskStatus::RUNNED); 117 118 ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 119 120 // Run current task code 121 Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber); 122 123 // If task was done 124 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 125 if (taskStatus == FiberTaskStatus::FINISHED) 126 { 127 TaskGroup::Type taskGroup = fiberContext->currentGroup; 128 ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group."); 129 130 // Update group status 131 int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec(); 132 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 133 if (groupTaskCount == 0) 134 { 135 // Restore awaiting tasks 136 threadContext.RestoreAwaitingTasks(taskGroup); 137 threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal(); 138 } 139 140 // Update total task count 141 groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec(); 142 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 143 if (groupTaskCount == 0) 144 { 145 // Notify all tasks in all group finished 146 threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal(); 147 } 148 149 FiberContext* parentFiberContext = fiberContext->parentFiber; 150 if (parentFiberContext != nullptr) 151 { 152 int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec(); 153 ASSERT(childrenFibersCount >= 0, "Sanity check failed!"); 154 155 if (childrenFibersCount == 0) 156 { 157 // This is a last subtask. Restore parent task 158 #if FIBER_DEBUG 159 160 int ownerThread = parentFiberContext->fiber.GetOwnerThread(); 161 FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus(); 162 internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext(); 163 int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter(); 164 ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state"); 165 166 ownerThread; 167 parentTaskStatus; 168 parentThreadContext; 169 fiberUsageCounter; 170 #endif 171 172 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 173 ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context"); 174 175 // WARNING!! Thread context can changed here! Set actual current thread context. 176 parentFiberContext->SetThreadContext(&threadContext); 177 178 ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 179 180 // All subtasks is done. 181 // Exiting and return parent fiber to scheduler 182 return parentFiberContext; 183 } else 184 { 185 // Other subtasks still exist 186 // Exiting 187 return nullptr; 188 } 189 } else 190 { 191 // Task is finished and no parent task 192 // Exiting 193 return nullptr; 194 } 195 } 196 197 ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status") 198 return nullptr; 199 } 200 201 202 void TaskScheduler::FiberMain(void* userData) 203 { 204 FiberContext& fiberContext = *(FiberContext*)(userData); 205 for(;;) 206 { 207 ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context"); 208 ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group"); 209 ASSERT(fiberContext.GetThreadContext(), "Invalid thread context"); 210 ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 211 212 fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData ); 213 214 fiberContext.SetStatus(FiberTaskStatus::FINISHED); 215 216 #ifdef MT_INSTRUMENTED_BUILD 217 fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask); 218 #endif 219 220 Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber); 221 } 222 223 } 224 225 226 bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount) 227 { 228 if (workersCount <= 1) 229 { 230 return false; 231 } 232 233 uint32 victimIndex = threadContext.random.Get(); 234 235 for (uint32 attempt = 0; attempt < workersCount; attempt++) 236 { 237 uint32 index = victimIndex % workersCount; 238 if (index == threadContext.workerIndex) 239 { 240 victimIndex++; 241 index = victimIndex % workersCount; 242 } 243 244 internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index]; 245 if (victimContext.queue.TryPop(task)) 246 { 247 return true; 248 } 249 250 victimIndex++; 251 } 252 return false; 253 } 254 255 void TaskScheduler::ThreadMain( void* userData ) 256 { 257 internal::ThreadContext& context = *(internal::ThreadContext*)(userData); 258 ASSERT(context.taskScheduler, "Task scheduler must be not null!"); 259 context.schedulerFiber.CreateFromThread(context.thread); 260 261 uint32 workersCount = context.taskScheduler->GetWorkerCount(); 262 263 while(context.state.Get() != internal::ThreadState::EXIT) 264 { 265 internal::GroupedTask task; 266 if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) ) 267 { 268 // There is a new task 269 FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task); 270 ASSERT(fiberContext, "Can't get execution context from pool"); 271 ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed"); 272 273 while(fiberContext) 274 { 275 #ifdef MT_INSTRUMENTED_BUILD 276 context.NotifyTaskResumed(fiberContext->currentTask); 277 #endif 278 279 // prevent invalid fiber resume from child tasks, before ExecuteTask is done 280 fiberContext->childrenFibersCount.Inc(); 281 282 FiberContext* parentFiber = ExecuteTask(context, fiberContext); 283 284 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 285 286 //release guard 287 int childrenFibersCount = fiberContext->childrenFibersCount.Dec(); 288 289 // Can drop fiber context - task is finished 290 if (taskStatus == FiberTaskStatus::FINISHED) 291 { 292 ASSERT( childrenFibersCount == 0, "Sanity check failed"); 293 context.taskScheduler->ReleaseFiberContext(fiberContext); 294 295 // If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit 296 fiberContext = parentFiber; 297 } else 298 { 299 ASSERT( childrenFibersCount >= 0, "Sanity check failed"); 300 301 // No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask 302 if (childrenFibersCount == 0) 303 { 304 ASSERT(parentFiber == nullptr, "Sanity check failed"); 305 } else 306 { 307 // If subtasks still exist, drop current task execution. task will be resumed when last subtask finished 308 break; 309 } 310 311 // If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called 312 if (taskStatus == FiberTaskStatus::AWAITING_GROUP) 313 { 314 break; 315 } 316 } 317 } //while(fiberContext) 318 319 } else 320 { 321 #ifdef MT_INSTRUMENTED_BUILD 322 int64 waitFrom = MT::GetTimeMicroSeconds(); 323 #endif 324 325 // Queue is empty and stealing attempt failed 326 // Wait new events 327 context.hasNewTasksEvent.Wait(2000); 328 329 #ifdef MT_INSTRUMENTED_BUILD 330 int64 waitTo = MT::GetTimeMicroSeconds(); 331 context.NotifyWorkerAwait(waitFrom, waitTo); 332 #endif 333 334 } 335 336 } // main thread loop 337 } 338 339 void TaskScheduler::RunTasksImpl(WrapperArray<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState) 340 { 341 // Reset counter to initial value 342 int taskCountInGroup[TaskGroup::COUNT]; 343 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 344 { 345 taskCountInGroup[i] = 0; 346 } 347 348 // Set parent fiber pointer 349 // Calculate the number of tasks per group 350 // Calculate total number of tasks 351 size_t count = 0; 352 for (size_t i = 0; i < buckets.Size(); ++i) 353 { 354 internal::TaskBucket& bucket = buckets[i]; 355 for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++) 356 { 357 internal::GroupedTask & task = bucket.tasks[taskIndex]; 358 359 ASSERT(task.group < TaskGroup::COUNT, "Invalid group."); 360 361 task.parentFiber = parentFiber; 362 taskCountInGroup[task.group]++; 363 } 364 count += bucket.count; 365 } 366 367 // Increments child fibers count on parent fiber 368 if (parentFiber) 369 { 370 parentFiber->childrenFibersCount.Add((uint32)count); 371 } 372 373 if (restoredFromAwaitState == false) 374 { 375 // Increments all task in progress counter 376 allGroupStats.allDoneEvent.Reset(); 377 allGroupStats.inProgressTaskCount.Add((uint32)count); 378 379 // Increments task in progress counters (per group) 380 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 381 { 382 int groupTaskCount = taskCountInGroup[i]; 383 if (groupTaskCount > 0) 384 { 385 groupStats[i].allDoneEvent.Reset(); 386 groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount); 387 } 388 } 389 } else 390 { 391 // If task's restored from await state, counters already in correct state 392 } 393 394 // Add to thread queue 395 for (size_t i = 0; i < buckets.Size(); ++i) 396 { 397 int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount; 398 internal::ThreadContext & context = threadContext[bucketIndex]; 399 400 internal::TaskBucket& bucket = buckets[i]; 401 402 context.queue.PushRange(bucket.tasks, bucket.count); 403 context.hasNewTasksEvent.Signal(); 404 } 405 } 406 407 bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds) 408 { 409 VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false); 410 411 return groupStats[group].allDoneEvent.Wait(milliseconds); 412 } 413 414 bool TaskScheduler::WaitAll(uint32 milliseconds) 415 { 416 VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false); 417 418 return allGroupStats.allDoneEvent.Wait(milliseconds); 419 } 420 421 bool TaskScheduler::IsEmpty() 422 { 423 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 424 { 425 if (!threadContext[i].queue.IsEmpty()) 426 { 427 return false; 428 } 429 } 430 return true; 431 } 432 433 uint32 TaskScheduler::GetWorkerCount() const 434 { 435 return threadsCount; 436 } 437 438 bool TaskScheduler::IsWorkerThread() const 439 { 440 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 441 { 442 if (threadContext[i].thread.IsCurrentThread()) 443 { 444 return true; 445 } 446 } 447 return false; 448 } 449 450 #ifdef MT_INSTRUMENTED_BUILD 451 452 size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize) 453 { 454 if (workerIndex >= MT_MAX_THREAD_COUNT) 455 { 456 return 0; 457 } 458 459 size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize); 460 return elementsCount; 461 } 462 463 void TaskScheduler::UpdateProfiler() 464 { 465 profilerWebServer.Update(*this); 466 } 467 468 int64 TaskScheduler::GetStartTime() 469 { 470 static int64 startTime = GetTimeMicroSeconds(); 471 return startTime; 472 } 473 474 #endif 475 } 476