1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #include <MTScheduler.h>
24 
25 namespace MT
26 {
27 
28 	TaskScheduler::TaskScheduler(uint32 workerThreadsCount)
29 		: roundRobinThreadIndex(0)
30 	{
31 #ifdef MT_INSTRUMENTED_BUILD
32 		webServerPort = profilerWebServer.Serve(8080, 8090);
33 
34 		//initialize static start time
35 		TaskScheduler::GetStartTime();
36 #endif
37 
38 
39 		if (workerThreadsCount != 0)
40 		{
41 			threadsCount = workerThreadsCount;
42 		} else
43 		{
44 			//query number of processor
45 			threadsCount = Max((uint32)Thread::GetNumberOfHardwareThreads() - 2, (uint32)1);
46 			if (threadsCount > MT_MAX_THREAD_COUNT)
47 			{
48 				threadsCount = MT_MAX_THREAD_COUNT;
49 			}
50 		}
51 
52 
53 		// create fiber pool
54 		for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++)
55 		{
56 			FiberContext& context = fiberContext[i];
57 			context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context);
58 			availableFibers.Push( &context );
59 		}
60 
61 		// create worker thread pool
62 		for (uint32 i = 0; i < threadsCount; i++)
63 		{
64 			threadContext[i].SetThreadIndex(i);
65 			threadContext[i].taskScheduler = this;
66 			threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] );
67 		}
68 	}
69 
70 	TaskScheduler::~TaskScheduler()
71 	{
72 		for (uint32 i = 0; i < threadsCount; i++)
73 		{
74 			threadContext[i].state.Set(internal::ThreadState::EXIT);
75 			threadContext[i].hasNewTasksEvent.Signal();
76 		}
77 
78 		for (uint32 i = 0; i < threadsCount; i++)
79 		{
80 			threadContext[i].thread.Stop();
81 		}
82 	}
83 
84 	FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task)
85 	{
86 		FiberContext *fiberContext = task.awaitingFiber;
87 		if (fiberContext)
88 		{
89 			task.awaitingFiber = nullptr;
90 			return fiberContext;
91 		}
92 
93 		if (!availableFibers.TryPop(fiberContext))
94 		{
95 			MT_ASSERT(false, "Fibers pool is empty");
96 		}
97 
98 		fiberContext->currentTask = task.desc;
99 		fiberContext->currentGroup = task.group;
100 		fiberContext->parentFiber = task.parentFiber;
101 		return fiberContext;
102 	}
103 
104 	void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext)
105 	{
106 		MT_ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber");
107 		fiberContext->Reset();
108 		availableFibers.Push(fiberContext);
109 	}
110 
111 	FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext)
112 	{
113 		MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
114 
115 		MT_ASSERT(fiberContext, "Invalid fiber context");
116 		MT_ASSERT(fiberContext->currentTask.IsValid(), "Invalid task");
117 		MT_ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group");
118 
119 		// Set actual thread context to fiber
120 		fiberContext->SetThreadContext(&threadContext);
121 
122 		// Update task status
123 		fiberContext->SetStatus(FiberTaskStatus::RUNNED);
124 
125 		MT_ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
126 
127 		// Run current task code
128 		Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber);
129 
130 		// If task was done
131 		FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
132 		if (taskStatus == FiberTaskStatus::FINISHED)
133 		{
134 			TaskGroup::Type taskGroup = fiberContext->currentGroup;
135 			MT_ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group.");
136 
137 			// Update group status
138 			int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec();
139 			MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
140 			if (groupTaskCount == 0)
141 			{
142 				// Restore awaiting tasks
143 				threadContext.RestoreAwaitingTasks(taskGroup);
144 				threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal();
145 			}
146 
147 			// Update total task count
148 			groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec();
149 			MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
150 			if (groupTaskCount == 0)
151 			{
152 				// Notify all tasks in all group finished
153 				threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal();
154 			}
155 
156 			FiberContext* parentFiberContext = fiberContext->parentFiber;
157 			if (parentFiberContext != nullptr)
158 			{
159 				int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec();
160 				MT_ASSERT(childrenFibersCount >= 0, "Sanity check failed!");
161 
162 				if (childrenFibersCount == 0)
163 				{
164 					// This is a last subtask. Restore parent task
165 #if MT_FIBER_DEBUG
166 
167 					int ownerThread = parentFiberContext->fiber.GetOwnerThread();
168 					FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus();
169 					internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext();
170 					int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter();
171 					MT_ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state");
172 
173 					ownerThread;
174 					parentTaskStatus;
175 					parentThreadContext;
176 					fiberUsageCounter;
177 #endif
178 
179 					MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
180 					MT_ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context");
181 
182 					// WARNING!! Thread context can changed here! Set actual current thread context.
183 					parentFiberContext->SetThreadContext(&threadContext);
184 
185 					MT_ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
186 
187 					// All subtasks is done.
188 					// Exiting and return parent fiber to scheduler
189 					return parentFiberContext;
190 				} else
191 				{
192 					// Other subtasks still exist
193 					// Exiting
194 					return nullptr;
195 				}
196 			} else
197 			{
198 				// Task is finished and no parent task
199 				// Exiting
200 				return nullptr;
201 			}
202 		}
203 
204 		MT_ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status")
205 		return nullptr;
206 	}
207 
208 
209 	void TaskScheduler::FiberMain(void* userData)
210 	{
211 		FiberContext& fiberContext = *(FiberContext*)(userData);
212 		for(;;)
213 		{
214 			MT_ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context");
215 			MT_ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group");
216 			MT_ASSERT(fiberContext.GetThreadContext(), "Invalid thread context");
217 			MT_ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
218 
219 			fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData );
220 
221 			fiberContext.SetStatus(FiberTaskStatus::FINISHED);
222 
223 #ifdef MT_INSTRUMENTED_BUILD
224 			fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask);
225 #endif
226 
227 			Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber);
228 		}
229 
230 	}
231 
232 
233 	bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount)
234 	{
235 		if (workersCount <= 1)
236 		{
237 			return false;
238 		}
239 
240 		uint32 victimIndex = threadContext.random.Get();
241 
242 		for (uint32 attempt = 0; attempt < workersCount; attempt++)
243 		{
244 			uint32 index = victimIndex % workersCount;
245 			if (index == threadContext.workerIndex)
246 			{
247 				victimIndex++;
248 				index = victimIndex % workersCount;
249 			}
250 
251 			internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index];
252 			if (victimContext.queue.TryPop(task))
253 			{
254 				return true;
255 			}
256 
257 			victimIndex++;
258 		}
259 		return false;
260 	}
261 
262 	void TaskScheduler::ThreadMain( void* userData )
263 	{
264 		internal::ThreadContext& context = *(internal::ThreadContext*)(userData);
265 		MT_ASSERT(context.taskScheduler, "Task scheduler must be not null!");
266 		context.schedulerFiber.CreateFromThread(context.thread);
267 
268 		uint32 workersCount = context.taskScheduler->GetWorkerCount();
269 
270 		while(context.state.Get() != internal::ThreadState::EXIT)
271 		{
272 			internal::GroupedTask task;
273 			if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) )
274 			{
275 				// There is a new task
276 				FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task);
277 				MT_ASSERT(fiberContext, "Can't get execution context from pool");
278 				MT_ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed");
279 
280 				while(fiberContext)
281 				{
282 #ifdef MT_INSTRUMENTED_BUILD
283 					context.NotifyTaskResumed(fiberContext->currentTask);
284 #endif
285 
286 					// prevent invalid fiber resume from child tasks, before ExecuteTask is done
287 					fiberContext->childrenFibersCount.Inc();
288 
289 					FiberContext* parentFiber = ExecuteTask(context, fiberContext);
290 
291 					FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
292 
293 					//release guard
294 					int childrenFibersCount = fiberContext->childrenFibersCount.Dec();
295 
296 					// Can drop fiber context - task is finished
297 					if (taskStatus == FiberTaskStatus::FINISHED)
298 					{
299 						MT_ASSERT( childrenFibersCount == 0, "Sanity check failed");
300 						context.taskScheduler->ReleaseFiberContext(fiberContext);
301 
302 						// If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit
303 						fiberContext = parentFiber;
304 					} else
305 					{
306 						MT_ASSERT( childrenFibersCount >= 0, "Sanity check failed");
307 
308 						// No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask
309 						if (childrenFibersCount == 0)
310 						{
311 							MT_ASSERT(parentFiber == nullptr, "Sanity check failed");
312 						} else
313 						{
314 							// If subtasks still exist, drop current task execution. task will be resumed when last subtask finished
315 							break;
316 						}
317 
318 						// If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called
319 						if (taskStatus == FiberTaskStatus::AWAITING_GROUP)
320 						{
321 							break;
322 						}
323 					}
324 				} //while(fiberContext)
325 
326 			} else
327 			{
328 #ifdef MT_INSTRUMENTED_BUILD
329 				int64 waitFrom = MT::GetTimeMicroSeconds();
330 #endif
331 
332 				// Queue is empty and stealing attempt failed
333 				// Wait new events
334 				context.hasNewTasksEvent.Wait(2000);
335 
336 #ifdef MT_INSTRUMENTED_BUILD
337 				int64 waitTo = MT::GetTimeMicroSeconds();
338 				context.NotifyWorkerAwait(waitFrom, waitTo);
339 #endif
340 
341 			}
342 
343 		} // main thread loop
344 	}
345 
346 	void TaskScheduler::RunTasksImpl(WrapperArray<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState)
347 	{
348 		// Reset counter to initial value
349 		int taskCountInGroup[TaskGroup::COUNT];
350 		for (size_t i = 0; i < TaskGroup::COUNT; ++i)
351 		{
352 			taskCountInGroup[i] = 0;
353 		}
354 
355 		// Set parent fiber pointer
356 		// Calculate the number of tasks per group
357 		// Calculate total number of tasks
358 		size_t count = 0;
359 		for (size_t i = 0; i < buckets.Size(); ++i)
360 		{
361 			internal::TaskBucket& bucket = buckets[i];
362 			for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++)
363 			{
364 				internal::GroupedTask & task = bucket.tasks[taskIndex];
365 
366 				MT_ASSERT(task.group < TaskGroup::COUNT, "Invalid group.");
367 
368 				task.parentFiber = parentFiber;
369 				taskCountInGroup[task.group]++;
370 			}
371 			count += bucket.count;
372 		}
373 
374 		// Increments child fibers count on parent fiber
375 		if (parentFiber)
376 		{
377 			parentFiber->childrenFibersCount.Add((uint32)count);
378 		}
379 
380 		if (restoredFromAwaitState == false)
381 		{
382 			// Increments all task in progress counter
383 			allGroupStats.allDoneEvent.Reset();
384 			allGroupStats.inProgressTaskCount.Add((uint32)count);
385 
386 			// Increments task in progress counters (per group)
387 			for (size_t i = 0; i < TaskGroup::COUNT; ++i)
388 			{
389 				int groupTaskCount = taskCountInGroup[i];
390 				if (groupTaskCount > 0)
391 				{
392 					groupStats[i].allDoneEvent.Reset();
393 					groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount);
394 				}
395 			}
396 		} else
397 		{
398 			// If task's restored from await state, counters already in correct state
399 		}
400 
401 		// Add to thread queue
402 		for (size_t i = 0; i < buckets.Size(); ++i)
403 		{
404 			int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount;
405 			internal::ThreadContext & context = threadContext[bucketIndex];
406 
407 			internal::TaskBucket& bucket = buckets[i];
408 
409 			context.queue.PushRange(bucket.tasks, bucket.count);
410 			context.hasNewTasksEvent.Signal();
411 		}
412 	}
413 
414 	bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds)
415 	{
416 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false);
417 
418 		return groupStats[group].allDoneEvent.Wait(milliseconds);
419 	}
420 
421 	bool TaskScheduler::WaitAll(uint32 milliseconds)
422 	{
423 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false);
424 
425 		return allGroupStats.allDoneEvent.Wait(milliseconds);
426 	}
427 
428 	bool TaskScheduler::IsEmpty()
429 	{
430 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
431 		{
432 			if (!threadContext[i].queue.IsEmpty())
433 			{
434 				return false;
435 			}
436 		}
437 		return true;
438 	}
439 
440 	uint32 TaskScheduler::GetWorkerCount() const
441 	{
442 		return threadsCount;
443 	}
444 
445 	bool TaskScheduler::IsWorkerThread() const
446 	{
447 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
448 		{
449 			if (threadContext[i].thread.IsCurrentThread())
450 			{
451 				return true;
452 			}
453 		}
454 		return false;
455 	}
456 
457 #ifdef MT_INSTRUMENTED_BUILD
458 
459 	size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize)
460 	{
461 		if (workerIndex >= MT_MAX_THREAD_COUNT)
462 		{
463 			return 0;
464 		}
465 
466 		size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize);
467 		return elementsCount;
468 	}
469 
470 	void TaskScheduler::UpdateProfiler()
471 	{
472 		profilerWebServer.Update(*this);
473 	}
474 
475 	int32 TaskScheduler::GetWebServerPort() const
476 	{
477 		return webServerPort;
478 	}
479 
480 
481 	int64 TaskScheduler::GetStartTime()
482 	{
483 		static int64 startTime = GetTimeMicroSeconds();
484 		return startTime;
485 	}
486 
487 #endif
488 }
489