1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #include <MTScheduler.h>
24 
25 namespace MT
26 {
27 
28 	TaskScheduler::TaskScheduler()
29 		: roundRobinThreadIndex(0)
30 	{
31 		//query number of processor
32 		threadsCount = Max(Thread::GetNumberOfHardwareThreads() - 2, 1);
33 
34 		if (threadsCount > MT_MAX_THREAD_COUNT)
35 		{
36 			threadsCount = MT_MAX_THREAD_COUNT;
37 		}
38 
39 		// create fiber pool
40 		for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++)
41 		{
42 			FiberContext& context = fiberContext[i];
43 			context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context);
44 			availableFibers.Push( &context );
45 		}
46 
47 		// create worker thread pool
48 		for (uint32 i = 0; i < threadsCount; i++)
49 		{
50 			threadContext[i].taskScheduler = this;
51 			threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] );
52 		}
53 	}
54 
55 	TaskScheduler::~TaskScheduler()
56 	{
57 		for (uint32 i = 0; i < threadsCount; i++)
58 		{
59 			threadContext[i].state.Set(internal::ThreadState::EXIT);
60 			threadContext[i].hasNewTasksEvent.Signal();
61 		}
62 
63 		for (uint32 i = 0; i < threadsCount; i++)
64 		{
65 			threadContext[i].thread.Stop();
66 		}
67 	}
68 
69 	FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task)
70 	{
71 		FiberContext *fiberContext = task.awaitingFiber;
72 		if (fiberContext)
73 		{
74 			task.awaitingFiber = nullptr;
75 			return fiberContext;
76 		}
77 
78 		if (!availableFibers.TryPop(fiberContext))
79 		{
80 			ASSERT(false, "Fibers pool is empty");
81 		}
82 
83 		fiberContext->currentTask = task.desc;
84 		fiberContext->currentGroup = task.group;
85 		fiberContext->parentFiber = task.parentFiber;
86 		return fiberContext;
87 	}
88 
89 	void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext)
90 	{
91 		ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber");
92 		fiberContext->Reset();
93 		availableFibers.Push(fiberContext);
94 	}
95 
96 	FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext)
97 	{
98 		ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
99 
100 		ASSERT(fiberContext, "Invalid fiber context");
101 		ASSERT(fiberContext->currentTask.IsValid(), "Invalid task");
102 		ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group");
103 
104 		// Set actual thread context to fiber
105 		fiberContext->SetThreadContext(&threadContext);
106 
107 		// Update task status
108 		fiberContext->SetStatus(FiberTaskStatus::RUNNED);
109 
110 		ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
111 
112 		// Run current task code
113 		Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber);
114 
115 		// If task was done
116 		FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
117 		if (taskStatus == FiberTaskStatus::FINISHED)
118 		{
119 			TaskGroup::Type taskGroup = fiberContext->currentGroup;
120 			ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group.");
121 
122 			// Update group status
123 			int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec();
124 			ASSERT(groupTaskCount >= 0, "Sanity check failed!");
125 			if (groupTaskCount == 0)
126 			{
127 				// Restore awaiting tasks
128 				threadContext.RestoreAwaitingTasks(taskGroup);
129 				threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal();
130 			}
131 
132 			// Update total task count
133 			groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec();
134 			ASSERT(groupTaskCount >= 0, "Sanity check failed!");
135 			if (groupTaskCount == 0)
136 			{
137 				// Notify all tasks in all group finished
138 				threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal();
139 			}
140 
141 			FiberContext* parentFiberContext = fiberContext->parentFiber;
142 			if (parentFiberContext != nullptr)
143 			{
144 				int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec();
145 				ASSERT(childrenFibersCount >= 0, "Sanity check failed!");
146 
147 				if (childrenFibersCount == 0)
148 				{
149 					// This is a last subtask. Restore parent task
150 #if FIBER_DEBUG
151 
152 					int ownerThread = parentFiberContext->fiber.GetOwnerThread();
153 					FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus();
154 					internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext();
155 					int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter();
156 					ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state");
157 
158 					ownerThread;
159 					parentTaskStatus;
160 					parentThreadContext;
161 					fiberUsageCounter;
162 #endif
163 
164 					ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
165 					ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context");
166 
167 					// WARNING!! Thread context can changed here! Set actual current thread context.
168 					parentFiberContext->SetThreadContext(&threadContext);
169 
170 					ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
171 
172 					// All subtasks is done.
173 					// Exiting and return parent fiber to scheduler
174 					return parentFiberContext;
175 				} else
176 				{
177 					// Other subtasks still exist
178 					// Exiting
179 					return nullptr;
180 				}
181 			} else
182 			{
183 				// Task is finished and no parent task
184 				// Exiting
185 				return nullptr;
186 			}
187 		}
188 
189 		ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status")
190 		return nullptr;
191 	}
192 
193 
194 	void TaskScheduler::FiberMain(void* userData)
195 	{
196 		FiberContext& fiberContext = *(FiberContext*)(userData);
197 		for(;;)
198 		{
199 			ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context");
200 			ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group");
201 			ASSERT(fiberContext.GetThreadContext(), "Invalid thread context");
202 			ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
203 
204 			fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData );
205 
206 			fiberContext.SetStatus(FiberTaskStatus::FINISHED);
207 
208 			Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber);
209 		}
210 
211 	}
212 
213 
214 	void TaskScheduler::ThreadMain( void* userData )
215 	{
216 		internal::ThreadContext& context = *(internal::ThreadContext*)(userData);
217 		ASSERT(context.taskScheduler, "Task scheduler must be not null!");
218 		context.schedulerFiber.CreateFromThread(context.thread);
219 
220 		while(context.state.Get() != internal::ThreadState::EXIT)
221 		{
222 			internal::GroupedTask task;
223 			if (context.queue.TryPop(task))
224 			{
225 				// There is a new task
226 				FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task);
227 				ASSERT(fiberContext, "Can't get execution context from pool");
228 				ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed");
229 
230 				while(fiberContext)
231 				{
232 					// prevent invalid fiber resume from child tasks, before ExecuteTask is done
233 					fiberContext->childrenFibersCount.Inc();
234 
235 					FiberContext* parentFiber = ExecuteTask(context, fiberContext);
236 
237 					FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
238 
239 					//release guard
240 					int childrenFibersCount = fiberContext->childrenFibersCount.Dec();
241 
242 					// Can drop fiber context - task is finished
243 					if (taskStatus == FiberTaskStatus::FINISHED)
244 					{
245 						ASSERT( childrenFibersCount == 0, "Sanity check failed");
246 
247 						context.taskScheduler->ReleaseFiberContext(fiberContext);
248 
249 						// If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit
250 						fiberContext = parentFiber;
251 					} else
252 					{
253 						ASSERT( childrenFibersCount >= 0, "Sanity check failed");
254 
255 						// No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask
256 						if (childrenFibersCount == 0)
257 						{
258 							ASSERT(parentFiber == nullptr, "Sanity check failed");
259 						} else
260 						{
261 							// If subtasks still exist, drop current task execution. task will be resumed when last subtask finished
262 							break;
263 						}
264 
265 						// If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called
266 						if (taskStatus == FiberTaskStatus::AWAITING_GROUP)
267 						{
268 							break;
269 						}
270 					}
271 				} //while(fiberContext)
272 
273 			} else
274 			{
275 				// Queue is empty
276 				// TODO: can try to steal tasks from other threads
277 				context.hasNewTasksEvent.Wait(2000);
278 			}
279 
280 		} // main thread loop
281 	}
282 
283 	void TaskScheduler::RunTasksImpl(fixed_array<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState)
284 	{
285 		// Reset counter to initial value
286 		int taskCountInGroup[TaskGroup::COUNT];
287 		for (size_t i = 0; i < TaskGroup::COUNT; ++i)
288 		{
289 			taskCountInGroup[i] = 0;
290 		}
291 
292 		// Set parent fiber pointer
293 		// Calculate the number of tasks per group
294 		// Calculate total number of tasks
295 		size_t count = 0;
296 		for (size_t i = 0; i < buckets.size(); ++i)
297 		{
298 			internal::TaskBucket& bucket = buckets[i];
299 			for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++)
300 			{
301 				internal::GroupedTask & task = bucket.tasks[taskIndex];
302 
303 				ASSERT(task.group < TaskGroup::COUNT, "Invalid group.");
304 
305 				task.parentFiber = parentFiber;
306 				taskCountInGroup[task.group]++;
307 			}
308 			count += bucket.count;
309 		}
310 
311 		// Increments child fibers count on parent fiber
312 		if (parentFiber)
313 		{
314 			parentFiber->childrenFibersCount.Add((uint32)count);
315 		}
316 
317 		if (restoredFromAwaitState == false)
318 		{
319 			// Increments all task in progress counter
320 			allGroupStats.allDoneEvent.Reset();
321 			allGroupStats.inProgressTaskCount.Add((uint32)count);
322 
323 			// Increments task in progress counters (per group)
324 			for (size_t i = 0; i < TaskGroup::COUNT; ++i)
325 			{
326 				int groupTaskCount = taskCountInGroup[i];
327 				if (groupTaskCount > 0)
328 				{
329 					groupStats[i].allDoneEvent.Reset();
330 					groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount);
331 				}
332 			}
333 		} else
334 		{
335 			// If task's restored from await state, counters already in correct state
336 		}
337 
338 		// Add to thread queue
339 		for (size_t i = 0; i < buckets.size(); ++i)
340 		{
341 			int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount;
342 			internal::ThreadContext & context = threadContext[bucketIndex];
343 
344 			internal::TaskBucket& bucket = buckets[i];
345 
346 			context.queue.PushRange(bucket.tasks, bucket.count);
347 			context.hasNewTasksEvent.Signal();
348 		}
349 	}
350 
351 	bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds)
352 	{
353 		VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false);
354 
355 		return groupStats[group].allDoneEvent.Wait(milliseconds);
356 	}
357 
358 	bool TaskScheduler::WaitAll(uint32 milliseconds)
359 	{
360 		VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false);
361 
362 		return allGroupStats.allDoneEvent.Wait(milliseconds);
363 	}
364 
365 	bool TaskScheduler::IsEmpty()
366 	{
367 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
368 		{
369 			if (!threadContext[i].queue.IsEmpty())
370 			{
371 				return false;
372 			}
373 		}
374 		return true;
375 	}
376 
377 	uint32 TaskScheduler::GetWorkerCount() const
378 	{
379 		return threadsCount;
380 	}
381 
382 	bool TaskScheduler::IsWorkerThread() const
383 	{
384 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
385 		{
386 			if (threadContext[i].thread.IsCurrentThread())
387 			{
388 				return true;
389 			}
390 		}
391 		return false;
392 	}
393 
394 }
395