1 // The MIT License (MIT) 2 // 3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev 4 // 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 // 12 // The above copyright notice and this permission notice shall be included in 13 // all copies or substantial portions of the Software. 14 // 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 // THE SOFTWARE. 22 23 #include <MTScheduler.h> 24 25 namespace MT 26 { 27 28 TaskScheduler::TaskScheduler() 29 : roundRobinThreadIndex(0) 30 { 31 #ifdef MT_INSTRUMENTED_BUILD 32 profilerWebServer.Serve(8080, 8090); 33 34 //initialize static start time 35 TaskScheduler::GetStartTime(); 36 #endif 37 38 //query number of processor 39 threadsCount = Max(Thread::GetNumberOfHardwareThreads() - 2, 1); 40 41 if (threadsCount > MT_MAX_THREAD_COUNT) 42 { 43 threadsCount = MT_MAX_THREAD_COUNT; 44 } 45 46 // create fiber pool 47 for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++) 48 { 49 FiberContext& context = fiberContext[i]; 50 context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context); 51 availableFibers.Push( &context ); 52 } 53 54 // create worker thread pool 55 for (uint32 i = 0; i < threadsCount; i++) 56 { 57 threadContext[i].SetThreadIndex(i); 58 threadContext[i].taskScheduler = this; 59 threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] ); 60 } 61 } 62 63 TaskScheduler::~TaskScheduler() 64 { 65 for (uint32 i = 0; i < threadsCount; i++) 66 { 67 threadContext[i].state.Set(internal::ThreadState::EXIT); 68 threadContext[i].hasNewTasksEvent.Signal(); 69 } 70 71 for (uint32 i = 0; i < threadsCount; i++) 72 { 73 threadContext[i].thread.Stop(); 74 } 75 } 76 77 FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task) 78 { 79 FiberContext *fiberContext = task.awaitingFiber; 80 if (fiberContext) 81 { 82 task.awaitingFiber = nullptr; 83 return fiberContext; 84 } 85 86 if (!availableFibers.TryPop(fiberContext)) 87 { 88 ASSERT(false, "Fibers pool is empty"); 89 } 90 91 fiberContext->currentTask = task.desc; 92 fiberContext->currentGroup = task.group; 93 fiberContext->parentFiber = task.parentFiber; 94 return fiberContext; 95 } 96 97 void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext) 98 { 99 ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber"); 100 fiberContext->Reset(); 101 availableFibers.Push(fiberContext); 102 } 103 104 FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext) 105 { 106 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 107 108 ASSERT(fiberContext, "Invalid fiber context"); 109 ASSERT(fiberContext->currentTask.IsValid(), "Invalid task"); 110 ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group"); 111 112 // Set actual thread context to fiber 113 fiberContext->SetThreadContext(&threadContext); 114 115 // Update task status 116 fiberContext->SetStatus(FiberTaskStatus::RUNNED); 117 118 ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 119 120 // Run current task code 121 Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber); 122 123 // If task was done 124 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 125 if (taskStatus == FiberTaskStatus::FINISHED) 126 { 127 TaskGroup::Type taskGroup = fiberContext->currentGroup; 128 ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group."); 129 130 // Update group status 131 int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec(); 132 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 133 if (groupTaskCount == 0) 134 { 135 // Restore awaiting tasks 136 threadContext.RestoreAwaitingTasks(taskGroup); 137 threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal(); 138 } 139 140 // Update total task count 141 groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec(); 142 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 143 if (groupTaskCount == 0) 144 { 145 // Notify all tasks in all group finished 146 threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal(); 147 } 148 149 FiberContext* parentFiberContext = fiberContext->parentFiber; 150 if (parentFiberContext != nullptr) 151 { 152 int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec(); 153 ASSERT(childrenFibersCount >= 0, "Sanity check failed!"); 154 155 if (childrenFibersCount == 0) 156 { 157 // This is a last subtask. Restore parent task 158 #if FIBER_DEBUG 159 160 int ownerThread = parentFiberContext->fiber.GetOwnerThread(); 161 FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus(); 162 internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext(); 163 int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter(); 164 ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state"); 165 166 ownerThread; 167 parentTaskStatus; 168 parentThreadContext; 169 fiberUsageCounter; 170 #endif 171 172 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 173 ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context"); 174 175 // WARNING!! Thread context can changed here! Set actual current thread context. 176 parentFiberContext->SetThreadContext(&threadContext); 177 178 ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 179 180 // All subtasks is done. 181 // Exiting and return parent fiber to scheduler 182 return parentFiberContext; 183 } else 184 { 185 // Other subtasks still exist 186 // Exiting 187 return nullptr; 188 } 189 } else 190 { 191 // Task is finished and no parent task 192 // Exiting 193 return nullptr; 194 } 195 } 196 197 ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status") 198 return nullptr; 199 } 200 201 202 void TaskScheduler::FiberMain(void* userData) 203 { 204 FiberContext& fiberContext = *(FiberContext*)(userData); 205 for(;;) 206 { 207 ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context"); 208 ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group"); 209 ASSERT(fiberContext.GetThreadContext(), "Invalid thread context"); 210 ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 211 212 fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData ); 213 214 fiberContext.SetStatus(FiberTaskStatus::FINISHED); 215 216 #ifdef MT_INSTRUMENTED_BUILD 217 fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask); 218 #endif 219 220 Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber); 221 } 222 223 } 224 225 226 bool TaskScheduler::StealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task) 227 { 228 // Try to steal tasks from random worker thread 229 uint32 workersCount = threadContext.taskScheduler->GetWorkerCount(); 230 if (workersCount <= 1) 231 { 232 return false; 233 } 234 235 uint32 victimIndex = threadContext.random.Get() % workersCount; 236 if (victimIndex == threadContext.workerIndex) 237 { 238 victimIndex = victimIndex++; 239 victimIndex = victimIndex % workersCount; 240 } 241 242 internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[victimIndex]; 243 return victimContext.queue.TryPop(task); 244 } 245 246 void TaskScheduler::ThreadMain( void* userData ) 247 { 248 internal::ThreadContext& context = *(internal::ThreadContext*)(userData); 249 ASSERT(context.taskScheduler, "Task scheduler must be not null!"); 250 context.schedulerFiber.CreateFromThread(context.thread); 251 252 while(context.state.Get() != internal::ThreadState::EXIT) 253 { 254 internal::GroupedTask task; 255 if (context.queue.TryPop(task) || StealTask(context, task) ) 256 { 257 // There is a new task 258 FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task); 259 ASSERT(fiberContext, "Can't get execution context from pool"); 260 ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed"); 261 262 while(fiberContext) 263 { 264 #ifdef MT_INSTRUMENTED_BUILD 265 context.NotifyTaskResumed(task.desc); 266 #endif 267 268 // prevent invalid fiber resume from child tasks, before ExecuteTask is done 269 fiberContext->childrenFibersCount.Inc(); 270 271 FiberContext* parentFiber = ExecuteTask(context, fiberContext); 272 273 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 274 275 //release guard 276 int childrenFibersCount = fiberContext->childrenFibersCount.Dec(); 277 278 // Can drop fiber context - task is finished 279 if (taskStatus == FiberTaskStatus::FINISHED) 280 { 281 ASSERT( childrenFibersCount == 0, "Sanity check failed"); 282 context.taskScheduler->ReleaseFiberContext(fiberContext); 283 284 // If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit 285 fiberContext = parentFiber; 286 } else 287 { 288 ASSERT( childrenFibersCount >= 0, "Sanity check failed"); 289 290 // No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask 291 if (childrenFibersCount == 0) 292 { 293 ASSERT(parentFiber == nullptr, "Sanity check failed"); 294 } else 295 { 296 // If subtasks still exist, drop current task execution. task will be resumed when last subtask finished 297 break; 298 } 299 300 // If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called 301 if (taskStatus == FiberTaskStatus::AWAITING_GROUP) 302 { 303 break; 304 } 305 } 306 } //while(fiberContext) 307 308 } else 309 { 310 // Queue is empty and stealing attempt failed 311 // Wait new events 312 context.hasNewTasksEvent.Wait(2000); 313 } 314 315 } // main thread loop 316 } 317 318 void TaskScheduler::RunTasksImpl(WrapperArray<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState) 319 { 320 // Reset counter to initial value 321 int taskCountInGroup[TaskGroup::COUNT]; 322 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 323 { 324 taskCountInGroup[i] = 0; 325 } 326 327 // Set parent fiber pointer 328 // Calculate the number of tasks per group 329 // Calculate total number of tasks 330 size_t count = 0; 331 for (size_t i = 0; i < buckets.Size(); ++i) 332 { 333 internal::TaskBucket& bucket = buckets[i]; 334 for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++) 335 { 336 internal::GroupedTask & task = bucket.tasks[taskIndex]; 337 338 ASSERT(task.group < TaskGroup::COUNT, "Invalid group."); 339 340 task.parentFiber = parentFiber; 341 taskCountInGroup[task.group]++; 342 } 343 count += bucket.count; 344 } 345 346 // Increments child fibers count on parent fiber 347 if (parentFiber) 348 { 349 parentFiber->childrenFibersCount.Add((uint32)count); 350 } 351 352 if (restoredFromAwaitState == false) 353 { 354 // Increments all task in progress counter 355 allGroupStats.allDoneEvent.Reset(); 356 allGroupStats.inProgressTaskCount.Add((uint32)count); 357 358 // Increments task in progress counters (per group) 359 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 360 { 361 int groupTaskCount = taskCountInGroup[i]; 362 if (groupTaskCount > 0) 363 { 364 groupStats[i].allDoneEvent.Reset(); 365 groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount); 366 } 367 } 368 } else 369 { 370 // If task's restored from await state, counters already in correct state 371 } 372 373 // Add to thread queue 374 for (size_t i = 0; i < buckets.Size(); ++i) 375 { 376 int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount; 377 internal::ThreadContext & context = threadContext[bucketIndex]; 378 379 internal::TaskBucket& bucket = buckets[i]; 380 381 context.queue.PushRange(bucket.tasks, bucket.count); 382 context.hasNewTasksEvent.Signal(); 383 } 384 } 385 386 bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds) 387 { 388 VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false); 389 390 return groupStats[group].allDoneEvent.Wait(milliseconds); 391 } 392 393 bool TaskScheduler::WaitAll(uint32 milliseconds) 394 { 395 VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false); 396 397 return allGroupStats.allDoneEvent.Wait(milliseconds); 398 } 399 400 bool TaskScheduler::IsEmpty() 401 { 402 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 403 { 404 if (!threadContext[i].queue.IsEmpty()) 405 { 406 return false; 407 } 408 } 409 return true; 410 } 411 412 uint32 TaskScheduler::GetWorkerCount() const 413 { 414 return threadsCount; 415 } 416 417 bool TaskScheduler::IsWorkerThread() const 418 { 419 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 420 { 421 if (threadContext[i].thread.IsCurrentThread()) 422 { 423 return true; 424 } 425 } 426 return false; 427 } 428 429 #ifdef MT_INSTRUMENTED_BUILD 430 431 size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize) 432 { 433 if (workerIndex >= MT_MAX_THREAD_COUNT) 434 { 435 return 0; 436 } 437 438 size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize); 439 return elementsCount; 440 } 441 442 void TaskScheduler::UpdateProfiler() 443 { 444 profilerWebServer.Update(*this); 445 } 446 447 int64 TaskScheduler::GetStartTime() 448 { 449 static int64 startTime = GetTimeMicroSeconds(); 450 return startTime; 451 } 452 453 #endif 454 } 455