1 // The MIT License (MIT) 2 // 3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev 4 // 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 // 12 // The above copyright notice and this permission notice shall be included in 13 // all copies or substantial portions of the Software. 14 // 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 // THE SOFTWARE. 22 23 #include <MTScheduler.h> 24 25 namespace MT 26 { 27 28 TaskScheduler::TaskScheduler(uint32 workerThreadsCount) 29 : roundRobinThreadIndex(0) 30 , startedThreadsCount(0) 31 { 32 #ifdef MT_INSTRUMENTED_BUILD 33 webServerPort = profilerWebServer.Serve(8080, 8090); 34 35 //initialize start time 36 startTime = MT::GetTimeMicroSeconds(); 37 #endif 38 39 workerThreadsCount = 50; 40 41 if (workerThreadsCount != 0) 42 { 43 threadsCount = MT::Clamp(workerThreadsCount, (uint32)1, (uint32)MT_MAX_THREAD_COUNT); 44 } else 45 { 46 //query number of processor 47 threadsCount = (uint32)MT::Clamp(Thread::GetNumberOfHardwareThreads() - 2, 1, (int)MT_MAX_THREAD_COUNT); 48 } 49 50 // create fiber pool 51 for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++) 52 { 53 FiberContext& context = fiberContext[i]; 54 context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context); 55 availableFibers.Push( &context ); 56 } 57 58 // create worker thread pool 59 for (uint32 i = 0; i < threadsCount; i++) 60 { 61 threadContext[i].SetThreadIndex(i); 62 threadContext[i].taskScheduler = this; 63 threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] ); 64 } 65 } 66 67 TaskScheduler::~TaskScheduler() 68 { 69 for (uint32 i = 0; i < threadsCount; i++) 70 { 71 threadContext[i].state.Set(internal::ThreadState::EXIT); 72 threadContext[i].hasNewTasksEvent.Signal(); 73 } 74 75 for (uint32 i = 0; i < threadsCount; i++) 76 { 77 threadContext[i].thread.Stop(); 78 } 79 } 80 81 FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task) 82 { 83 FiberContext *fiberContext = task.awaitingFiber; 84 if (fiberContext) 85 { 86 task.awaitingFiber = nullptr; 87 return fiberContext; 88 } 89 90 if (!availableFibers.TryPop(fiberContext)) 91 { 92 MT_ASSERT(false, "Fibers pool is empty"); 93 } 94 95 fiberContext->currentTask = task.desc; 96 fiberContext->currentGroup = task.group; 97 fiberContext->parentFiber = task.parentFiber; 98 return fiberContext; 99 } 100 101 void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext) 102 { 103 MT_ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber"); 104 fiberContext->Reset(); 105 availableFibers.Push(fiberContext); 106 } 107 108 FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext) 109 { 110 MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 111 112 MT_ASSERT(fiberContext, "Invalid fiber context"); 113 MT_ASSERT(fiberContext->currentTask.IsValid(), "Invalid task"); 114 115 // Set actual thread context to fiber 116 fiberContext->SetThreadContext(&threadContext); 117 118 // Update task status 119 fiberContext->SetStatus(FiberTaskStatus::RUNNED); 120 121 MT_ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 122 123 // Run current task code 124 Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber); 125 126 // If task was done 127 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 128 if (taskStatus == FiberTaskStatus::FINISHED) 129 { 130 TaskGroup* taskGroup = fiberContext->currentGroup; 131 132 int groupTaskCount = 0; 133 134 // Update group status 135 if (taskGroup != nullptr) 136 { 137 groupTaskCount = taskGroup->Dec(); 138 MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 139 if (groupTaskCount == 0) 140 { 141 // Restore awaiting tasks 142 threadContext.RestoreAwaitingTasks(taskGroup); 143 taskGroup->Signal(); 144 } 145 } 146 147 // Update total task count 148 groupTaskCount = threadContext.taskScheduler->allGroups.Dec(); 149 MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 150 if (groupTaskCount == 0) 151 { 152 // Notify all tasks in all group finished 153 threadContext.taskScheduler->allGroups.Signal(); 154 } 155 156 FiberContext* parentFiberContext = fiberContext->parentFiber; 157 if (parentFiberContext != nullptr) 158 { 159 int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec(); 160 MT_ASSERT(childrenFibersCount >= 0, "Sanity check failed!"); 161 162 if (childrenFibersCount == 0) 163 { 164 // This is a last subtask. Restore parent task 165 #if MT_FIBER_DEBUG 166 167 int ownerThread = parentFiberContext->fiber.GetOwnerThread(); 168 FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus(); 169 internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext(); 170 int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter(); 171 MT_ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state"); 172 173 ownerThread; 174 parentTaskStatus; 175 parentThreadContext; 176 fiberUsageCounter; 177 #endif 178 179 MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 180 MT_ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context"); 181 182 // WARNING!! Thread context can changed here! Set actual current thread context. 183 parentFiberContext->SetThreadContext(&threadContext); 184 185 MT_ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 186 187 // All subtasks is done. 188 // Exiting and return parent fiber to scheduler 189 return parentFiberContext; 190 } else 191 { 192 // Other subtasks still exist 193 // Exiting 194 return nullptr; 195 } 196 } else 197 { 198 // Task is finished and no parent task 199 // Exiting 200 return nullptr; 201 } 202 } 203 204 MT_ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status") 205 return nullptr; 206 } 207 208 209 void TaskScheduler::FiberMain(void* userData) 210 { 211 FiberContext& fiberContext = *(FiberContext*)(userData); 212 for(;;) 213 { 214 MT_ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context"); 215 MT_ASSERT(fiberContext.GetThreadContext(), "Invalid thread context"); 216 MT_ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 217 218 fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData ); 219 220 fiberContext.SetStatus(FiberTaskStatus::FINISHED); 221 222 #ifdef MT_INSTRUMENTED_BUILD 223 fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask); 224 #endif 225 226 Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber); 227 } 228 229 } 230 231 232 bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount) 233 { 234 if (workersCount <= 1) 235 { 236 return false; 237 } 238 239 uint32 victimIndex = threadContext.random.Get(); 240 241 for (uint32 attempt = 0; attempt < workersCount; attempt++) 242 { 243 uint32 index = victimIndex % workersCount; 244 if (index == threadContext.workerIndex) 245 { 246 victimIndex++; 247 index = victimIndex % workersCount; 248 } 249 250 internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index]; 251 if (victimContext.queue.TryPop(task)) 252 { 253 return true; 254 } 255 256 victimIndex++; 257 } 258 return false; 259 } 260 261 void TaskScheduler::ThreadMain( void* userData ) 262 { 263 internal::ThreadContext& context = *(internal::ThreadContext*)(userData); 264 MT_ASSERT(context.taskScheduler, "Task scheduler must be not null!"); 265 context.schedulerFiber.CreateFromThread(context.thread); 266 267 uint32 workersCount = context.taskScheduler->GetWorkerCount(); 268 269 context.taskScheduler->startedThreadsCount.Inc(); 270 271 //Spinlock until all threads started and initialized 272 while(true) 273 { 274 if (context.taskScheduler->startedThreadsCount.Get() == (int)context.taskScheduler->threadsCount) 275 { 276 break; 277 } 278 Thread::Sleep(1); 279 } 280 281 while(context.state.Get() != internal::ThreadState::EXIT) 282 { 283 internal::GroupedTask task; 284 if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) ) 285 { 286 // There is a new task 287 FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task); 288 MT_ASSERT(fiberContext, "Can't get execution context from pool"); 289 MT_ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed"); 290 291 while(fiberContext) 292 { 293 #ifdef MT_INSTRUMENTED_BUILD 294 context.NotifyTaskResumed(fiberContext->currentTask); 295 #endif 296 297 // prevent invalid fiber resume from child tasks, before ExecuteTask is done 298 fiberContext->childrenFibersCount.Inc(); 299 300 FiberContext* parentFiber = ExecuteTask(context, fiberContext); 301 302 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 303 304 //release guard 305 int childrenFibersCount = fiberContext->childrenFibersCount.Dec(); 306 307 // Can drop fiber context - task is finished 308 if (taskStatus == FiberTaskStatus::FINISHED) 309 { 310 MT_ASSERT( childrenFibersCount == 0, "Sanity check failed"); 311 context.taskScheduler->ReleaseFiberContext(fiberContext); 312 313 // If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit 314 fiberContext = parentFiber; 315 } else 316 { 317 MT_ASSERT( childrenFibersCount >= 0, "Sanity check failed"); 318 319 // No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask 320 if (childrenFibersCount == 0) 321 { 322 MT_ASSERT(parentFiber == nullptr, "Sanity check failed"); 323 } else 324 { 325 // If subtasks still exist, drop current task execution. task will be resumed when last subtask finished 326 break; 327 } 328 329 // If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called 330 if (taskStatus == FiberTaskStatus::AWAITING_GROUP) 331 { 332 break; 333 } 334 } 335 } //while(fiberContext) 336 337 } else 338 { 339 #ifdef MT_INSTRUMENTED_BUILD 340 int64 waitFrom = MT::GetTimeMicroSeconds(); 341 #endif 342 343 // Queue is empty and stealing attempt failed 344 // Wait new events 345 context.hasNewTasksEvent.Wait(2000); 346 347 #ifdef MT_INSTRUMENTED_BUILD 348 int64 waitTo = MT::GetTimeMicroSeconds(); 349 context.NotifyWorkerAwait(waitFrom, waitTo); 350 #endif 351 352 } 353 354 } // main thread loop 355 } 356 357 void TaskScheduler::RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState) 358 { 359 // Set parent fiber pointer 360 // Calculate the number of tasks per group 361 // Calculate total number of tasks 362 size_t count = 0; 363 for (size_t i = 0; i < buckets.Size(); ++i) 364 { 365 internal::TaskBucket& bucket = buckets[i]; 366 for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++) 367 { 368 internal::GroupedTask & task = bucket.tasks[taskIndex]; 369 370 task.parentFiber = parentFiber; 371 372 if (task.group != nullptr) 373 { 374 //TODO: reduce the number of reset calls 375 task.group->Reset(); 376 task.group->Inc(); 377 } 378 } 379 380 count += bucket.count; 381 } 382 383 // Increments child fibers count on parent fiber 384 if (parentFiber) 385 { 386 parentFiber->childrenFibersCount.Add((uint32)count); 387 } 388 389 if (restoredFromAwaitState == false) 390 { 391 // Increments all task in progress counter 392 allGroups.Reset(); 393 allGroups.Add((uint32)count); 394 } else 395 { 396 // If task's restored from await state, counters already in correct state 397 } 398 399 // Add to thread queue 400 for (size_t i = 0; i < buckets.Size(); ++i) 401 { 402 int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount; 403 internal::ThreadContext & context = threadContext[bucketIndex]; 404 405 internal::TaskBucket& bucket = buckets[i]; 406 407 context.queue.PushRange(bucket.tasks, bucket.count); 408 context.hasNewTasksEvent.Signal(); 409 } 410 } 411 412 bool TaskScheduler::WaitGroup(TaskGroup* group, uint32 milliseconds) 413 { 414 MT_VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false); 415 416 return group->Wait(milliseconds); 417 } 418 419 bool TaskScheduler::WaitAll(uint32 milliseconds) 420 { 421 MT_VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false); 422 423 return allGroups.Wait(milliseconds); 424 } 425 426 bool TaskScheduler::IsEmpty() 427 { 428 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 429 { 430 if (!threadContext[i].queue.IsEmpty()) 431 { 432 return false; 433 } 434 } 435 return true; 436 } 437 438 uint32 TaskScheduler::GetWorkerCount() const 439 { 440 return threadsCount; 441 } 442 443 bool TaskScheduler::IsWorkerThread() const 444 { 445 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 446 { 447 if (threadContext[i].thread.IsCurrentThread()) 448 { 449 return true; 450 } 451 } 452 return false; 453 } 454 455 #ifdef MT_INSTRUMENTED_BUILD 456 457 size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize) 458 { 459 if (workerIndex >= MT_MAX_THREAD_COUNT) 460 { 461 return 0; 462 } 463 464 size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize); 465 return elementsCount; 466 } 467 468 void TaskScheduler::UpdateProfiler() 469 { 470 profilerWebServer.Update(*this); 471 } 472 473 int32 TaskScheduler::GetWebServerPort() const 474 { 475 return webServerPort; 476 } 477 478 479 480 #endif 481 } 482