1 // The MIT License (MIT) 2 // 3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev 4 // 5 // Permission is hereby granted, free of charge, to any person obtaining a copy 6 // of this software and associated documentation files (the "Software"), to deal 7 // in the Software without restriction, including without limitation the rights 8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 // copies of the Software, and to permit persons to whom the Software is 10 // furnished to do so, subject to the following conditions: 11 // 12 // The above copyright notice and this permission notice shall be included in 13 // all copies or substantial portions of the Software. 14 // 15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 // THE SOFTWARE. 22 23 #include <MTScheduler.h> 24 25 namespace MT 26 { 27 28 TaskScheduler::TaskScheduler() 29 : roundRobinThreadIndex(0) 30 { 31 //query number of processor 32 threadsCount = Max(Thread::GetNumberOfHardwareThreads() - 2, 1); 33 34 if (threadsCount > MT_MAX_THREAD_COUNT) 35 { 36 threadsCount = MT_MAX_THREAD_COUNT; 37 } 38 39 // create fiber pool 40 for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++) 41 { 42 FiberContext& context = fiberContext[i]; 43 context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context); 44 availableFibers.Push( &context ); 45 } 46 47 // create worker thread pool 48 for (uint32 i = 0; i < threadsCount; i++) 49 { 50 threadContext[i].taskScheduler = this; 51 threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] ); 52 } 53 } 54 55 TaskScheduler::~TaskScheduler() 56 { 57 for (uint32 i = 0; i < threadsCount; i++) 58 { 59 threadContext[i].state.Set(internal::ThreadState::EXIT); 60 threadContext[i].hasNewTasksEvent.Signal(); 61 } 62 63 for (uint32 i = 0; i < threadsCount; i++) 64 { 65 threadContext[i].thread.Stop(); 66 } 67 } 68 69 FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task) 70 { 71 FiberContext *fiberContext = task.awaitingFiber; 72 if (fiberContext) 73 { 74 task.awaitingFiber = nullptr; 75 return fiberContext; 76 } 77 78 if (!availableFibers.TryPop(fiberContext)) 79 { 80 ASSERT(false, "Fibers pool is empty"); 81 } 82 83 fiberContext->currentTask = task.desc; 84 fiberContext->currentGroup = task.group; 85 fiberContext->parentFiber = task.parentFiber; 86 return fiberContext; 87 } 88 89 void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext) 90 { 91 ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber"); 92 fiberContext->Reset(); 93 availableFibers.Push(fiberContext); 94 } 95 96 FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext) 97 { 98 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 99 100 ASSERT(fiberContext, "Invalid fiber context"); 101 ASSERT(fiberContext->currentTask.IsValid(), "Invalid task"); 102 ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group"); 103 104 // Set actual thread context to fiber 105 fiberContext->SetThreadContext(&threadContext); 106 107 // Update task status 108 fiberContext->SetStatus(FiberTaskStatus::RUNNED); 109 110 ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 111 112 // Run current task code 113 Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber); 114 115 // If task was done 116 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 117 if (taskStatus == FiberTaskStatus::FINISHED) 118 { 119 TaskGroup::Type taskGroup = fiberContext->currentGroup; 120 ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group."); 121 122 // Update group status 123 int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec(); 124 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 125 if (groupTaskCount == 0) 126 { 127 // Restore awaiting tasks 128 threadContext.RestoreAwaitingTasks(taskGroup); 129 threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal(); 130 } 131 132 // Update total task count 133 groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec(); 134 ASSERT(groupTaskCount >= 0, "Sanity check failed!"); 135 if (groupTaskCount == 0) 136 { 137 // Notify all tasks in all group finished 138 threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal(); 139 } 140 141 FiberContext* parentFiberContext = fiberContext->parentFiber; 142 if (parentFiberContext != nullptr) 143 { 144 int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec(); 145 ASSERT(childrenFibersCount >= 0, "Sanity check failed!"); 146 147 if (childrenFibersCount == 0) 148 { 149 // This is a last subtask. Restore parent task 150 #if FIBER_DEBUG 151 152 int ownerThread = parentFiberContext->fiber.GetOwnerThread(); 153 FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus(); 154 internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext(); 155 int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter(); 156 ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state"); 157 158 ownerThread; 159 parentTaskStatus; 160 parentThreadContext; 161 fiberUsageCounter; 162 #endif 163 164 ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed"); 165 ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context"); 166 167 // WARNING!! Thread context can changed here! Set actual current thread context. 168 parentFiberContext->SetThreadContext(&threadContext); 169 170 ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 171 172 // All subtasks is done. 173 // Exiting and return parent fiber to scheduler 174 return parentFiberContext; 175 } else 176 { 177 // Other subtasks still exist 178 // Exiting 179 return nullptr; 180 } 181 } else 182 { 183 // Task is finished and no parent task 184 // Exiting 185 return nullptr; 186 } 187 } 188 189 ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status") 190 return nullptr; 191 } 192 193 194 void TaskScheduler::FiberMain(void* userData) 195 { 196 FiberContext& fiberContext = *(FiberContext*)(userData); 197 for(;;) 198 { 199 ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context"); 200 ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group"); 201 ASSERT(fiberContext.GetThreadContext(), "Invalid thread context"); 202 ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed"); 203 204 fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData ); 205 206 fiberContext.SetStatus(FiberTaskStatus::FINISHED); 207 208 Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber); 209 } 210 211 } 212 213 214 void TaskScheduler::ThreadMain( void* userData ) 215 { 216 internal::ThreadContext& context = *(internal::ThreadContext*)(userData); 217 ASSERT(context.taskScheduler, "Task scheduler must be not null!"); 218 context.schedulerFiber.CreateFromThread(context.thread); 219 220 while(context.state.Get() != internal::ThreadState::EXIT) 221 { 222 internal::GroupedTask task; 223 if (context.queue.TryPop(task)) 224 { 225 // There is a new task 226 FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task); 227 ASSERT(fiberContext, "Can't get execution context from pool"); 228 ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed"); 229 230 while(fiberContext) 231 { 232 // prevent invalid fiber resume from child tasks, before ExecuteTask is done 233 fiberContext->childrenFibersCount.Inc(); 234 235 FiberContext* parentFiber = ExecuteTask(context, fiberContext); 236 237 FiberTaskStatus::Type taskStatus = fiberContext->GetStatus(); 238 239 //release guard 240 int childrenFibersCount = fiberContext->childrenFibersCount.Dec(); 241 242 // Can drop fiber context - task is finished 243 if (taskStatus == FiberTaskStatus::FINISHED) 244 { 245 ASSERT( childrenFibersCount == 0, "Sanity check failed"); 246 247 context.taskScheduler->ReleaseFiberContext(fiberContext); 248 249 // If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit 250 fiberContext = parentFiber; 251 } else 252 { 253 ASSERT( childrenFibersCount >= 0, "Sanity check failed"); 254 255 // No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask 256 if (childrenFibersCount == 0) 257 { 258 ASSERT(parentFiber == nullptr, "Sanity check failed"); 259 } else 260 { 261 // If subtasks still exist, drop current task execution. task will be resumed when last subtask finished 262 break; 263 } 264 265 // If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called 266 if (taskStatus == FiberTaskStatus::AWAITING_GROUP) 267 { 268 break; 269 } 270 } 271 } //while(fiberContext) 272 273 } else 274 { 275 // Queue is empty 276 // TODO: can try to steal tasks from other threads 277 context.hasNewTasksEvent.Wait(2000); 278 } 279 280 } // main thread loop 281 } 282 283 void TaskScheduler::RunTasksImpl(fixed_array<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState) 284 { 285 // Reset counter to initial value 286 int taskCountInGroup[TaskGroup::COUNT]; 287 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 288 { 289 taskCountInGroup[i] = 0; 290 } 291 292 // Set parent fiber pointer 293 // Calculate the number of tasks per group 294 // Calculate total number of tasks 295 size_t count = 0; 296 for (size_t i = 0; i < buckets.size(); ++i) 297 { 298 internal::TaskBucket& bucket = buckets[i]; 299 for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++) 300 { 301 internal::GroupedTask & task = bucket.tasks[taskIndex]; 302 303 ASSERT(task.group < TaskGroup::COUNT, "Invalid group."); 304 305 task.parentFiber = parentFiber; 306 taskCountInGroup[task.group]++; 307 } 308 count += bucket.count; 309 } 310 311 // Increments child fibers count on parent fiber 312 if (parentFiber) 313 { 314 parentFiber->childrenFibersCount.Add((uint32)count); 315 } 316 317 if (restoredFromAwaitState == false) 318 { 319 // Increments all task in progress counter 320 allGroupStats.allDoneEvent.Reset(); 321 allGroupStats.inProgressTaskCount.Add((uint32)count); 322 323 // Increments task in progress counters (per group) 324 for (size_t i = 0; i < TaskGroup::COUNT; ++i) 325 { 326 int groupTaskCount = taskCountInGroup[i]; 327 if (groupTaskCount > 0) 328 { 329 groupStats[i].allDoneEvent.Reset(); 330 groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount); 331 } 332 } 333 } else 334 { 335 // If task's restored from await state, counters already in correct state 336 } 337 338 // Add to thread queue 339 for (size_t i = 0; i < buckets.size(); ++i) 340 { 341 int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount; 342 internal::ThreadContext & context = threadContext[bucketIndex]; 343 344 internal::TaskBucket& bucket = buckets[i]; 345 346 context.queue.PushRange(bucket.tasks, bucket.count); 347 context.hasNewTasksEvent.Signal(); 348 } 349 } 350 351 bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds) 352 { 353 VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false); 354 355 return groupStats[group].allDoneEvent.Wait(milliseconds); 356 } 357 358 bool TaskScheduler::WaitAll(uint32 milliseconds) 359 { 360 VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false); 361 362 return allGroupStats.allDoneEvent.Wait(milliseconds); 363 } 364 365 bool TaskScheduler::IsEmpty() 366 { 367 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 368 { 369 if (!threadContext[i].queue.IsEmpty()) 370 { 371 return false; 372 } 373 } 374 return true; 375 } 376 377 uint32 TaskScheduler::GetWorkerCount() const 378 { 379 return threadsCount; 380 } 381 382 bool TaskScheduler::IsWorkerThread() const 383 { 384 for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++) 385 { 386 if (threadContext[i].thread.IsCurrentThread()) 387 { 388 return true; 389 } 390 } 391 return false; 392 } 393 394 } 395