1 // The MIT License (MIT) 2 // 3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev 4 // 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 // 12 // The above copyright notice and this permission notice shall be included in 13 // all copies or substantial portions of the Software. 14 // 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 // THE SOFTWARE. 22 23 #include <MTScheduler.h> 24 25 namespace MT 26 { 27 28 TaskScheduler::TaskScheduler() 29 : roundRobinThreadIndex(0) 30 { 31 //query number of processor 32 threadsCount = Max(Thread::GetNumberOfHardwareThreads() - 2, 1); 33 34 if (threadsCount > MT_MAX_THREAD_COUNT) 35 { 36 threadsCount = MT_MAX_THREAD_COUNT; 37 } 38 39 // create fiber pool 40 for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++) 41 { 42 FiberContext& context = fiberContext[i]; 43 context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context); 44 availableFibers.Push( &context ); 45 } 46 47 // create worker thread pool 48 for (uint32 i = 0; i < threadsCount; i++) 49 { 50 threadContext[i].SetThreadIndex(i); 51 threadContext[i].taskScheduler = this; 52 threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] ); 53 } 54 } 55 56 TaskScheduler::~TaskScheduler() 57 { 58 for (uint32 i = 0; i < threadsCount; i++) 59 { 60 threadContext[i].state.Set(internal::ThreadState::EXIT); 61 threadContext[i].hasNewTasksEvent.Signal(); 62 } 63 64 for (uint32 i = 0; i < threadsCount; i++) 65 { 66 threadContext[i].thread.Stop(); 67 } 68 } 69 70 FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task) 71 { 72 FiberContext *fiberContext = task.awaitingFiber; 73 if (fiberContext) 74 { 75 task.awaitingFiber = nullptr; 76 return fiberContext; 77 } 78 79 if (!availableFibers.TryPop(fiberContext)) 80 { 81 ASSERT(false, "Fibers pool is empty"); 82 } 83 84 fiberContext->currentTask = task.desc; 85 fiberContext->currentGroup = task.group; 86 fiberContext->parentFiber = task.parentFiber; 87 return fiberContext; 88 } 89 90 void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext) 91 { 92 ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber"); 93 fiberContext->Reset(); 94 availableFibers.Push(fiberContext); 95 } 96 97 FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext) 98 { 99 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 100 101 ASSERT(fiberContext, "Invalid fiber context"); 102 ASSERT(fiberContext->currentTask.IsValid(), "Invalid task"); 103 ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group"); 104 105 // Set actual thread context to fiber 106 fiberContext->SetThreadContext(&threadContext); 107 108 // Update task status 109 fiberContext->SetStatus(FiberTaskStatus::RUNNED); 110 111 ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 112 113 // Run current task code 114 Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber); 115 116 // If task was done 117 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 118 if (taskStatus == FiberTaskStatus::FINISHED) 119 { 120 TaskGroup::Type taskGroup = fiberContext->currentGroup; 121 ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group."); 122 123 // Update group status 124 int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec(); 125 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 126 if (groupTaskCount == 0) 127 { 128 // Restore awaiting tasks 129 threadContext.RestoreAwaitingTasks(taskGroup); 130 threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal(); 131 } 132 133 // Update total task count 134 groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec(); 135 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 136 if (groupTaskCount == 0) 137 { 138 // Notify all tasks in all group finished 139 threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal(); 140 } 141 142 FiberContext* parentFiberContext = fiberContext->parentFiber; 143 if (parentFiberContext != nullptr) 144 { 145 int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec(); 146 ASSERT(childrenFibersCount >= 0, "Sanity check failed!"); 147 148 if (childrenFibersCount == 0) 149 { 150 // This is a last subtask. Restore parent task 151 #if FIBER_DEBUG 152 153 int ownerThread = parentFiberContext->fiber.GetOwnerThread(); 154 FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus(); 155 internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext(); 156 int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter(); 157 ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state"); 158 159 ownerThread; 160 parentTaskStatus; 161 parentThreadContext; 162 fiberUsageCounter; 163 #endif 164 165 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 166 ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context"); 167 168 // WARNING!! Thread context can changed here! Set actual current thread context. 169 parentFiberContext->SetThreadContext(&threadContext); 170 171 ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 172 173 // All subtasks is done. 174 // Exiting and return parent fiber to scheduler 175 return parentFiberContext; 176 } else 177 { 178 // Other subtasks still exist 179 // Exiting 180 return nullptr; 181 } 182 } else 183 { 184 // Task is finished and no parent task 185 // Exiting 186 return nullptr; 187 } 188 } 189 190 ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status") 191 return nullptr; 192 } 193 194 195 void TaskScheduler::FiberMain(void* userData) 196 { 197 FiberContext& fiberContext = *(FiberContext*)(userData); 198 for(;;) 199 { 200 ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context"); 201 ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group"); 202 ASSERT(fiberContext.GetThreadContext(), "Invalid thread context"); 203 ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 204 205 fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData ); 206 207 fiberContext.SetStatus(FiberTaskStatus::FINISHED); 208 209 Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber); 210 } 211 212 } 213 214 215 bool TaskScheduler::StealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task) 216 { 217 // Try to steal tasks from random worker thread 218 uint32 workersCount = threadContext.taskScheduler->GetWorkerCount(); 219 if (workersCount <= 1) 220 { 221 return false; 222 } 223 224 uint32 victimIndex = threadContext.random.Get() % workersCount; 225 if (victimIndex == threadContext.workerIndex) 226 { 227 victimIndex = victimIndex++; 228 victimIndex = victimIndex % workersCount; 229 } 230 231 internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[victimIndex]; 232 return victimContext.queue.TryPop(task); 233 } 234 235 void TaskScheduler::ThreadMain( void* userData ) 236 { 237 internal::ThreadContext& context = *(internal::ThreadContext*)(userData); 238 ASSERT(context.taskScheduler, "Task scheduler must be not null!"); 239 context.schedulerFiber.CreateFromThread(context.thread); 240 241 while(context.state.Get() != internal::ThreadState::EXIT) 242 { 243 internal::GroupedTask task; 244 if (context.queue.TryPop(task) || StealTask(context, task) ) 245 { 246 // There is a new task 247 FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task); 248 ASSERT(fiberContext, "Can't get execution context from pool"); 249 ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed"); 250 251 while(fiberContext) 252 { 253 // prevent invalid fiber resume from child tasks, before ExecuteTask is done 254 fiberContext->childrenFibersCount.Inc(); 255 256 FiberContext* parentFiber = ExecuteTask(context, fiberContext); 257 258 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 259 260 //release guard 261 int childrenFibersCount = fiberContext->childrenFibersCount.Dec(); 262 263 // Can drop fiber context - task is finished 264 if (taskStatus == FiberTaskStatus::FINISHED) 265 { 266 ASSERT( childrenFibersCount == 0, "Sanity check failed"); 267 268 context.taskScheduler->ReleaseFiberContext(fiberContext); 269 270 // If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit 271 fiberContext = parentFiber; 272 } else 273 { 274 ASSERT( childrenFibersCount >= 0, "Sanity check failed"); 275 276 // No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask 277 if (childrenFibersCount == 0) 278 { 279 ASSERT(parentFiber == nullptr, "Sanity check failed"); 280 } else 281 { 282 // If subtasks still exist, drop current task execution. task will be resumed when last subtask finished 283 break; 284 } 285 286 // If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called 287 if (taskStatus == FiberTaskStatus::AWAITING_GROUP) 288 { 289 break; 290 } 291 } 292 } //while(fiberContext) 293 294 } else 295 { 296 // Queue is empty and stealing attempt failed 297 // Wait new events 298 context.hasNewTasksEvent.Wait(2000); 299 } 300 301 } // main thread loop 302 } 303 304 void TaskScheduler::RunTasksImpl(fixed_array<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState) 305 { 306 // Reset counter to initial value 307 int taskCountInGroup[TaskGroup::COUNT]; 308 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 309 { 310 taskCountInGroup[i] = 0; 311 } 312 313 // Set parent fiber pointer 314 // Calculate the number of tasks per group 315 // Calculate total number of tasks 316 size_t count = 0; 317 for (size_t i = 0; i < buckets.size(); ++i) 318 { 319 internal::TaskBucket& bucket = buckets[i]; 320 for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++) 321 { 322 internal::GroupedTask & task = bucket.tasks[taskIndex]; 323 324 ASSERT(task.group < TaskGroup::COUNT, "Invalid group."); 325 326 task.parentFiber = parentFiber; 327 taskCountInGroup[task.group]++; 328 } 329 count += bucket.count; 330 } 331 332 // Increments child fibers count on parent fiber 333 if (parentFiber) 334 { 335 parentFiber->childrenFibersCount.Add((uint32)count); 336 } 337 338 if (restoredFromAwaitState == false) 339 { 340 // Increments all task in progress counter 341 allGroupStats.allDoneEvent.Reset(); 342 allGroupStats.inProgressTaskCount.Add((uint32)count); 343 344 // Increments task in progress counters (per group) 345 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 346 { 347 int groupTaskCount = taskCountInGroup[i]; 348 if (groupTaskCount > 0) 349 { 350 groupStats[i].allDoneEvent.Reset(); 351 groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount); 352 } 353 } 354 } else 355 { 356 // If task's restored from await state, counters already in correct state 357 } 358 359 // Add to thread queue 360 for (size_t i = 0; i < buckets.size(); ++i) 361 { 362 int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount; 363 internal::ThreadContext & context = threadContext[bucketIndex]; 364 365 internal::TaskBucket& bucket = buckets[i]; 366 367 context.queue.PushRange(bucket.tasks, bucket.count); 368 context.hasNewTasksEvent.Signal(); 369 } 370 } 371 372 bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds) 373 { 374 VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false); 375 376 return groupStats[group].allDoneEvent.Wait(milliseconds); 377 } 378 379 bool TaskScheduler::WaitAll(uint32 milliseconds) 380 { 381 VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false); 382 383 return allGroupStats.allDoneEvent.Wait(milliseconds); 384 } 385 386 bool TaskScheduler::IsEmpty() 387 { 388 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 389 { 390 if (!threadContext[i].queue.IsEmpty()) 391 { 392 return false; 393 } 394 } 395 return true; 396 } 397 398 uint32 TaskScheduler::GetWorkerCount() const 399 { 400 return threadsCount; 401 } 402 403 bool TaskScheduler::IsWorkerThread() const 404 { 405 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 406 { 407 if (threadContext[i].thread.IsCurrentThread()) 408 { 409 return true; 410 } 411 } 412 return false; 413 } 414 415 } 416