1 // The MIT License (MIT) 2 // 3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev 4 // 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 // 12 // The above copyright notice and this permission notice shall be included in 13 // all copies or substantial portions of the Software. 14 // 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 // THE SOFTWARE. 22 23 #include <MTScheduler.h> 24 25 namespace MT 26 { 27 28 TaskScheduler::TaskScheduler(uint32 workerThreadsCount) 29 : roundRobinThreadIndex(0) 30 , startedThreadsCount(0) 31 { 32 #ifdef MT_INSTRUMENTED_BUILD 33 webServerPort = profilerWebServer.Serve(8080, 8090); 34 35 //initialize start time 36 startTime = MT::GetTimeMicroSeconds(); 37 #endif 38 39 if (workerThreadsCount != 0) 40 { 41 threadsCount = MT::Clamp(workerThreadsCount, (uint32)1, (uint32)MT_MAX_THREAD_COUNT); 42 } else 43 { 44 //query number of processor 45 threadsCount = (uint32)MT::Clamp(Thread::GetNumberOfHardwareThreads() - 2, 1, (int)MT_MAX_THREAD_COUNT); 46 } 47 48 // create fiber pool 49 for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++) 50 { 51 FiberContext& context = fiberContext[i]; 52 context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context); 53 availableFibers.Push( &context ); 54 } 55 56 // create worker thread pool 57 for (uint32 i = 0; i < threadsCount; i++) 58 { 59 threadContext[i].SetThreadIndex(i); 60 threadContext[i].taskScheduler = this; 61 threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] ); 62 } 63 } 64 65 TaskScheduler::~TaskScheduler() 66 { 67 for (uint32 i = 0; i < threadsCount; i++) 68 { 69 threadContext[i].state.Set(internal::ThreadState::EXIT); 70 threadContext[i].hasNewTasksEvent.Signal(); 71 } 72 73 for (uint32 i = 0; i < threadsCount; i++) 74 { 75 threadContext[i].thread.Stop(); 76 } 77 } 78 79 FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task) 80 { 81 FiberContext *fiberContext = task.awaitingFiber; 82 if (fiberContext) 83 { 84 task.awaitingFiber = nullptr; 85 return fiberContext; 86 } 87 88 if (!availableFibers.TryPop(fiberContext)) 89 { 90 MT_ASSERT(false, "Fibers pool is empty"); 91 } 92 93 fiberContext->currentTask = task.desc; 94 fiberContext->currentGroup = task.group; 95 fiberContext->parentFiber = task.parentFiber; 96 return fiberContext; 97 } 98 99 void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext) 100 { 101 MT_ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber"); 102 fiberContext->Reset(); 103 availableFibers.Push(fiberContext); 104 } 105 106 FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext) 107 { 108 MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 109 110 MT_ASSERT(fiberContext, "Invalid fiber context"); 111 MT_ASSERT(fiberContext->currentTask.IsValid(), "Invalid task"); 112 113 // Set actual thread context to fiber 114 fiberContext->SetThreadContext(&threadContext); 115 116 // Update task status 117 fiberContext->SetStatus(FiberTaskStatus::RUNNED); 118 119 MT_ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 120 121 // Run current task code 122 Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber); 123 124 // If task was done 125 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 126 if (taskStatus == FiberTaskStatus::FINISHED) 127 { 128 TaskGroup* taskGroup = fiberContext->currentGroup; 129 130 int groupTaskCount = 0; 131 132 // Update group status 133 if (taskGroup != nullptr) 134 { 135 groupTaskCount = taskGroup->inProgressTaskCount.Dec(); 136 MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 137 if (groupTaskCount == 0) 138 { 139 // Restore awaiting tasks 140 threadContext.RestoreAwaitingTasks(taskGroup); 141 taskGroup->allDoneEvent.Signal(); 142 } 143 } 144 145 // Update total task count 146 groupTaskCount = threadContext.taskScheduler->allGroups.inProgressTaskCount.Dec(); 147 MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 148 if (groupTaskCount == 0) 149 { 150 // Notify all tasks in all group finished 151 threadContext.taskScheduler->allGroups.allDoneEvent.Signal(); 152 } 153 154 FiberContext* parentFiberContext = fiberContext->parentFiber; 155 if (parentFiberContext != nullptr) 156 { 157 int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec(); 158 MT_ASSERT(childrenFibersCount >= 0, "Sanity check failed!"); 159 160 if (childrenFibersCount == 0) 161 { 162 // This is a last subtask. Restore parent task 163 #if MT_FIBER_DEBUG 164 165 int ownerThread = parentFiberContext->fiber.GetOwnerThread(); 166 FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus(); 167 internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext(); 168 int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter(); 169 MT_ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state"); 170 171 ownerThread; 172 parentTaskStatus; 173 parentThreadContext; 174 fiberUsageCounter; 175 #endif 176 177 MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 178 MT_ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context"); 179 180 // WARNING!! Thread context can changed here! Set actual current thread context. 181 parentFiberContext->SetThreadContext(&threadContext); 182 183 MT_ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 184 185 // All subtasks is done. 186 // Exiting and return parent fiber to scheduler 187 return parentFiberContext; 188 } else 189 { 190 // Other subtasks still exist 191 // Exiting 192 return nullptr; 193 } 194 } else 195 { 196 // Task is finished and no parent task 197 // Exiting 198 return nullptr; 199 } 200 } 201 202 MT_ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status") 203 return nullptr; 204 } 205 206 207 void TaskScheduler::FiberMain(void* userData) 208 { 209 FiberContext& fiberContext = *(FiberContext*)(userData); 210 for(;;) 211 { 212 MT_ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context"); 213 MT_ASSERT(fiberContext.GetThreadContext(), "Invalid thread context"); 214 MT_ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 215 216 fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData ); 217 218 fiberContext.SetStatus(FiberTaskStatus::FINISHED); 219 220 #ifdef MT_INSTRUMENTED_BUILD 221 fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask); 222 #endif 223 224 Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber); 225 } 226 227 } 228 229 230 bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount) 231 { 232 if (workersCount <= 1) 233 { 234 return false; 235 } 236 237 uint32 victimIndex = threadContext.random.Get(); 238 239 for (uint32 attempt = 0; attempt < workersCount; attempt++) 240 { 241 uint32 index = victimIndex % workersCount; 242 if (index == threadContext.workerIndex) 243 { 244 victimIndex++; 245 index = victimIndex % workersCount; 246 } 247 248 internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index]; 249 if (victimContext.queue.TryPop(task)) 250 { 251 return true; 252 } 253 254 victimIndex++; 255 } 256 return false; 257 } 258 259 void TaskScheduler::ThreadMain( void* userData ) 260 { 261 internal::ThreadContext& context = *(internal::ThreadContext*)(userData); 262 MT_ASSERT(context.taskScheduler, "Task scheduler must be not null!"); 263 context.schedulerFiber.CreateFromThread(context.thread); 264 265 uint32 workersCount = context.taskScheduler->GetWorkerCount(); 266 267 context.taskScheduler->startedThreadsCount.Inc(); 268 269 //Spinlock until all threads started and initialized 270 while(true) 271 { 272 if (context.taskScheduler->startedThreadsCount.Get() == (int)context.taskScheduler->threadsCount) 273 { 274 break; 275 } 276 Thread::Sleep(1); 277 } 278 279 while(context.state.Get() != internal::ThreadState::EXIT) 280 { 281 internal::GroupedTask task; 282 if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) ) 283 { 284 // There is a new task 285 FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task); 286 MT_ASSERT(fiberContext, "Can't get execution context from pool"); 287 MT_ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed"); 288 289 while(fiberContext) 290 { 291 #ifdef MT_INSTRUMENTED_BUILD 292 context.NotifyTaskResumed(fiberContext->currentTask); 293 #endif 294 295 // prevent invalid fiber resume from child tasks, before ExecuteTask is done 296 fiberContext->childrenFibersCount.Inc(); 297 298 FiberContext* parentFiber = ExecuteTask(context, fiberContext); 299 300 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 301 302 //release guard 303 int childrenFibersCount = fiberContext->childrenFibersCount.Dec(); 304 305 // Can drop fiber context - task is finished 306 if (taskStatus == FiberTaskStatus::FINISHED) 307 { 308 MT_ASSERT( childrenFibersCount == 0, "Sanity check failed"); 309 context.taskScheduler->ReleaseFiberContext(fiberContext); 310 311 // If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit 312 fiberContext = parentFiber; 313 } else 314 { 315 MT_ASSERT( childrenFibersCount >= 0, "Sanity check failed"); 316 317 // No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask 318 if (childrenFibersCount == 0) 319 { 320 MT_ASSERT(parentFiber == nullptr, "Sanity check failed"); 321 } else 322 { 323 // If subtasks still exist, drop current task execution. task will be resumed when last subtask finished 324 break; 325 } 326 327 // If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called 328 if (taskStatus == FiberTaskStatus::AWAITING_GROUP) 329 { 330 break; 331 } 332 } 333 } //while(fiberContext) 334 335 } else 336 { 337 #ifdef MT_INSTRUMENTED_BUILD 338 int64 waitFrom = MT::GetTimeMicroSeconds(); 339 #endif 340 341 // Queue is empty and stealing attempt failed 342 // Wait new events 343 context.hasNewTasksEvent.Wait(2000); 344 345 #ifdef MT_INSTRUMENTED_BUILD 346 int64 waitTo = MT::GetTimeMicroSeconds(); 347 context.NotifyWorkerAwait(waitFrom, waitTo); 348 #endif 349 350 } 351 352 } // main thread loop 353 } 354 355 void TaskScheduler::RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState) 356 { 357 // Set parent fiber pointer 358 // Calculate the number of tasks per group 359 // Calculate total number of tasks 360 size_t count = 0; 361 for (size_t i = 0; i < buckets.Size(); ++i) 362 { 363 internal::TaskBucket& bucket = buckets[i]; 364 for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++) 365 { 366 internal::GroupedTask & task = bucket.tasks[taskIndex]; 367 368 task.parentFiber = parentFiber; 369 370 if (task.group != nullptr) 371 { 372 //TODO: reduce the number of reset calls 373 task.group->allDoneEvent.Reset(); 374 task.group->inProgressTaskCount.Inc(); 375 } 376 } 377 378 count += bucket.count; 379 } 380 381 // Increments child fibers count on parent fiber 382 if (parentFiber) 383 { 384 parentFiber->childrenFibersCount.Add((uint32)count); 385 } 386 387 if (restoredFromAwaitState == false) 388 { 389 // Increments all task in progress counter 390 allGroups.allDoneEvent.Reset(); 391 allGroups.inProgressTaskCount.Add((uint32)count); 392 } else 393 { 394 // If task's restored from await state, counters already in correct state 395 } 396 397 // Add to thread queue 398 for (size_t i = 0; i < buckets.Size(); ++i) 399 { 400 int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount; 401 internal::ThreadContext & context = threadContext[bucketIndex]; 402 403 internal::TaskBucket& bucket = buckets[i]; 404 405 context.queue.PushRange(bucket.tasks, bucket.count); 406 context.hasNewTasksEvent.Signal(); 407 } 408 } 409 410 bool TaskScheduler::WaitGroup(TaskGroup* group, uint32 milliseconds) 411 { 412 MT_VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false); 413 414 return group->allDoneEvent.Wait(milliseconds); 415 } 416 417 bool TaskScheduler::WaitAll(uint32 milliseconds) 418 { 419 MT_VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false); 420 421 return allGroups.allDoneEvent.Wait(milliseconds); 422 } 423 424 bool TaskScheduler::IsEmpty() 425 { 426 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 427 { 428 if (!threadContext[i].queue.IsEmpty()) 429 { 430 return false; 431 } 432 } 433 return true; 434 } 435 436 uint32 TaskScheduler::GetWorkerCount() const 437 { 438 return threadsCount; 439 } 440 441 bool TaskScheduler::IsWorkerThread() const 442 { 443 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 444 { 445 if (threadContext[i].thread.IsCurrentThread()) 446 { 447 return true; 448 } 449 } 450 return false; 451 } 452 453 #ifdef MT_INSTRUMENTED_BUILD 454 455 size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize) 456 { 457 if (workerIndex >= MT_MAX_THREAD_COUNT) 458 { 459 return 0; 460 } 461 462 size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize); 463 return elementsCount; 464 } 465 466 void TaskScheduler::UpdateProfiler() 467 { 468 profilerWebServer.Update(*this); 469 } 470 471 int32 TaskScheduler::GetWebServerPort() const 472 { 473 return webServerPort; 474 } 475 476 477 478 #endif 479 } 480