1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #include <MTScheduler.h>
24 
25 namespace MT
26 {
27 
28 	TaskScheduler::TaskScheduler(uint32 workerThreadsCount)
29 		: roundRobinThreadIndex(0)
30 		, startedThreadsCount(0)
31 	{
32 #ifdef MT_INSTRUMENTED_BUILD
33 		webServerPort = profilerWebServer.Serve(8080, 8090);
34 
35 		//initialize start time
36 		startTime = MT::GetTimeMicroSeconds();
37 #endif
38 
39 		workerThreadsCount = 50;
40 
41 		if (workerThreadsCount != 0)
42 		{
43 			threadsCount = MT::Clamp(workerThreadsCount, (uint32)1, (uint32)MT_MAX_THREAD_COUNT);
44 		} else
45 		{
46 			//query number of processor
47 			threadsCount = (uint32)MT::Clamp(Thread::GetNumberOfHardwareThreads() - 2, 1, (int)MT_MAX_THREAD_COUNT);
48 		}
49 
50 		// create fiber pool
51 		for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++)
52 		{
53 			FiberContext& context = fiberContext[i];
54 			context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context);
55 			availableFibers.Push( &context );
56 		}
57 
58 		// create worker thread pool
59 		for (uint32 i = 0; i < threadsCount; i++)
60 		{
61 			threadContext[i].SetThreadIndex(i);
62 			threadContext[i].taskScheduler = this;
63 			threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] );
64 		}
65 	}
66 
67 	TaskScheduler::~TaskScheduler()
68 	{
69 		for (uint32 i = 0; i < threadsCount; i++)
70 		{
71 			threadContext[i].state.Set(internal::ThreadState::EXIT);
72 			threadContext[i].hasNewTasksEvent.Signal();
73 		}
74 
75 		for (uint32 i = 0; i < threadsCount; i++)
76 		{
77 			threadContext[i].thread.Stop();
78 		}
79 	}
80 
81 	FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task)
82 	{
83 		FiberContext *fiberContext = task.awaitingFiber;
84 		if (fiberContext)
85 		{
86 			task.awaitingFiber = nullptr;
87 			return fiberContext;
88 		}
89 
90 		if (!availableFibers.TryPop(fiberContext))
91 		{
92 			MT_ASSERT(false, "Fibers pool is empty");
93 		}
94 
95 		fiberContext->currentTask = task.desc;
96 		fiberContext->currentGroup = task.group;
97 		fiberContext->parentFiber = task.parentFiber;
98 		return fiberContext;
99 	}
100 
101 	void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext)
102 	{
103 		MT_ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber");
104 		fiberContext->Reset();
105 		availableFibers.Push(fiberContext);
106 	}
107 
108 	FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext)
109 	{
110 		MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
111 
112 		MT_ASSERT(fiberContext, "Invalid fiber context");
113 		MT_ASSERT(fiberContext->currentTask.IsValid(), "Invalid task");
114 
115 		// Set actual thread context to fiber
116 		fiberContext->SetThreadContext(&threadContext);
117 
118 		// Update task status
119 		fiberContext->SetStatus(FiberTaskStatus::RUNNED);
120 
121 		MT_ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
122 
123 		// Run current task code
124 		Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber);
125 
126 		// If task was done
127 		FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
128 		if (taskStatus == FiberTaskStatus::FINISHED)
129 		{
130 			TaskGroup* taskGroup = fiberContext->currentGroup;
131 
132 			int groupTaskCount = 0;
133 
134 			// Update group status
135 			if (taskGroup != nullptr)
136 			{
137 				groupTaskCount = taskGroup->Dec();
138 				MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
139 				if (groupTaskCount == 0)
140 				{
141 					// Restore awaiting tasks
142 					threadContext.RestoreAwaitingTasks(taskGroup);
143 					taskGroup->Signal();
144 				}
145 			}
146 
147 			// Update total task count
148 			groupTaskCount = threadContext.taskScheduler->allGroups.Dec();
149 			MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
150 			if (groupTaskCount == 0)
151 			{
152 				// Notify all tasks in all group finished
153 				threadContext.taskScheduler->allGroups.Signal();
154 			}
155 
156 			FiberContext* parentFiberContext = fiberContext->parentFiber;
157 			if (parentFiberContext != nullptr)
158 			{
159 				int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec();
160 				MT_ASSERT(childrenFibersCount >= 0, "Sanity check failed!");
161 
162 				if (childrenFibersCount == 0)
163 				{
164 					// This is a last subtask. Restore parent task
165 #if MT_FIBER_DEBUG
166 
167 					int ownerThread = parentFiberContext->fiber.GetOwnerThread();
168 					FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus();
169 					internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext();
170 					int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter();
171 					MT_ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state");
172 
173 					ownerThread;
174 					parentTaskStatus;
175 					parentThreadContext;
176 					fiberUsageCounter;
177 #endif
178 
179 					MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
180 					MT_ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context");
181 
182 					// WARNING!! Thread context can changed here! Set actual current thread context.
183 					parentFiberContext->SetThreadContext(&threadContext);
184 
185 					MT_ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
186 
187 					// All subtasks is done.
188 					// Exiting and return parent fiber to scheduler
189 					return parentFiberContext;
190 				} else
191 				{
192 					// Other subtasks still exist
193 					// Exiting
194 					return nullptr;
195 				}
196 			} else
197 			{
198 				// Task is finished and no parent task
199 				// Exiting
200 				return nullptr;
201 			}
202 		}
203 
204 		MT_ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status")
205 		return nullptr;
206 	}
207 
208 
209 	void TaskScheduler::FiberMain(void* userData)
210 	{
211 		FiberContext& fiberContext = *(FiberContext*)(userData);
212 		for(;;)
213 		{
214 			MT_ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context");
215 			MT_ASSERT(fiberContext.GetThreadContext(), "Invalid thread context");
216 			MT_ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
217 
218 			fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData );
219 
220 			fiberContext.SetStatus(FiberTaskStatus::FINISHED);
221 
222 #ifdef MT_INSTRUMENTED_BUILD
223 			fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask);
224 #endif
225 
226 			Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber);
227 		}
228 
229 	}
230 
231 
232 	bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount)
233 	{
234 		if (workersCount <= 1)
235 		{
236 			return false;
237 		}
238 
239 		uint32 victimIndex = threadContext.random.Get();
240 
241 		for (uint32 attempt = 0; attempt < workersCount; attempt++)
242 		{
243 			uint32 index = victimIndex % workersCount;
244 			if (index == threadContext.workerIndex)
245 			{
246 				victimIndex++;
247 				index = victimIndex % workersCount;
248 			}
249 
250 			internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index];
251 			if (victimContext.queue.TryPop(task))
252 			{
253 				return true;
254 			}
255 
256 			victimIndex++;
257 		}
258 		return false;
259 	}
260 
261 	void TaskScheduler::ThreadMain( void* userData )
262 	{
263 		internal::ThreadContext& context = *(internal::ThreadContext*)(userData);
264 		MT_ASSERT(context.taskScheduler, "Task scheduler must be not null!");
265 		context.schedulerFiber.CreateFromThread(context.thread);
266 
267 		uint32 workersCount = context.taskScheduler->GetWorkerCount();
268 
269 		context.taskScheduler->startedThreadsCount.Inc();
270 
271 		//Spinlock until all threads started and initialized
272 		while(true)
273 		{
274 			if (context.taskScheduler->startedThreadsCount.Get() == (int)context.taskScheduler->threadsCount)
275 			{
276 				break;
277 			}
278 			Thread::Sleep(1);
279 		}
280 
281 		while(context.state.Get() != internal::ThreadState::EXIT)
282 		{
283 			internal::GroupedTask task;
284 			if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) )
285 			{
286 				// There is a new task
287 				FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task);
288 				MT_ASSERT(fiberContext, "Can't get execution context from pool");
289 				MT_ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed");
290 
291 				while(fiberContext)
292 				{
293 #ifdef MT_INSTRUMENTED_BUILD
294 					context.NotifyTaskResumed(fiberContext->currentTask);
295 #endif
296 
297 					// prevent invalid fiber resume from child tasks, before ExecuteTask is done
298 					fiberContext->childrenFibersCount.Inc();
299 
300 					FiberContext* parentFiber = ExecuteTask(context, fiberContext);
301 
302 					FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
303 
304 					//release guard
305 					int childrenFibersCount = fiberContext->childrenFibersCount.Dec();
306 
307 					// Can drop fiber context - task is finished
308 					if (taskStatus == FiberTaskStatus::FINISHED)
309 					{
310 						MT_ASSERT( childrenFibersCount == 0, "Sanity check failed");
311 						context.taskScheduler->ReleaseFiberContext(fiberContext);
312 
313 						// If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit
314 						fiberContext = parentFiber;
315 					} else
316 					{
317 						MT_ASSERT( childrenFibersCount >= 0, "Sanity check failed");
318 
319 						// No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask
320 						if (childrenFibersCount == 0)
321 						{
322 							MT_ASSERT(parentFiber == nullptr, "Sanity check failed");
323 						} else
324 						{
325 							// If subtasks still exist, drop current task execution. task will be resumed when last subtask finished
326 							break;
327 						}
328 
329 						// If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called
330 						if (taskStatus == FiberTaskStatus::AWAITING_GROUP)
331 						{
332 							break;
333 						}
334 					}
335 				} //while(fiberContext)
336 
337 			} else
338 			{
339 #ifdef MT_INSTRUMENTED_BUILD
340 				int64 waitFrom = MT::GetTimeMicroSeconds();
341 #endif
342 
343 				// Queue is empty and stealing attempt failed
344 				// Wait new events
345 				context.hasNewTasksEvent.Wait(2000);
346 
347 #ifdef MT_INSTRUMENTED_BUILD
348 				int64 waitTo = MT::GetTimeMicroSeconds();
349 				context.NotifyWorkerAwait(waitFrom, waitTo);
350 #endif
351 
352 			}
353 
354 		} // main thread loop
355 	}
356 
357 	void TaskScheduler::RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState)
358 	{
359 		// Set parent fiber pointer
360 		// Calculate the number of tasks per group
361 		// Calculate total number of tasks
362 		size_t count = 0;
363 		for (size_t i = 0; i < buckets.Size(); ++i)
364 		{
365 			internal::TaskBucket& bucket = buckets[i];
366 			for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++)
367 			{
368 				internal::GroupedTask & task = bucket.tasks[taskIndex];
369 
370 				task.parentFiber = parentFiber;
371 
372 				if (task.group != nullptr)
373 				{
374 					//TODO: reduce the number of reset calls
375 					task.group->Reset();
376 					task.group->Inc();
377 				}
378 			}
379 
380 			count += bucket.count;
381 		}
382 
383 		// Increments child fibers count on parent fiber
384 		if (parentFiber)
385 		{
386 			parentFiber->childrenFibersCount.Add((uint32)count);
387 		}
388 
389 		if (restoredFromAwaitState == false)
390 		{
391 			// Increments all task in progress counter
392 			allGroups.Reset();
393 			allGroups.Add((uint32)count);
394 		} else
395 		{
396 			// If task's restored from await state, counters already in correct state
397 		}
398 
399 		// Add to thread queue
400 		for (size_t i = 0; i < buckets.Size(); ++i)
401 		{
402 			int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount;
403 			internal::ThreadContext & context = threadContext[bucketIndex];
404 
405 			internal::TaskBucket& bucket = buckets[i];
406 
407 			context.queue.PushRange(bucket.tasks, bucket.count);
408 			context.hasNewTasksEvent.Signal();
409 		}
410 	}
411 
412 	bool TaskScheduler::WaitGroup(TaskGroup* group, uint32 milliseconds)
413 	{
414 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false);
415 
416 		return group->Wait(milliseconds);
417 	}
418 
419 	bool TaskScheduler::WaitAll(uint32 milliseconds)
420 	{
421 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false);
422 
423 		return allGroups.Wait(milliseconds);
424 	}
425 
426 	bool TaskScheduler::IsEmpty()
427 	{
428 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
429 		{
430 			if (!threadContext[i].queue.IsEmpty())
431 			{
432 				return false;
433 			}
434 		}
435 		return true;
436 	}
437 
438 	uint32 TaskScheduler::GetWorkerCount() const
439 	{
440 		return threadsCount;
441 	}
442 
443 	bool TaskScheduler::IsWorkerThread() const
444 	{
445 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
446 		{
447 			if (threadContext[i].thread.IsCurrentThread())
448 			{
449 				return true;
450 			}
451 		}
452 		return false;
453 	}
454 
455 #ifdef MT_INSTRUMENTED_BUILD
456 
457 	size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize)
458 	{
459 		if (workerIndex >= MT_MAX_THREAD_COUNT)
460 		{
461 			return 0;
462 		}
463 
464 		size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize);
465 		return elementsCount;
466 	}
467 
468 	void TaskScheduler::UpdateProfiler()
469 	{
470 		profilerWebServer.Update(*this);
471 	}
472 
473 	int32 TaskScheduler::GetWebServerPort() const
474 	{
475 		return webServerPort;
476 	}
477 
478 
479 
480 #endif
481 }
482