1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #include <MTScheduler.h>
24 
25 namespace MT
26 {
27 
28 	TaskScheduler::TaskScheduler(uint32 workerThreadsCount)
29 		: roundRobinThreadIndex(0)
30 		, startedThreadsCount(0)
31 	{
32 #ifdef MT_INSTRUMENTED_BUILD
33 		webServerPort = profilerWebServer.Serve(8080, 8090);
34 
35 		//initialize start time
36 		startTime = MT::GetTimeMicroSeconds();
37 #endif
38 
39 		if (workerThreadsCount != 0)
40 		{
41 			threadsCount = MT::Clamp(workerThreadsCount, (uint32)1, (uint32)MT_MAX_THREAD_COUNT);
42 		} else
43 		{
44 			//query number of processor
45 			threadsCount = (uint32)MT::Clamp(Thread::GetNumberOfHardwareThreads() - 2, 1, (int)MT_MAX_THREAD_COUNT);
46 		}
47 
48 		// create fiber pool
49 		for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++)
50 		{
51 			FiberContext& context = fiberContext[i];
52 			context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context);
53 			availableFibers.Push( &context );
54 		}
55 
56 		// create worker thread pool
57 		for (uint32 i = 0; i < threadsCount; i++)
58 		{
59 			threadContext[i].SetThreadIndex(i);
60 			threadContext[i].taskScheduler = this;
61 			threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] );
62 		}
63 	}
64 
65 	TaskScheduler::~TaskScheduler()
66 	{
67 		for (uint32 i = 0; i < threadsCount; i++)
68 		{
69 			threadContext[i].state.Set(internal::ThreadState::EXIT);
70 			threadContext[i].hasNewTasksEvent.Signal();
71 		}
72 
73 		for (uint32 i = 0; i < threadsCount; i++)
74 		{
75 			threadContext[i].thread.Stop();
76 		}
77 	}
78 
79 	FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task)
80 	{
81 		FiberContext *fiberContext = task.awaitingFiber;
82 		if (fiberContext)
83 		{
84 			task.awaitingFiber = nullptr;
85 			return fiberContext;
86 		}
87 
88 		if (!availableFibers.TryPop(fiberContext))
89 		{
90 			MT_ASSERT(false, "Fibers pool is empty");
91 		}
92 
93 		fiberContext->currentTask = task.desc;
94 		fiberContext->currentGroup = task.group;
95 		fiberContext->parentFiber = task.parentFiber;
96 		return fiberContext;
97 	}
98 
99 	void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext)
100 	{
101 		MT_ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber");
102 		fiberContext->Reset();
103 		availableFibers.Push(fiberContext);
104 	}
105 
106 	FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext)
107 	{
108 		MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
109 
110 		MT_ASSERT(fiberContext, "Invalid fiber context");
111 		MT_ASSERT(fiberContext->currentTask.IsValid(), "Invalid task");
112 
113 		// Set actual thread context to fiber
114 		fiberContext->SetThreadContext(&threadContext);
115 
116 		// Update task status
117 		fiberContext->SetStatus(FiberTaskStatus::RUNNED);
118 
119 		MT_ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
120 
121 		// Run current task code
122 		Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber);
123 
124 		// If task was done
125 		FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
126 		if (taskStatus == FiberTaskStatus::FINISHED)
127 		{
128 			TaskGroup* taskGroup = fiberContext->currentGroup;
129 
130 			int groupTaskCount = 0;
131 
132 			// Update group status
133 			if (taskGroup != nullptr)
134 			{
135 				groupTaskCount = taskGroup->inProgressTaskCount.Dec();
136 				MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
137 				if (groupTaskCount == 0)
138 				{
139 					// Restore awaiting tasks
140 					threadContext.RestoreAwaitingTasks(taskGroup);
141 					taskGroup->allDoneEvent.Signal();
142 				}
143 			}
144 
145 			// Update total task count
146 			groupTaskCount = threadContext.taskScheduler->allGroups.inProgressTaskCount.Dec();
147 			MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
148 			if (groupTaskCount == 0)
149 			{
150 				// Notify all tasks in all group finished
151 				threadContext.taskScheduler->allGroups.allDoneEvent.Signal();
152 			}
153 
154 			FiberContext* parentFiberContext = fiberContext->parentFiber;
155 			if (parentFiberContext != nullptr)
156 			{
157 				int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec();
158 				MT_ASSERT(childrenFibersCount >= 0, "Sanity check failed!");
159 
160 				if (childrenFibersCount == 0)
161 				{
162 					// This is a last subtask. Restore parent task
163 #if MT_FIBER_DEBUG
164 
165 					int ownerThread = parentFiberContext->fiber.GetOwnerThread();
166 					FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus();
167 					internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext();
168 					int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter();
169 					MT_ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state");
170 
171 					ownerThread;
172 					parentTaskStatus;
173 					parentThreadContext;
174 					fiberUsageCounter;
175 #endif
176 
177 					MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
178 					MT_ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context");
179 
180 					// WARNING!! Thread context can changed here! Set actual current thread context.
181 					parentFiberContext->SetThreadContext(&threadContext);
182 
183 					MT_ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
184 
185 					// All subtasks is done.
186 					// Exiting and return parent fiber to scheduler
187 					return parentFiberContext;
188 				} else
189 				{
190 					// Other subtasks still exist
191 					// Exiting
192 					return nullptr;
193 				}
194 			} else
195 			{
196 				// Task is finished and no parent task
197 				// Exiting
198 				return nullptr;
199 			}
200 		}
201 
202 		MT_ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status")
203 		return nullptr;
204 	}
205 
206 
207 	void TaskScheduler::FiberMain(void* userData)
208 	{
209 		FiberContext& fiberContext = *(FiberContext*)(userData);
210 		for(;;)
211 		{
212 			MT_ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context");
213 			MT_ASSERT(fiberContext.GetThreadContext(), "Invalid thread context");
214 			MT_ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
215 
216 			fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData );
217 
218 			fiberContext.SetStatus(FiberTaskStatus::FINISHED);
219 
220 #ifdef MT_INSTRUMENTED_BUILD
221 			fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask);
222 #endif
223 
224 			Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber);
225 		}
226 
227 	}
228 
229 
230 	bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount)
231 	{
232 		if (workersCount <= 1)
233 		{
234 			return false;
235 		}
236 
237 		uint32 victimIndex = threadContext.random.Get();
238 
239 		for (uint32 attempt = 0; attempt < workersCount; attempt++)
240 		{
241 			uint32 index = victimIndex % workersCount;
242 			if (index == threadContext.workerIndex)
243 			{
244 				victimIndex++;
245 				index = victimIndex % workersCount;
246 			}
247 
248 			internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index];
249 			if (victimContext.queue.TryPop(task))
250 			{
251 				return true;
252 			}
253 
254 			victimIndex++;
255 		}
256 		return false;
257 	}
258 
259 	void TaskScheduler::ThreadMain( void* userData )
260 	{
261 		internal::ThreadContext& context = *(internal::ThreadContext*)(userData);
262 		MT_ASSERT(context.taskScheduler, "Task scheduler must be not null!");
263 		context.schedulerFiber.CreateFromThread(context.thread);
264 
265 		uint32 workersCount = context.taskScheduler->GetWorkerCount();
266 
267 		context.taskScheduler->startedThreadsCount.Inc();
268 
269 		//Spinlock until all threads started and initialized
270 		while(true)
271 		{
272 			if (context.taskScheduler->startedThreadsCount.Get() == (int)context.taskScheduler->threadsCount)
273 			{
274 				break;
275 			}
276 			Thread::Sleep(1);
277 		}
278 
279 		while(context.state.Get() != internal::ThreadState::EXIT)
280 		{
281 			internal::GroupedTask task;
282 			if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) )
283 			{
284 				// There is a new task
285 				FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task);
286 				MT_ASSERT(fiberContext, "Can't get execution context from pool");
287 				MT_ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed");
288 
289 				while(fiberContext)
290 				{
291 #ifdef MT_INSTRUMENTED_BUILD
292 					context.NotifyTaskResumed(fiberContext->currentTask);
293 #endif
294 
295 					// prevent invalid fiber resume from child tasks, before ExecuteTask is done
296 					fiberContext->childrenFibersCount.Inc();
297 
298 					FiberContext* parentFiber = ExecuteTask(context, fiberContext);
299 
300 					FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
301 
302 					//release guard
303 					int childrenFibersCount = fiberContext->childrenFibersCount.Dec();
304 
305 					// Can drop fiber context - task is finished
306 					if (taskStatus == FiberTaskStatus::FINISHED)
307 					{
308 						MT_ASSERT( childrenFibersCount == 0, "Sanity check failed");
309 						context.taskScheduler->ReleaseFiberContext(fiberContext);
310 
311 						// If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit
312 						fiberContext = parentFiber;
313 					} else
314 					{
315 						MT_ASSERT( childrenFibersCount >= 0, "Sanity check failed");
316 
317 						// No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask
318 						if (childrenFibersCount == 0)
319 						{
320 							MT_ASSERT(parentFiber == nullptr, "Sanity check failed");
321 						} else
322 						{
323 							// If subtasks still exist, drop current task execution. task will be resumed when last subtask finished
324 							break;
325 						}
326 
327 						// If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called
328 						if (taskStatus == FiberTaskStatus::AWAITING_GROUP)
329 						{
330 							break;
331 						}
332 					}
333 				} //while(fiberContext)
334 
335 			} else
336 			{
337 #ifdef MT_INSTRUMENTED_BUILD
338 				int64 waitFrom = MT::GetTimeMicroSeconds();
339 #endif
340 
341 				// Queue is empty and stealing attempt failed
342 				// Wait new events
343 				context.hasNewTasksEvent.Wait(2000);
344 
345 #ifdef MT_INSTRUMENTED_BUILD
346 				int64 waitTo = MT::GetTimeMicroSeconds();
347 				context.NotifyWorkerAwait(waitFrom, waitTo);
348 #endif
349 
350 			}
351 
352 		} // main thread loop
353 	}
354 
355 	void TaskScheduler::RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState)
356 	{
357 		// Set parent fiber pointer
358 		// Calculate the number of tasks per group
359 		// Calculate total number of tasks
360 		size_t count = 0;
361 		for (size_t i = 0; i < buckets.Size(); ++i)
362 		{
363 			internal::TaskBucket& bucket = buckets[i];
364 			for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++)
365 			{
366 				internal::GroupedTask & task = bucket.tasks[taskIndex];
367 
368 				task.parentFiber = parentFiber;
369 
370 				if (task.group != nullptr)
371 				{
372 					//TODO: reduce the number of reset calls
373 					task.group->allDoneEvent.Reset();
374 					task.group->inProgressTaskCount.Inc();
375 				}
376 			}
377 
378 			count += bucket.count;
379 		}
380 
381 		// Increments child fibers count on parent fiber
382 		if (parentFiber)
383 		{
384 			parentFiber->childrenFibersCount.Add((uint32)count);
385 		}
386 
387 		if (restoredFromAwaitState == false)
388 		{
389 			// Increments all task in progress counter
390 			allGroups.allDoneEvent.Reset();
391 			allGroups.inProgressTaskCount.Add((uint32)count);
392 		} else
393 		{
394 			// If task's restored from await state, counters already in correct state
395 		}
396 
397 		// Add to thread queue
398 		for (size_t i = 0; i < buckets.Size(); ++i)
399 		{
400 			int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount;
401 			internal::ThreadContext & context = threadContext[bucketIndex];
402 
403 			internal::TaskBucket& bucket = buckets[i];
404 
405 			context.queue.PushRange(bucket.tasks, bucket.count);
406 			context.hasNewTasksEvent.Signal();
407 		}
408 	}
409 
410 	bool TaskScheduler::WaitGroup(TaskGroup* group, uint32 milliseconds)
411 	{
412 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false);
413 
414 		return group->allDoneEvent.Wait(milliseconds);
415 	}
416 
417 	bool TaskScheduler::WaitAll(uint32 milliseconds)
418 	{
419 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false);
420 
421 		return allGroups.allDoneEvent.Wait(milliseconds);
422 	}
423 
424 	bool TaskScheduler::IsEmpty()
425 	{
426 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
427 		{
428 			if (!threadContext[i].queue.IsEmpty())
429 			{
430 				return false;
431 			}
432 		}
433 		return true;
434 	}
435 
436 	uint32 TaskScheduler::GetWorkerCount() const
437 	{
438 		return threadsCount;
439 	}
440 
441 	bool TaskScheduler::IsWorkerThread() const
442 	{
443 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
444 		{
445 			if (threadContext[i].thread.IsCurrentThread())
446 			{
447 				return true;
448 			}
449 		}
450 		return false;
451 	}
452 
453 #ifdef MT_INSTRUMENTED_BUILD
454 
455 	size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize)
456 	{
457 		if (workerIndex >= MT_MAX_THREAD_COUNT)
458 		{
459 			return 0;
460 		}
461 
462 		size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize);
463 		return elementsCount;
464 	}
465 
466 	void TaskScheduler::UpdateProfiler()
467 	{
468 		profilerWebServer.Update(*this);
469 	}
470 
471 	int32 TaskScheduler::GetWebServerPort() const
472 	{
473 		return webServerPort;
474 	}
475 
476 
477 
478 #endif
479 }
480