1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #include <MTScheduler.h>
24 
25 namespace MT
26 {
27 
28 	TaskScheduler::TaskScheduler(uint32 workerThreadsCount)
29 		: roundRobinThreadIndex(0)
30 		, startedThreadsCount(0)
31 	{
32 #ifdef MT_INSTRUMENTED_BUILD
33 		webServerPort = profilerWebServer.Serve(8080, 8090);
34 
35 		//initialize start time
36 		startTime = MT::GetTimeMicroSeconds();
37 #endif
38 
39 		if (workerThreadsCount != 0)
40 		{
41 			threadsCount = MT::Clamp(workerThreadsCount, (uint32)1, (uint32)MT_MAX_THREAD_COUNT);
42 		} else
43 		{
44 			//query number of processor
45 			threadsCount = (uint32)MT::Clamp(Thread::GetNumberOfHardwareThreads() - 2, 1, (int)MT_MAX_THREAD_COUNT);
46 		}
47 
48 		// create fiber pool
49 		for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++)
50 		{
51 			FiberContext& context = fiberContext[i];
52 			context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context);
53 			availableFibers.Push( &context );
54 		}
55 
56 		for (uint32 i = 0; i < MT_MAX_GROUPS_COUNT; i++)
57 		{
58 			if (i != MT::DEFAULT_GROUP)
59 			{
60 				availableGroups.Push( (TaskGroup)i );
61 			}
62 		}
63 
64 		groupStats[MT::DEFAULT_GROUP].debugIsFree = false;
65 
66 		// create worker thread pool
67 		for (uint32 i = 0; i < threadsCount; i++)
68 		{
69 			threadContext[i].SetThreadIndex(i);
70 			threadContext[i].taskScheduler = this;
71 			threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] );
72 		}
73 	}
74 
75 	TaskScheduler::~TaskScheduler()
76 	{
77 		for (uint32 i = 0; i < threadsCount; i++)
78 		{
79 			threadContext[i].state.Set(internal::ThreadState::EXIT);
80 			threadContext[i].hasNewTasksEvent.Signal();
81 		}
82 
83 		for (uint32 i = 0; i < threadsCount; i++)
84 		{
85 			threadContext[i].thread.Stop();
86 		}
87 	}
88 
89 	FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task)
90 	{
91 		FiberContext *fiberContext = task.awaitingFiber;
92 		if (fiberContext)
93 		{
94 			task.awaitingFiber = nullptr;
95 			return fiberContext;
96 		}
97 
98 		if (!availableFibers.TryPop(fiberContext))
99 		{
100 			MT_ASSERT(false, "Fibers pool is empty");
101 		}
102 
103 		fiberContext->currentTask = task.desc;
104 		fiberContext->currentGroup = task.group;
105 		fiberContext->parentFiber = task.parentFiber;
106 		return fiberContext;
107 	}
108 
109 	void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext)
110 	{
111 		MT_ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber");
112 		fiberContext->Reset();
113 		availableFibers.Push(fiberContext);
114 	}
115 
116 	FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext)
117 	{
118 		MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
119 
120 		MT_ASSERT(fiberContext, "Invalid fiber context");
121 		MT_ASSERT(fiberContext->currentTask.IsValid(), "Invalid task");
122 
123 		// Set actual thread context to fiber
124 		fiberContext->SetThreadContext(&threadContext);
125 
126 		// Update task status
127 		fiberContext->SetStatus(FiberTaskStatus::RUNNED);
128 
129 		MT_ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
130 
131 		// Run current task code
132 		Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber);
133 
134 		// If task was done
135 		FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
136 		if (taskStatus == FiberTaskStatus::FINISHED)
137 		{
138 			TaskGroup taskGroup = fiberContext->currentGroup;
139 
140 			TaskScheduler::TaskGroupDescription  & groupDesc = threadContext.taskScheduler->GetGroupDesc(taskGroup);
141 
142 			// Update group status
143 			int groupTaskCount = groupDesc.Dec();
144 			MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
145 			if (groupTaskCount == 0)
146 			{
147 				// Restore awaiting tasks
148 				threadContext.RestoreAwaitingTasks(taskGroup);
149 
150 				// All restored tasks can be already finished on this line.
151 				// That's why you can't release groups from worker threads, if worker thread release group, than you can't Signal to released group.
152 
153 				// Signal pending threads that group work is finished. Group can be destroyed after this call.
154 				groupDesc.Signal();
155 
156 				fiberContext->currentGroup = MT::INVALID_GROUP;
157 			}
158 
159 			// Update total task count
160 			int allGroupTaskCount = threadContext.taskScheduler->allGroups.Dec();
161 			MT_ASSERT(allGroupTaskCount >= 0, "Sanity check failed!");
162 			if (allGroupTaskCount == 0)
163 			{
164 				// Notify all tasks in all group finished
165 				threadContext.taskScheduler->allGroups.Signal();
166 			}
167 
168 			FiberContext* parentFiberContext = fiberContext->parentFiber;
169 			if (parentFiberContext != nullptr)
170 			{
171 				int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec();
172 				MT_ASSERT(childrenFibersCount >= 0, "Sanity check failed!");
173 
174 				if (childrenFibersCount == 0)
175 				{
176 					// This is a last subtask. Restore parent task
177 #if MT_FIBER_DEBUG
178 
179 					int ownerThread = parentFiberContext->fiber.GetOwnerThread();
180 					FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus();
181 					internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext();
182 					int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter();
183 					MT_ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state");
184 
185 					ownerThread;
186 					parentTaskStatus;
187 					parentThreadContext;
188 					fiberUsageCounter;
189 #endif
190 
191 					MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
192 					MT_ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context");
193 
194 					// WARNING!! Thread context can changed here! Set actual current thread context.
195 					parentFiberContext->SetThreadContext(&threadContext);
196 
197 					MT_ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
198 
199 					// All subtasks is done.
200 					// Exiting and return parent fiber to scheduler
201 					return parentFiberContext;
202 				} else
203 				{
204 					// Other subtasks still exist
205 					// Exiting
206 					return nullptr;
207 				}
208 			} else
209 			{
210 				// Task is finished and no parent task
211 				// Exiting
212 				return nullptr;
213 			}
214 		}
215 
216 		MT_ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status")
217 		return nullptr;
218 	}
219 
220 
221 	void TaskScheduler::FiberMain(void* userData)
222 	{
223 		FiberContext& fiberContext = *(FiberContext*)(userData);
224 		for(;;)
225 		{
226 			MT_ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context");
227 			MT_ASSERT(fiberContext.GetThreadContext(), "Invalid thread context");
228 			MT_ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
229 
230 			fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData );
231 
232 			fiberContext.SetStatus(FiberTaskStatus::FINISHED);
233 
234 #ifdef MT_INSTRUMENTED_BUILD
235 			fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask);
236 #endif
237 
238 			Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber);
239 		}
240 
241 	}
242 
243 
244 	bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount)
245 	{
246 		if (workersCount <= 1)
247 		{
248 			return false;
249 		}
250 
251 		uint32 victimIndex = threadContext.random.Get();
252 
253 		for (uint32 attempt = 0; attempt < workersCount; attempt++)
254 		{
255 			uint32 index = victimIndex % workersCount;
256 			if (index == threadContext.workerIndex)
257 			{
258 				victimIndex++;
259 				index = victimIndex % workersCount;
260 			}
261 
262 			internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index];
263 			if (victimContext.queue.TryPop(task))
264 			{
265 				return true;
266 			}
267 
268 			victimIndex++;
269 		}
270 		return false;
271 	}
272 
273 	void TaskScheduler::ThreadMain( void* userData )
274 	{
275 		internal::ThreadContext& context = *(internal::ThreadContext*)(userData);
276 		MT_ASSERT(context.taskScheduler, "Task scheduler must be not null!");
277 		context.schedulerFiber.CreateFromThread(context.thread);
278 
279 		uint32 workersCount = context.taskScheduler->GetWorkerCount();
280 
281 		context.taskScheduler->startedThreadsCount.Inc();
282 
283 		//Spinlock until all threads started and initialized
284 		while(true)
285 		{
286 			if (context.taskScheduler->startedThreadsCount.Get() == (int)context.taskScheduler->threadsCount)
287 			{
288 				break;
289 			}
290 			Thread::Sleep(1);
291 		}
292 
293 		while(context.state.Get() != internal::ThreadState::EXIT)
294 		{
295 			internal::GroupedTask task;
296 			if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) )
297 			{
298 				// There is a new task
299 				FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task);
300 				MT_ASSERT(fiberContext, "Can't get execution context from pool");
301 				MT_ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed");
302 
303 				while(fiberContext)
304 				{
305 #ifdef MT_INSTRUMENTED_BUILD
306 					context.NotifyTaskResumed(fiberContext->currentTask);
307 #endif
308 
309 					// prevent invalid fiber resume from child tasks, before ExecuteTask is done
310 					fiberContext->childrenFibersCount.Inc();
311 
312 					FiberContext* parentFiber = ExecuteTask(context, fiberContext);
313 
314 					FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
315 
316 					//release guard
317 					int childrenFibersCount = fiberContext->childrenFibersCount.Dec();
318 
319 					// Can drop fiber context - task is finished
320 					if (taskStatus == FiberTaskStatus::FINISHED)
321 					{
322 						MT_ASSERT( childrenFibersCount == 0, "Sanity check failed");
323 						context.taskScheduler->ReleaseFiberContext(fiberContext);
324 
325 						// If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit
326 						fiberContext = parentFiber;
327 					} else
328 					{
329 						MT_ASSERT( childrenFibersCount >= 0, "Sanity check failed");
330 
331 						// No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask
332 						if (childrenFibersCount == 0)
333 						{
334 							MT_ASSERT(parentFiber == nullptr, "Sanity check failed");
335 						} else
336 						{
337 							// If subtasks still exist, drop current task execution. task will be resumed when last subtask finished
338 							break;
339 						}
340 
341 						// If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called
342 						if (taskStatus == FiberTaskStatus::AWAITING_GROUP)
343 						{
344 							break;
345 						}
346 					}
347 				} //while(fiberContext)
348 
349 			} else
350 			{
351 #ifdef MT_INSTRUMENTED_BUILD
352 				int64 waitFrom = MT::GetTimeMicroSeconds();
353 #endif
354 
355 				// Queue is empty and stealing attempt failed
356 				// Wait new events
357 				context.hasNewTasksEvent.Wait(2000);
358 
359 #ifdef MT_INSTRUMENTED_BUILD
360 				int64 waitTo = MT::GetTimeMicroSeconds();
361 				context.NotifyWorkerAwait(waitFrom, waitTo);
362 #endif
363 
364 			}
365 
366 		} // main thread loop
367 	}
368 
369 	void TaskScheduler::RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState)
370 	{
371 		// This storage is necessary to calculate how many tasks we add to different groups
372 		int newTaskCountInGroup[MT_MAX_GROUPS_COUNT];
373 
374 		// Default value is 0
375 		memset(&newTaskCountInGroup[0], 0, sizeof(newTaskCountInGroup));
376 
377 		// Set parent fiber pointer
378 		// Calculate the number of tasks per group
379 		// Calculate total number of tasks
380 		size_t count = 0;
381 		for (size_t i = 0; i < buckets.Size(); ++i)
382 		{
383 			internal::TaskBucket& bucket = buckets[i];
384 			for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++)
385 			{
386 				internal::GroupedTask & task = bucket.tasks[taskIndex];
387 
388 				task.parentFiber = parentFiber;
389 				newTaskCountInGroup[task.group]++;
390 			}
391 
392 			count += bucket.count;
393 		}
394 
395 		// Increments child fibers count on parent fiber
396 		if (parentFiber)
397 		{
398 			parentFiber->childrenFibersCount.Add((uint32)count);
399 		}
400 
401 		if (restoredFromAwaitState == false)
402 		{
403 			// Increase the number of active tasks in the group using data from temporary storage
404 			for (size_t i = 0; i < MT_MAX_GROUPS_COUNT; i++)
405 			{
406 				int groupNewTaskCount = newTaskCountInGroup[i];
407 				if (groupNewTaskCount > 0)
408 				{
409 					groupStats[i].Reset();
410 					groupStats[i].Add((uint32)groupNewTaskCount);
411 				}
412 			}
413 
414 			// Increments all task in progress counter
415 			allGroups.Reset();
416 			allGroups.Add((uint32)count);
417 		} else
418 		{
419 			// If task's restored from await state, counters already in correct state
420 		}
421 
422 		// Add to thread queue
423 		for (size_t i = 0; i < buckets.Size(); ++i)
424 		{
425 			int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount;
426 			internal::ThreadContext & context = threadContext[bucketIndex];
427 
428 			internal::TaskBucket& bucket = buckets[i];
429 
430 			context.queue.PushRange(bucket.tasks, bucket.count);
431 			context.hasNewTasksEvent.Signal();
432 		}
433 	}
434 
435 	bool TaskScheduler::WaitGroup(TaskGroup group, uint32 milliseconds)
436 	{
437 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false);
438 
439 		TaskScheduler::TaskGroupDescription  & groupDesc = GetGroupDesc(group);
440 		return groupDesc.Wait(milliseconds);
441 	}
442 
443 	bool TaskScheduler::WaitAll(uint32 milliseconds)
444 	{
445 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false);
446 
447 		return allGroups.Wait(milliseconds);
448 	}
449 
450 	bool TaskScheduler::IsEmpty()
451 	{
452 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
453 		{
454 			if (!threadContext[i].queue.IsEmpty())
455 			{
456 				return false;
457 			}
458 		}
459 		return true;
460 	}
461 
462 	uint32 TaskScheduler::GetWorkerCount() const
463 	{
464 		return threadsCount;
465 	}
466 
467 	bool TaskScheduler::IsWorkerThread() const
468 	{
469 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
470 		{
471 			if (threadContext[i].thread.IsCurrentThread())
472 			{
473 				return true;
474 			}
475 		}
476 		return false;
477 	}
478 
479 	TaskGroup TaskScheduler::CreateGroup()
480 	{
481 		MT_ASSERT(IsWorkerThread() == false, "Can't use CreateGroup inside Task.");
482 
483 		TaskGroup group = MT::INVALID_GROUP;
484 		if (!availableGroups.TryPop(group))
485 		{
486 			MT_ASSERT(false, "Group pool is empty");
487 		}
488 
489 		MT_ASSERT(groupStats[group].debugIsFree == true, "Bad logic!");
490 		groupStats[group].debugIsFree = false;
491 
492 		return group;
493 	}
494 
495 	void TaskScheduler::ReleaseGroup(TaskGroup group)
496 	{
497 		MT_ASSERT(IsWorkerThread() == false, "Can't use ReleaseGroup inside Task.");
498 		MT_ASSERT(group != MT::DEFAULT_GROUP && group >= 0 && group < MT_MAX_GROUPS_COUNT, "Invalid group ID");
499 
500 		MT_ASSERT(groupStats[group].debugIsFree == false, "Group already released");
501 		groupStats[group].debugIsFree = true;
502 
503 		availableGroups.Push(group);
504 	}
505 
506 	TaskScheduler::TaskGroupDescription & TaskScheduler::GetGroupDesc(TaskGroup group)
507 	{
508 		MT_ASSERT(group >= 0 && group < MT_MAX_GROUPS_COUNT, "Invalid group ID");
509 
510 		TaskScheduler::TaskGroupDescription & groupDesc = groupStats[group];
511 
512 		MT_ASSERT(groupDesc.debugIsFree == false, "Invalid group");
513 		return groupDesc;
514 	}
515 
516 
517 #ifdef MT_INSTRUMENTED_BUILD
518 
519 	size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize)
520 	{
521 		if (workerIndex >= MT_MAX_THREAD_COUNT)
522 		{
523 			return 0;
524 		}
525 
526 		size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize);
527 		return elementsCount;
528 	}
529 
530 	void TaskScheduler::UpdateProfiler()
531 	{
532 		profilerWebServer.Update(*this);
533 	}
534 
535 	int32 TaskScheduler::GetWebServerPort() const
536 	{
537 		return webServerPort;
538 	}
539 
540 
541 
542 #endif
543 }
544