1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #include <MTScheduler.h>
24 
25 namespace MT
26 {
27 
28 	TaskScheduler::TaskScheduler(uint32 workerThreadsCount)
29 		: roundRobinThreadIndex(0)
30 	{
31 #ifdef MT_INSTRUMENTED_BUILD
32 		webServerPort = profilerWebServer.Serve(8080, 8090);
33 
34 		//initialize start time
35 		startTime = MT::GetTimeMicroSeconds();
36 #endif
37 
38 
39 		if (workerThreadsCount != 0)
40 		{
41 			threadsCount = workerThreadsCount;
42 		} else
43 		{
44 			//query number of processor
45 			threadsCount = (uint32)MT::Clamp(Thread::GetNumberOfHardwareThreads() - 2, 1, (int)MT_MAX_THREAD_COUNT);
46 		}
47 
48 
49 		// create fiber pool
50 		for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++)
51 		{
52 			FiberContext& context = fiberContext[i];
53 			context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context);
54 			availableFibers.Push( &context );
55 		}
56 
57 		// create worker thread pool
58 		for (uint32 i = 0; i < threadsCount; i++)
59 		{
60 			threadContext[i].SetThreadIndex(i);
61 			threadContext[i].taskScheduler = this;
62 			threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] );
63 		}
64 	}
65 
66 	TaskScheduler::~TaskScheduler()
67 	{
68 		for (uint32 i = 0; i < threadsCount; i++)
69 		{
70 			threadContext[i].state.Set(internal::ThreadState::EXIT);
71 			threadContext[i].hasNewTasksEvent.Signal();
72 		}
73 
74 		for (uint32 i = 0; i < threadsCount; i++)
75 		{
76 			threadContext[i].thread.Stop();
77 		}
78 	}
79 
80 	FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task)
81 	{
82 		FiberContext *fiberContext = task.awaitingFiber;
83 		if (fiberContext)
84 		{
85 			task.awaitingFiber = nullptr;
86 			return fiberContext;
87 		}
88 
89 		if (!availableFibers.TryPop(fiberContext))
90 		{
91 			MT_ASSERT(false, "Fibers pool is empty");
92 		}
93 
94 		fiberContext->currentTask = task.desc;
95 		fiberContext->currentGroup = task.group;
96 		fiberContext->parentFiber = task.parentFiber;
97 		return fiberContext;
98 	}
99 
100 	void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext)
101 	{
102 		MT_ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber");
103 		fiberContext->Reset();
104 		availableFibers.Push(fiberContext);
105 	}
106 
107 	FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext)
108 	{
109 		MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
110 
111 		MT_ASSERT(fiberContext, "Invalid fiber context");
112 		MT_ASSERT(fiberContext->currentTask.IsValid(), "Invalid task");
113 		MT_ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group");
114 
115 		// Set actual thread context to fiber
116 		fiberContext->SetThreadContext(&threadContext);
117 
118 		// Update task status
119 		fiberContext->SetStatus(FiberTaskStatus::RUNNED);
120 
121 		MT_ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
122 
123 		// Run current task code
124 		Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber);
125 
126 		// If task was done
127 		FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
128 		if (taskStatus == FiberTaskStatus::FINISHED)
129 		{
130 			TaskGroup::Type taskGroup = fiberContext->currentGroup;
131 			MT_ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group.");
132 
133 			// Update group status
134 			int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec();
135 			MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
136 			if (groupTaskCount == 0)
137 			{
138 				// Restore awaiting tasks
139 				threadContext.RestoreAwaitingTasks(taskGroup);
140 				threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal();
141 			}
142 
143 			// Update total task count
144 			groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec();
145 			MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
146 			if (groupTaskCount == 0)
147 			{
148 				// Notify all tasks in all group finished
149 				threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal();
150 			}
151 
152 			FiberContext* parentFiberContext = fiberContext->parentFiber;
153 			if (parentFiberContext != nullptr)
154 			{
155 				int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec();
156 				MT_ASSERT(childrenFibersCount >= 0, "Sanity check failed!");
157 
158 				if (childrenFibersCount == 0)
159 				{
160 					// This is a last subtask. Restore parent task
161 #if MT_FIBER_DEBUG
162 
163 					int ownerThread = parentFiberContext->fiber.GetOwnerThread();
164 					FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus();
165 					internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext();
166 					int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter();
167 					MT_ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state");
168 
169 					ownerThread;
170 					parentTaskStatus;
171 					parentThreadContext;
172 					fiberUsageCounter;
173 #endif
174 
175 					MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
176 					MT_ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context");
177 
178 					// WARNING!! Thread context can changed here! Set actual current thread context.
179 					parentFiberContext->SetThreadContext(&threadContext);
180 
181 					MT_ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
182 
183 					// All subtasks is done.
184 					// Exiting and return parent fiber to scheduler
185 					return parentFiberContext;
186 				} else
187 				{
188 					// Other subtasks still exist
189 					// Exiting
190 					return nullptr;
191 				}
192 			} else
193 			{
194 				// Task is finished and no parent task
195 				// Exiting
196 				return nullptr;
197 			}
198 		}
199 
200 		MT_ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status")
201 		return nullptr;
202 	}
203 
204 
205 	void TaskScheduler::FiberMain(void* userData)
206 	{
207 		FiberContext& fiberContext = *(FiberContext*)(userData);
208 		for(;;)
209 		{
210 			MT_ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context");
211 			MT_ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group");
212 			MT_ASSERT(fiberContext.GetThreadContext(), "Invalid thread context");
213 			MT_ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
214 
215 			fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData );
216 
217 			fiberContext.SetStatus(FiberTaskStatus::FINISHED);
218 
219 #ifdef MT_INSTRUMENTED_BUILD
220 			fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask);
221 #endif
222 
223 			Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber);
224 		}
225 
226 	}
227 
228 
229 	bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount)
230 	{
231 		if (workersCount <= 1)
232 		{
233 			return false;
234 		}
235 
236 		uint32 victimIndex = threadContext.random.Get();
237 
238 		for (uint32 attempt = 0; attempt < workersCount; attempt++)
239 		{
240 			uint32 index = victimIndex % workersCount;
241 			if (index == threadContext.workerIndex)
242 			{
243 				victimIndex++;
244 				index = victimIndex % workersCount;
245 			}
246 
247 			internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index];
248 			if (victimContext.queue.TryPop(task))
249 			{
250 				return true;
251 			}
252 
253 			victimIndex++;
254 		}
255 		return false;
256 	}
257 
258 	void TaskScheduler::ThreadMain( void* userData )
259 	{
260 		internal::ThreadContext& context = *(internal::ThreadContext*)(userData);
261 		MT_ASSERT(context.taskScheduler, "Task scheduler must be not null!");
262 		context.schedulerFiber.CreateFromThread(context.thread);
263 
264 		uint32 workersCount = context.taskScheduler->GetWorkerCount();
265 
266 		while(context.state.Get() != internal::ThreadState::EXIT)
267 		{
268 			internal::GroupedTask task;
269 			if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) )
270 			{
271 				// There is a new task
272 				FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task);
273 				MT_ASSERT(fiberContext, "Can't get execution context from pool");
274 				MT_ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed");
275 
276 				while(fiberContext)
277 				{
278 #ifdef MT_INSTRUMENTED_BUILD
279 					context.NotifyTaskResumed(fiberContext->currentTask);
280 #endif
281 
282 					// prevent invalid fiber resume from child tasks, before ExecuteTask is done
283 					fiberContext->childrenFibersCount.Inc();
284 
285 					FiberContext* parentFiber = ExecuteTask(context, fiberContext);
286 
287 					FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
288 
289 					//release guard
290 					int childrenFibersCount = fiberContext->childrenFibersCount.Dec();
291 
292 					// Can drop fiber context - task is finished
293 					if (taskStatus == FiberTaskStatus::FINISHED)
294 					{
295 						MT_ASSERT( childrenFibersCount == 0, "Sanity check failed");
296 						context.taskScheduler->ReleaseFiberContext(fiberContext);
297 
298 						// If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit
299 						fiberContext = parentFiber;
300 					} else
301 					{
302 						MT_ASSERT( childrenFibersCount >= 0, "Sanity check failed");
303 
304 						// No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask
305 						if (childrenFibersCount == 0)
306 						{
307 							MT_ASSERT(parentFiber == nullptr, "Sanity check failed");
308 						} else
309 						{
310 							// If subtasks still exist, drop current task execution. task will be resumed when last subtask finished
311 							break;
312 						}
313 
314 						// If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called
315 						if (taskStatus == FiberTaskStatus::AWAITING_GROUP)
316 						{
317 							break;
318 						}
319 					}
320 				} //while(fiberContext)
321 
322 			} else
323 			{
324 #ifdef MT_INSTRUMENTED_BUILD
325 				int64 waitFrom = MT::GetTimeMicroSeconds();
326 #endif
327 
328 				// Queue is empty and stealing attempt failed
329 				// Wait new events
330 				context.hasNewTasksEvent.Wait(2000);
331 
332 #ifdef MT_INSTRUMENTED_BUILD
333 				int64 waitTo = MT::GetTimeMicroSeconds();
334 				context.NotifyWorkerAwait(waitFrom, waitTo);
335 #endif
336 
337 			}
338 
339 		} // main thread loop
340 	}
341 
342 	void TaskScheduler::RunTasksImpl(WrapperArray<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState)
343 	{
344 		// Reset counter to initial value
345 		int taskCountInGroup[TaskGroup::COUNT];
346 		for (size_t i = 0; i < TaskGroup::COUNT; ++i)
347 		{
348 			taskCountInGroup[i] = 0;
349 		}
350 
351 		// Set parent fiber pointer
352 		// Calculate the number of tasks per group
353 		// Calculate total number of tasks
354 		size_t count = 0;
355 		for (size_t i = 0; i < buckets.Size(); ++i)
356 		{
357 			internal::TaskBucket& bucket = buckets[i];
358 			for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++)
359 			{
360 				internal::GroupedTask & task = bucket.tasks[taskIndex];
361 
362 				MT_ASSERT(task.group < TaskGroup::COUNT, "Invalid group.");
363 
364 				task.parentFiber = parentFiber;
365 				taskCountInGroup[task.group]++;
366 			}
367 			count += bucket.count;
368 		}
369 
370 		// Increments child fibers count on parent fiber
371 		if (parentFiber)
372 		{
373 			parentFiber->childrenFibersCount.Add((uint32)count);
374 		}
375 
376 		if (restoredFromAwaitState == false)
377 		{
378 			// Increments all task in progress counter
379 			allGroupStats.allDoneEvent.Reset();
380 			allGroupStats.inProgressTaskCount.Add((uint32)count);
381 
382 			// Increments task in progress counters (per group)
383 			for (size_t i = 0; i < TaskGroup::COUNT; ++i)
384 			{
385 				int groupTaskCount = taskCountInGroup[i];
386 				if (groupTaskCount > 0)
387 				{
388 					groupStats[i].allDoneEvent.Reset();
389 					groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount);
390 				}
391 			}
392 		} else
393 		{
394 			// If task's restored from await state, counters already in correct state
395 		}
396 
397 		// Add to thread queue
398 		for (size_t i = 0; i < buckets.Size(); ++i)
399 		{
400 			int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount;
401 			internal::ThreadContext & context = threadContext[bucketIndex];
402 
403 			internal::TaskBucket& bucket = buckets[i];
404 
405 			context.queue.PushRange(bucket.tasks, bucket.count);
406 			context.hasNewTasksEvent.Signal();
407 		}
408 	}
409 
410 	bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds)
411 	{
412 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false);
413 
414 		return groupStats[group].allDoneEvent.Wait(milliseconds);
415 	}
416 
417 	bool TaskScheduler::WaitAll(uint32 milliseconds)
418 	{
419 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false);
420 
421 		return allGroupStats.allDoneEvent.Wait(milliseconds);
422 	}
423 
424 	bool TaskScheduler::IsEmpty()
425 	{
426 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
427 		{
428 			if (!threadContext[i].queue.IsEmpty())
429 			{
430 				return false;
431 			}
432 		}
433 		return true;
434 	}
435 
436 	uint32 TaskScheduler::GetWorkerCount() const
437 	{
438 		return threadsCount;
439 	}
440 
441 	bool TaskScheduler::IsWorkerThread() const
442 	{
443 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
444 		{
445 			if (threadContext[i].thread.IsCurrentThread())
446 			{
447 				return true;
448 			}
449 		}
450 		return false;
451 	}
452 
453 #ifdef MT_INSTRUMENTED_BUILD
454 
455 	size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize)
456 	{
457 		if (workerIndex >= MT_MAX_THREAD_COUNT)
458 		{
459 			return 0;
460 		}
461 
462 		size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize);
463 		return elementsCount;
464 	}
465 
466 	void TaskScheduler::UpdateProfiler()
467 	{
468 		profilerWebServer.Update(*this);
469 	}
470 
471 	int32 TaskScheduler::GetWebServerPort() const
472 	{
473 		return webServerPort;
474 	}
475 
476 
477 
478 #endif
479 }
480