1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #include <MTScheduler.h>
24 
25 namespace MT
26 {
27 
28 	TaskScheduler::TaskScheduler(uint32 workerThreadsCount)
29 		: roundRobinThreadIndex(0)
30 	{
31 #ifdef MT_INSTRUMENTED_BUILD
32 		webServerPort = profilerWebServer.Serve(8080, 8090);
33 
34 		//initialize start time
35 		startTime = MT::GetTimeMicroSeconds();
36 #endif
37 
38 
39 		if (workerThreadsCount != 0)
40 		{
41 			threadsCount = workerThreadsCount;
42 		} else
43 		{
44 			//query number of processor
45 			threadsCount = (uint32)MT::Clamp(Thread::GetNumberOfHardwareThreads() - 2, 1, (int)MT_MAX_THREAD_COUNT);
46 		}
47 
48 
49 		// create fiber pool
50 		for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++)
51 		{
52 			FiberContext& context = fiberContext[i];
53 			context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context);
54 			availableFibers.Push( &context );
55 		}
56 
57 		// create worker thread pool
58 		for (uint32 i = 0; i < threadsCount; i++)
59 		{
60 			threadContext[i].SetThreadIndex(i);
61 			threadContext[i].taskScheduler = this;
62 			threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] );
63 		}
64 	}
65 
66 	TaskScheduler::~TaskScheduler()
67 	{
68 		for (uint32 i = 0; i < threadsCount; i++)
69 		{
70 			threadContext[i].state.Set(internal::ThreadState::EXIT);
71 			threadContext[i].hasNewTasksEvent.Signal();
72 		}
73 
74 		for (uint32 i = 0; i < threadsCount; i++)
75 		{
76 			threadContext[i].thread.Stop();
77 		}
78 	}
79 
80 	FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task)
81 	{
82 		FiberContext *fiberContext = task.awaitingFiber;
83 		if (fiberContext)
84 		{
85 			task.awaitingFiber = nullptr;
86 			return fiberContext;
87 		}
88 
89 		if (!availableFibers.TryPop(fiberContext))
90 		{
91 			MT_ASSERT(false, "Fibers pool is empty");
92 		}
93 
94 		fiberContext->currentTask = task.desc;
95 		fiberContext->currentGroup = task.group;
96 		fiberContext->parentFiber = task.parentFiber;
97 		return fiberContext;
98 	}
99 
100 	void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext)
101 	{
102 		MT_ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber");
103 		fiberContext->Reset();
104 		availableFibers.Push(fiberContext);
105 	}
106 
107 	FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext)
108 	{
109 		MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
110 
111 		MT_ASSERT(fiberContext, "Invalid fiber context");
112 		MT_ASSERT(fiberContext->currentTask.IsValid(), "Invalid task");
113 
114 		// Set actual thread context to fiber
115 		fiberContext->SetThreadContext(&threadContext);
116 
117 		// Update task status
118 		fiberContext->SetStatus(FiberTaskStatus::RUNNED);
119 
120 		MT_ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
121 
122 		// Run current task code
123 		Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber);
124 
125 		// If task was done
126 		FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
127 		if (taskStatus == FiberTaskStatus::FINISHED)
128 		{
129 			TaskGroup* taskGroup = fiberContext->currentGroup;
130 
131 			int groupTaskCount = 0;
132 
133 			// Update group status
134 			if (taskGroup != nullptr)
135 			{
136 				groupTaskCount = taskGroup->inProgressTaskCount.Dec();
137 				MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
138 				if (groupTaskCount == 0)
139 				{
140 					// Restore awaiting tasks
141 					threadContext.RestoreAwaitingTasks(taskGroup);
142 					taskGroup->allDoneEvent.Signal();
143 				}
144 			}
145 
146 			// Update total task count
147 			groupTaskCount = threadContext.taskScheduler->allGroups.inProgressTaskCount.Dec();
148 			MT_ASSERT(groupTaskCount >= 0, "Sanity check failed!");
149 			if (groupTaskCount == 0)
150 			{
151 				// Notify all tasks in all group finished
152 				threadContext.taskScheduler->allGroups.allDoneEvent.Signal();
153 			}
154 
155 			FiberContext* parentFiberContext = fiberContext->parentFiber;
156 			if (parentFiberContext != nullptr)
157 			{
158 				int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec();
159 				MT_ASSERT(childrenFibersCount >= 0, "Sanity check failed!");
160 
161 				if (childrenFibersCount == 0)
162 				{
163 					// This is a last subtask. Restore parent task
164 #if MT_FIBER_DEBUG
165 
166 					int ownerThread = parentFiberContext->fiber.GetOwnerThread();
167 					FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus();
168 					internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext();
169 					int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter();
170 					MT_ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state");
171 
172 					ownerThread;
173 					parentTaskStatus;
174 					parentThreadContext;
175 					fiberUsageCounter;
176 #endif
177 
178 					MT_ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
179 					MT_ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context");
180 
181 					// WARNING!! Thread context can changed here! Set actual current thread context.
182 					parentFiberContext->SetThreadContext(&threadContext);
183 
184 					MT_ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
185 
186 					// All subtasks is done.
187 					// Exiting and return parent fiber to scheduler
188 					return parentFiberContext;
189 				} else
190 				{
191 					// Other subtasks still exist
192 					// Exiting
193 					return nullptr;
194 				}
195 			} else
196 			{
197 				// Task is finished and no parent task
198 				// Exiting
199 				return nullptr;
200 			}
201 		}
202 
203 		MT_ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status")
204 		return nullptr;
205 	}
206 
207 
208 	void TaskScheduler::FiberMain(void* userData)
209 	{
210 		FiberContext& fiberContext = *(FiberContext*)(userData);
211 		for(;;)
212 		{
213 			MT_ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context");
214 			MT_ASSERT(fiberContext.GetThreadContext(), "Invalid thread context");
215 			MT_ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
216 
217 			fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData );
218 
219 			fiberContext.SetStatus(FiberTaskStatus::FINISHED);
220 
221 #ifdef MT_INSTRUMENTED_BUILD
222 			fiberContext.GetThreadContext()->NotifyTaskFinished(fiberContext.currentTask);
223 #endif
224 
225 			Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber);
226 		}
227 
228 	}
229 
230 
231 	bool TaskScheduler::TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount)
232 	{
233 		if (workersCount <= 1)
234 		{
235 			return false;
236 		}
237 
238 		uint32 victimIndex = threadContext.random.Get();
239 
240 		for (uint32 attempt = 0; attempt < workersCount; attempt++)
241 		{
242 			uint32 index = victimIndex % workersCount;
243 			if (index == threadContext.workerIndex)
244 			{
245 				victimIndex++;
246 				index = victimIndex % workersCount;
247 			}
248 
249 			internal::ThreadContext& victimContext = threadContext.taskScheduler->threadContext[index];
250 			if (victimContext.queue.TryPop(task))
251 			{
252 				return true;
253 			}
254 
255 			victimIndex++;
256 		}
257 		return false;
258 	}
259 
260 	void TaskScheduler::ThreadMain( void* userData )
261 	{
262 		internal::ThreadContext& context = *(internal::ThreadContext*)(userData);
263 		MT_ASSERT(context.taskScheduler, "Task scheduler must be not null!");
264 		context.schedulerFiber.CreateFromThread(context.thread);
265 
266 		uint32 workersCount = context.taskScheduler->GetWorkerCount();
267 
268 		while(context.state.Get() != internal::ThreadState::EXIT)
269 		{
270 			internal::GroupedTask task;
271 			if (context.queue.TryPop(task) || TryStealTask(context, task, workersCount) )
272 			{
273 				// There is a new task
274 				FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task);
275 				MT_ASSERT(fiberContext, "Can't get execution context from pool");
276 				MT_ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed");
277 
278 				while(fiberContext)
279 				{
280 #ifdef MT_INSTRUMENTED_BUILD
281 					context.NotifyTaskResumed(fiberContext->currentTask);
282 #endif
283 
284 					// prevent invalid fiber resume from child tasks, before ExecuteTask is done
285 					fiberContext->childrenFibersCount.Inc();
286 
287 					FiberContext* parentFiber = ExecuteTask(context, fiberContext);
288 
289 					FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
290 
291 					//release guard
292 					int childrenFibersCount = fiberContext->childrenFibersCount.Dec();
293 
294 					// Can drop fiber context - task is finished
295 					if (taskStatus == FiberTaskStatus::FINISHED)
296 					{
297 						MT_ASSERT( childrenFibersCount == 0, "Sanity check failed");
298 						context.taskScheduler->ReleaseFiberContext(fiberContext);
299 
300 						// If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit
301 						fiberContext = parentFiber;
302 					} else
303 					{
304 						MT_ASSERT( childrenFibersCount >= 0, "Sanity check failed");
305 
306 						// No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask
307 						if (childrenFibersCount == 0)
308 						{
309 							MT_ASSERT(parentFiber == nullptr, "Sanity check failed");
310 						} else
311 						{
312 							// If subtasks still exist, drop current task execution. task will be resumed when last subtask finished
313 							break;
314 						}
315 
316 						// If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called
317 						if (taskStatus == FiberTaskStatus::AWAITING_GROUP)
318 						{
319 							break;
320 						}
321 					}
322 				} //while(fiberContext)
323 
324 			} else
325 			{
326 #ifdef MT_INSTRUMENTED_BUILD
327 				int64 waitFrom = MT::GetTimeMicroSeconds();
328 #endif
329 
330 				// Queue is empty and stealing attempt failed
331 				// Wait new events
332 				context.hasNewTasksEvent.Wait(2000);
333 
334 #ifdef MT_INSTRUMENTED_BUILD
335 				int64 waitTo = MT::GetTimeMicroSeconds();
336 				context.NotifyWorkerAwait(waitFrom, waitTo);
337 #endif
338 
339 			}
340 
341 		} // main thread loop
342 	}
343 
344 	void TaskScheduler::RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState)
345 	{
346 		// Set parent fiber pointer
347 		// Calculate the number of tasks per group
348 		// Calculate total number of tasks
349 		size_t count = 0;
350 		for (size_t i = 0; i < buckets.Size(); ++i)
351 		{
352 			internal::TaskBucket& bucket = buckets[i];
353 			for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++)
354 			{
355 				internal::GroupedTask & task = bucket.tasks[taskIndex];
356 
357 				task.parentFiber = parentFiber;
358 
359 				if (task.group != nullptr)
360 				{
361 					//TODO: reduce the number of reset calls
362 					task.group->allDoneEvent.Reset();
363 					task.group->inProgressTaskCount.Inc();
364 				}
365 			}
366 
367 			count += bucket.count;
368 		}
369 
370 		// Increments child fibers count on parent fiber
371 		if (parentFiber)
372 		{
373 			parentFiber->childrenFibersCount.Add((uint32)count);
374 		}
375 
376 		if (restoredFromAwaitState == false)
377 		{
378 			// Increments all task in progress counter
379 			allGroups.allDoneEvent.Reset();
380 			allGroups.inProgressTaskCount.Add((uint32)count);
381 		} else
382 		{
383 			// If task's restored from await state, counters already in correct state
384 		}
385 
386 		// Add to thread queue
387 		for (size_t i = 0; i < buckets.Size(); ++i)
388 		{
389 			int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount;
390 			internal::ThreadContext & context = threadContext[bucketIndex];
391 
392 			internal::TaskBucket& bucket = buckets[i];
393 
394 			context.queue.PushRange(bucket.tasks, bucket.count);
395 			context.hasNewTasksEvent.Signal();
396 		}
397 	}
398 
399 	bool TaskScheduler::WaitGroup(TaskGroup* group, uint32 milliseconds)
400 	{
401 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false);
402 
403 		return group->allDoneEvent.Wait(milliseconds);
404 	}
405 
406 	bool TaskScheduler::WaitAll(uint32 milliseconds)
407 	{
408 		MT_VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false);
409 
410 		return allGroups.allDoneEvent.Wait(milliseconds);
411 	}
412 
413 	bool TaskScheduler::IsEmpty()
414 	{
415 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
416 		{
417 			if (!threadContext[i].queue.IsEmpty())
418 			{
419 				return false;
420 			}
421 		}
422 		return true;
423 	}
424 
425 	uint32 TaskScheduler::GetWorkerCount() const
426 	{
427 		return threadsCount;
428 	}
429 
430 	bool TaskScheduler::IsWorkerThread() const
431 	{
432 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
433 		{
434 			if (threadContext[i].thread.IsCurrentThread())
435 			{
436 				return true;
437 			}
438 		}
439 		return false;
440 	}
441 
442 #ifdef MT_INSTRUMENTED_BUILD
443 
444 	size_t TaskScheduler::GetProfilerEvents(uint32 workerIndex, ProfileEventDesc * dstBuffer, size_t dstBufferSize)
445 	{
446 		if (workerIndex >= MT_MAX_THREAD_COUNT)
447 		{
448 			return 0;
449 		}
450 
451 		size_t elementsCount = threadContext[workerIndex].profileEvents.PopAll(dstBuffer, dstBufferSize);
452 		return elementsCount;
453 	}
454 
455 	void TaskScheduler::UpdateProfiler()
456 	{
457 		profilerWebServer.Update(*this);
458 	}
459 
460 	int32 TaskScheduler::GetWebServerPort() const
461 	{
462 		return webServerPort;
463 	}
464 
465 
466 
467 #endif
468 }
469