1 // The MIT License (MIT) 2 // 3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev 4 // 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 // 12 // The above copyright notice and this permission notice shall be included in 13 // all copies or substantial portions of the Software. 14 // 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 // THE SOFTWARE. 22 23 #include <MTScheduler.h> 24 25 namespace MT 26 { 27 28 TaskScheduler::TaskScheduler(uint32 workerThreadsCount) 29 : roundRobinThreadIndex(0) 30 { 31 #ifdef MT_INSTRUMENTED_BUILD 32 webServerPort = profilerWebServer.Serve(8080, 8090); 33 34 //initialize start time 35 startTime = MT::GetTimeMicroSeconds(); 36 #endif 37 38 39 if (workerThreadsCount != 0) 40 { 41 threadsCount = workerThreadsCount; 42 } else 43 { 44 //query number of processor 45 threadsCount = (uint32)MT::Clamp(Thread::GetNumberOfHardwareThreads() - 2, 1, (int)MT_MAX_THREAD_COUNT); 46 } 47 48 49 // create fiber pool 50 for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++) 51 { 52 FiberContext& context = fiberContext[i]; 53 context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context); 54 availableFibers.Push( &context ); 55 } 56 57 // create worker thread pool 58 for (uint32 i = 0; i < threadsCount; i++) 59 { 60 threadContext[i].SetThreadIndex(i); 61 threadContext[i].taskScheduler = this; 62 threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] ); 63 } 64 } 65 66 TaskScheduler::~TaskScheduler() 67 { 68 for (uint32 i = 0; i < threadsCount; i++) 69 { 70 threadContext[i].state.Set(internal::ThreadState::EXIT); 71 threadContext[i].hasNewTasksEvent.Signal(); 72 } 73 74 for (uint32 i = 0; i < threadsCount; i++) 75 { 76 threadContext[i].thread.Stop(); 77 } 78 } 79 80 FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task) 81 { 82 FiberContext *fiberContext = task.awaitingFiber; 83 if (fiberContext) 84 { 85 task.awaitingFiber = nullptr; 86 return fiberContext; 87 } 88 89 if (!availableFibers.TryPop(fiberContext)) 90 { 91 MT_ASSERT(false, "Fibers pool is empty"); 92 } 93 94 fiberContext->currentTask = task.desc; 95 fiberContext->currentGroup = task.group; 96 fiberContext->parentFiber = task.parentFiber; 97 return fiberContext; 98 } 99 100 void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext) 101 { 102 MT_ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber"); 103 fiberContext->Reset(); 104 availableFibers.Push(fiberContext); 105 } 106 107 FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext) 108 { 109 MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 110 111 MT_ASSERT(fiberContext, "Invalid fiber context"); 112 MT_ASSERT(fiberContext->currentTask.IsValid(), "Invalid task"); 113 114 // Set actual thread context to fiber 115 fiberContext->SetThreadContext(&threadContext); 116 117 // Update task status 118 fiberContext->SetStatus(FiberTaskStatus::RUNNED); 119 120 MT_ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 121 122 // Run current task code 123 Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber); 124 125 // If task was done 126 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 127 if (taskStatus == FiberTaskStatus::FINISHED) 128 { 129 TaskGroup* taskGroup = fiberContext->currentGroup; 130 131 int groupTaskCount = 0; 132 133 // Update group status 134 if (taskGroup != nullptr) 135 { 136 groupTaskCount = taskGroup->inProgressTaskCount.Dec(); 137 MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 138 if (groupTaskCount == 0) 139 { 140 // Restore awaiting tasks 141 threadContext.RestoreAwaitingTasks(taskGroup); 142 taskGroup->allDoneEvent.Signal(); 143 } 144 } 145 146 // Update total task count 147 groupTaskCount = threadContext.taskScheduler->allGroups.inProgressTaskCount.Dec(); 148 MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 149 if (groupTaskCount == 0) 150 { 151 // Notify all tasks in all group finished 152 threadContext.taskScheduler->allGroups.allDoneEvent.Signal(); 153 } 154 155 FiberContext* parentFiberContext = fiberContext->parentFiber; 156 if (parentFiberContext != nullptr) 157 { 158 int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec(); 159 MT_ASSERT(childrenFibersCount >= 0, "Sanity check failed!"); 160 161 if (childrenFibersCount == 0) 162 { 163 // This is a last subtask. Restore parent task 164 #if MT_FIBER_DEBUG 165 166 int ownerThread = parentFiberContext->fiber.GetOwnerThread(); 167 FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus(); 168 internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext(); 169 int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter(); 170 MT_ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state"); 171 172 ownerThread; 173 parentTaskStatus; 174 parentThreadContext; 175 fiberUsageCounter; 176 #endif 177 178 MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 179 MT_ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context"); 180 181 // WARNING!! Thread context can changed here! Set actual current thread context. 182 parentFiberContext->SetThreadContext(&threadContext); 183 184 MT_ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 185 186 // All subtasks is done. 187 // Exiting and return parent fiber to scheduler 188 return parentFiberContext; 189 } else 190 { 191 // Other subtasks still exist 192 // Exiting 193 return nullptr; 194 } 195 } else 196 { 197 // Task is finished and no parent task 198 // Exiting 199 return nullptr; 200 } 201 } 202 203 MT_ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status") 204 return nullptr; 205 } 206 207 208 void TaskScheduler::FiberMain(void* userData) 209 { 210 FiberContext& fiberContext = *(FiberContext*)(userData); 211 for(;;) 212 { 213 MT_ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context"); 214 MT_ASSERT(fiberContext.GetThreadContext(), "Invalid thread context"); 215 MT_ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 216 217 fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData ); 218 219 fiberContext.SetStatus(FiberTaskStatus::FINISHED); 220 221 #ifdef MT_INSTRUMENTED_BUILD 222 fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask); 223 #endif 224 225 Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber); 226 } 227 228 } 229 230 231 bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount) 232 { 233 if (workersCount <= 1) 234 { 235 return false; 236 } 237 238 uint32 victimIndex = threadContext.random.Get(); 239 240 for (uint32 attempt = 0; attempt < workersCount; attempt++) 241 { 242 uint32 index = victimIndex % workersCount; 243 if (index == threadContext.workerIndex) 244 { 245 victimIndex++; 246 index = victimIndex % workersCount; 247 } 248 249 internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index]; 250 if (victimContext.queue.TryPop(task)) 251 { 252 return true; 253 } 254 255 victimIndex++; 256 } 257 return false; 258 } 259 260 void TaskScheduler::ThreadMain( void* userData ) 261 { 262 internal::ThreadContext& context = *(internal::ThreadContext*)(userData); 263 MT_ASSERT(context.taskScheduler, "Task scheduler must be not null!"); 264 context.schedulerFiber.CreateFromThread(context.thread); 265 266 uint32 workersCount = context.taskScheduler->GetWorkerCount(); 267 268 while(context.state.Get() != internal::ThreadState::EXIT) 269 { 270 internal::GroupedTask task; 271 if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) ) 272 { 273 // There is a new task 274 FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task); 275 MT_ASSERT(fiberContext, "Can't get execution context from pool"); 276 MT_ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed"); 277 278 while(fiberContext) 279 { 280 #ifdef MT_INSTRUMENTED_BUILD 281 context.NotifyTaskResumed(fiberContext->currentTask); 282 #endif 283 284 // prevent invalid fiber resume from child tasks, before ExecuteTask is done 285 fiberContext->childrenFibersCount.Inc(); 286 287 FiberContext* parentFiber = ExecuteTask(context, fiberContext); 288 289 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 290 291 //release guard 292 int childrenFibersCount = fiberContext->childrenFibersCount.Dec(); 293 294 // Can drop fiber context - task is finished 295 if (taskStatus == FiberTaskStatus::FINISHED) 296 { 297 MT_ASSERT( childrenFibersCount == 0, "Sanity check failed"); 298 context.taskScheduler->ReleaseFiberContext(fiberContext); 299 300 // If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit 301 fiberContext = parentFiber; 302 } else 303 { 304 MT_ASSERT( childrenFibersCount >= 0, "Sanity check failed"); 305 306 // No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask 307 if (childrenFibersCount == 0) 308 { 309 MT_ASSERT(parentFiber == nullptr, "Sanity check failed"); 310 } else 311 { 312 // If subtasks still exist, drop current task execution. task will be resumed when last subtask finished 313 break; 314 } 315 316 // If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called 317 if (taskStatus == FiberTaskStatus::AWAITING_GROUP) 318 { 319 break; 320 } 321 } 322 } //while(fiberContext) 323 324 } else 325 { 326 #ifdef MT_INSTRUMENTED_BUILD 327 int64 waitFrom = MT::GetTimeMicroSeconds(); 328 #endif 329 330 // Queue is empty and stealing attempt failed 331 // Wait new events 332 context.hasNewTasksEvent.Wait(2000); 333 334 #ifdef MT_INSTRUMENTED_BUILD 335 int64 waitTo = MT::GetTimeMicroSeconds(); 336 context.NotifyWorkerAwait(waitFrom, waitTo); 337 #endif 338 339 } 340 341 } // main thread loop 342 } 343 344 void TaskScheduler::RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState) 345 { 346 // Set parent fiber pointer 347 // Calculate the number of tasks per group 348 // Calculate total number of tasks 349 size_t count = 0; 350 for (size_t i = 0; i < buckets.Size(); ++i) 351 { 352 internal::TaskBucket& bucket = buckets[i]; 353 for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++) 354 { 355 internal::GroupedTask & task = bucket.tasks[taskIndex]; 356 357 task.parentFiber = parentFiber; 358 359 if (task.group != nullptr) 360 { 361 //TODO: reduce the number of reset calls 362 task.group->allDoneEvent.Reset(); 363 task.group->inProgressTaskCount.Inc(); 364 } 365 } 366 367 count += bucket.count; 368 } 369 370 // Increments child fibers count on parent fiber 371 if (parentFiber) 372 { 373 parentFiber->childrenFibersCount.Add((uint32)count); 374 } 375 376 if (restoredFromAwaitState == false) 377 { 378 // Increments all task in progress counter 379 allGroups.allDoneEvent.Reset(); 380 allGroups.inProgressTaskCount.Add((uint32)count); 381 } else 382 { 383 // If task's restored from await state, counters already in correct state 384 } 385 386 // Add to thread queue 387 for (size_t i = 0; i < buckets.Size(); ++i) 388 { 389 int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount; 390 internal::ThreadContext & context = threadContext[bucketIndex]; 391 392 internal::TaskBucket& bucket = buckets[i]; 393 394 context.queue.PushRange(bucket.tasks, bucket.count); 395 context.hasNewTasksEvent.Signal(); 396 } 397 } 398 399 bool TaskScheduler::WaitGroup(TaskGroup* group, uint32 milliseconds) 400 { 401 MT_VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false); 402 403 return group->allDoneEvent.Wait(milliseconds); 404 } 405 406 bool TaskScheduler::WaitAll(uint32 milliseconds) 407 { 408 MT_VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false); 409 410 return allGroups.allDoneEvent.Wait(milliseconds); 411 } 412 413 bool TaskScheduler::IsEmpty() 414 { 415 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 416 { 417 if (!threadContext[i].queue.IsEmpty()) 418 { 419 return false; 420 } 421 } 422 return true; 423 } 424 425 uint32 TaskScheduler::GetWorkerCount() const 426 { 427 return threadsCount; 428 } 429 430 bool TaskScheduler::IsWorkerThread() const 431 { 432 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 433 { 434 if (threadContext[i].thread.IsCurrentThread()) 435 { 436 return true; 437 } 438 } 439 return false; 440 } 441 442 #ifdef MT_INSTRUMENTED_BUILD 443 444 size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize) 445 { 446 if (workerIndex >= MT_MAX_THREAD_COUNT) 447 { 448 return 0; 449 } 450 451 size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize); 452 return elementsCount; 453 } 454 455 void TaskScheduler::UpdateProfiler() 456 { 457 profilerWebServer.Update(*this); 458 } 459 460 int32 TaskScheduler::GetWebServerPort() const 461 { 462 return webServerPort; 463 } 464 465 466 467 #endif 468 } 469