1 #include <MTScheduler.h>
2 
3 namespace MT
4 {
5 
6 	TaskScheduler::TaskScheduler()
7 		: roundRobinThreadIndex(0)
8 	{
9 		//query number of processor
10 		threadsCount = Max(Thread::GetNumberOfHardwareThreads() - 2, 1);
11 
12 		if (threadsCount > MT_MAX_THREAD_COUNT)
13 		{
14 			threadsCount = MT_MAX_THREAD_COUNT;
15 		}
16 
17 		// create fiber pool
18 		for (uint32 i = 0; i < MT_MAX_FIBERS_COUNT; i++)
19 		{
20 			FiberContext& context = fiberContext[i];
21 			context.fiber.Create(MT_FIBER_STACK_SIZE, FiberMain, &context);
22 			availableFibers.Push( &context );
23 		}
24 
25 		// create worker thread pool
26 		for (uint32 i = 0; i < threadsCount; i++)
27 		{
28 			threadContext[i].taskScheduler = this;
29 			threadContext[i].thread.Start( MT_SCHEDULER_STACK_SIZE, ThreadMain, &threadContext[i] );
30 		}
31 	}
32 
33 	TaskScheduler::~TaskScheduler()
34 	{
35 		for (uint32 i = 0; i < threadsCount; i++)
36 		{
37 			threadContext[i].state.Set(internal::ThreadState::EXIT);
38 			threadContext[i].hasNewTasksEvent.Signal();
39 		}
40 
41 		for (uint32 i = 0; i < threadsCount; i++)
42 		{
43 			threadContext[i].thread.Stop();
44 		}
45 	}
46 
47 	FiberContext* TaskScheduler::RequestFiberContext(internal::GroupedTask& task)
48 	{
49 		FiberContext *fiberContext = task.awaitingFiber;
50 		if (fiberContext)
51 		{
52 			task.awaitingFiber = nullptr;
53 			return fiberContext;
54 		}
55 
56 		if (!availableFibers.TryPop(fiberContext))
57 		{
58 			ASSERT(false, "Fibers pool is empty");
59 		}
60 
61 		fiberContext->currentTask = task.desc;
62 		fiberContext->currentGroup = task.group;
63 		fiberContext->parentFiber = task.parentFiber;
64 		return fiberContext;
65 	}
66 
67 	void TaskScheduler::ReleaseFiberContext(FiberContext* fiberContext)
68 	{
69 		ASSERT(fiberContext != nullptr, "Can't release nullptr Fiber");
70 		fiberContext->Reset();
71 		availableFibers.Push(fiberContext);
72 	}
73 
74 	FiberContext* TaskScheduler::ExecuteTask(internal::ThreadContext& threadContext, FiberContext* fiberContext)
75 	{
76 		ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
77 
78 		ASSERT(fiberContext, "Invalid fiber context");
79 		ASSERT(fiberContext->currentTask.IsValid(), "Invalid task");
80 		ASSERT(fiberContext->currentGroup < TaskGroup::COUNT, "Invalid task group");
81 
82 		// Set actual thread context to fiber
83 		fiberContext->SetThreadContext(&threadContext);
84 
85 		// Update task status
86 		fiberContext->SetStatus(FiberTaskStatus::RUNNED);
87 
88 		ASSERT(fiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
89 
90 		// Run current task code
91 		Fiber::SwitchTo(threadContext.schedulerFiber, fiberContext->fiber);
92 
93 		// If task was done
94 		FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
95 		if (taskStatus == FiberTaskStatus::FINISHED)
96 		{
97 			TaskGroup::Type taskGroup = fiberContext->currentGroup;
98 			ASSERT(taskGroup < TaskGroup::COUNT, "Invalid group.");
99 
100 			// Update group status
101 			int groupTaskCount = threadContext.taskScheduler->groupStats[taskGroup].inProgressTaskCount.Dec();
102 			ASSERT(groupTaskCount >= 0, "Sanity check failed!");
103 			if (groupTaskCount == 0)
104 			{
105 				// Restore awaiting tasks
106 				threadContext.RestoreAwaitingTasks(taskGroup);
107 				threadContext.taskScheduler->groupStats[taskGroup].allDoneEvent.Signal();
108 			}
109 
110 			// Update total task count
111 			groupTaskCount = threadContext.taskScheduler->allGroupStats.inProgressTaskCount.Dec();
112 			ASSERT(groupTaskCount >= 0, "Sanity check failed!");
113 			if (groupTaskCount == 0)
114 			{
115 				// Notify all tasks in all group finished
116 				threadContext.taskScheduler->allGroupStats.allDoneEvent.Signal();
117 			}
118 
119 			FiberContext* parentFiberContext = fiberContext->parentFiber;
120 			if (parentFiberContext != nullptr)
121 			{
122 				int childrenFibersCount = parentFiberContext->childrenFibersCount.Dec();
123 				ASSERT(childrenFibersCount >= 0, "Sanity check failed!");
124 
125 				if (childrenFibersCount == 0)
126 				{
127 					// This is a last subtask. Restore parent task
128 #if FIBER_DEBUG
129 
130 					int ownerThread = parentFiberContext->fiber.GetOwnerThread();
131 					FiberTaskStatus::Type parentTaskStatus = parentFiberContext->GetStatus();
132 					internal::ThreadContext * parentThreadContext = parentFiberContext->GetThreadContext();
133 					int fiberUsageCounter = parentFiberContext->fiber.GetUsageCounter();
134 					ASSERT(fiberUsageCounter == 0, "Parent fiber in invalid state");
135 
136 					ownerThread;
137 					parentTaskStatus;
138 					parentThreadContext;
139 					fiberUsageCounter;
140 #endif
141 
142 					ASSERT(threadContext.thread.IsCurrentThread(), "Thread context sanity check failed");
143 					ASSERT(parentFiberContext->GetThreadContext() == nullptr, "Inactive parent should not have a valid thread context");
144 
145 					// WARNING!! Thread context can changed here! Set actual current thread context.
146 					parentFiberContext->SetThreadContext(&threadContext);
147 
148 					ASSERT(parentFiberContext->GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
149 
150 					// All subtasks is done.
151 					// Exiting and return parent fiber to scheduler
152 					return parentFiberContext;
153 				} else
154 				{
155 					// Other subtasks still exist
156 					// Exiting
157 					return nullptr;
158 				}
159 			} else
160 			{
161 				// Task is finished and no parent task
162 				// Exiting
163 				return nullptr;
164 			}
165 		}
166 
167 		ASSERT(taskStatus != FiberTaskStatus::RUNNED, "Incorrect task status")
168 		return nullptr;
169 	}
170 
171 
172 	void TaskScheduler::FiberMain(void* userData)
173 	{
174 		FiberContext& fiberContext = *(FiberContext*)(userData);
175 		for(;;)
176 		{
177 			ASSERT(fiberContext.currentTask.IsValid(), "Invalid task in fiber context");
178 			ASSERT(fiberContext.currentGroup < TaskGroup::COUNT, "Invalid task group");
179 			ASSERT(fiberContext.GetThreadContext(), "Invalid thread context");
180 			ASSERT(fiberContext.GetThreadContext()->thread.IsCurrentThread(), "Thread context sanity check failed");
181 
182 			fiberContext.currentTask.taskFunc( fiberContext, fiberContext.currentTask.userData );
183 
184 			fiberContext.SetStatus(FiberTaskStatus::FINISHED);
185 
186 			Fiber::SwitchTo(fiberContext.fiber, fiberContext.GetThreadContext()->schedulerFiber);
187 		}
188 
189 	}
190 
191 
192 	void TaskScheduler::ThreadMain( void* userData )
193 	{
194 		internal::ThreadContext& context = *(internal::ThreadContext*)(userData);
195 		ASSERT(context.taskScheduler, "Task scheduler must be not null!");
196 		context.schedulerFiber.CreateFromThread(context.thread);
197 
198 		while(context.state.Get() != internal::ThreadState::EXIT)
199 		{
200 			internal::GroupedTask task;
201 			if (context.queue.TryPop(task))
202 			{
203 				// There is a new task
204 				FiberContext* fiberContext = context.taskScheduler->RequestFiberContext(task);
205 				ASSERT(fiberContext, "Can't get execution context from pool");
206 				ASSERT(fiberContext->currentTask.IsValid(), "Sanity check failed");
207 
208 				while(fiberContext)
209 				{
210 					// prevent invalid fiber resume from child tasks, before ExecuteTask is done
211 					fiberContext->childrenFibersCount.Inc();
212 
213 					FiberContext* parentFiber = ExecuteTask(context, fiberContext);
214 
215 					FiberTaskStatus::Type taskStatus = fiberContext->GetStatus();
216 
217 					//release guard
218 					int childrenFibersCount = fiberContext->childrenFibersCount.Dec();
219 
220 					// Can drop fiber context - task is finished
221 					if (taskStatus == FiberTaskStatus::FINISHED)
222 					{
223 						ASSERT( childrenFibersCount == 0, "Sanity check failed");
224 
225 						context.taskScheduler->ReleaseFiberContext(fiberContext);
226 
227 						// If parent fiber is exist transfer flow control to parent fiber, if parent fiber is null, exit
228 						fiberContext = parentFiber;
229 					} else
230 					{
231 						ASSERT( childrenFibersCount >= 0, "Sanity check failed");
232 
233 						// No subtasks here and status is not finished, this mean all subtasks already finished before parent return from ExecuteTask
234 						if (childrenFibersCount == 0)
235 						{
236 							ASSERT(parentFiber == nullptr, "Sanity check failed");
237 						} else
238 						{
239 							// If subtasks still exist, drop current task execution. task will be resumed when last subtask finished
240 							break;
241 						}
242 
243 						// If task is in await state drop execution. task will be resumed when RestoreAwaitingTasks called
244 						if (taskStatus == FiberTaskStatus::AWAITING_GROUP)
245 						{
246 							break;
247 						}
248 					}
249 				} //while(fiberContext)
250 
251 			} else
252 			{
253 				// Queue is empty
254 				// TODO: can try to steal tasks from other threads
255 				context.hasNewTasksEvent.Wait(2000);
256 			}
257 
258 		} // main thread loop
259 	}
260 
261 	void TaskScheduler::RunTasksImpl(fixed_array<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState)
262 	{
263 		// Reset counter to initial value
264 		int taskCountInGroup[TaskGroup::COUNT];
265 		for (size_t i = 0; i < TaskGroup::COUNT; ++i)
266 		{
267 			taskCountInGroup[i] = 0;
268 		}
269 
270 		// Set parent fiber pointer
271 		// Calculate the number of tasks per group
272 		// Calculate total number of tasks
273 		size_t count = 0;
274 		for (size_t i = 0; i < buckets.size(); ++i)
275 		{
276 			internal::TaskBucket& bucket = buckets[i];
277 			for (size_t taskIndex = 0; taskIndex < bucket.count; taskIndex++)
278 			{
279 				internal::GroupedTask & task = bucket.tasks[taskIndex];
280 
281 				ASSERT(task.group < TaskGroup::COUNT, "Invalid group.");
282 
283 				task.parentFiber = parentFiber;
284 				taskCountInGroup[task.group]++;
285 			}
286 			count += bucket.count;
287 		}
288 
289 		// Increments child fibers count on parent fiber
290 		if (parentFiber)
291 		{
292 			parentFiber->childrenFibersCount.Add((uint32)count);
293 		}
294 
295 		if (restoredFromAwaitState == false)
296 		{
297 			// Increments all task in progress counter
298 			allGroupStats.allDoneEvent.Reset();
299 			allGroupStats.inProgressTaskCount.Add((uint32)count);
300 
301 			// Increments task in progress counters (per group)
302 			for (size_t i = 0; i < TaskGroup::COUNT; ++i)
303 			{
304 				int groupTaskCount = taskCountInGroup[i];
305 				if (groupTaskCount > 0)
306 				{
307 					groupStats[i].allDoneEvent.Reset();
308 					groupStats[i].inProgressTaskCount.Add((uint32)groupTaskCount);
309 				}
310 			}
311 		} else
312 		{
313 			// If task's restored from await state, counters already in correct state
314 		}
315 
316 		// Add to thread queue
317 		for (size_t i = 0; i < buckets.size(); ++i)
318 		{
319 			int bucketIndex = roundRobinThreadIndex.Inc() % threadsCount;
320 			internal::ThreadContext & context = threadContext[bucketIndex];
321 
322 			internal::TaskBucket& bucket = buckets[i];
323 
324 			context.queue.PushRange(bucket.tasks, bucket.count);
325 			context.hasNewTasksEvent.Signal();
326 		}
327 	}
328 
329 	bool TaskScheduler::WaitGroup(TaskGroup::Type group, uint32 milliseconds)
330 	{
331 		VERIFY(IsWorkerThread() == false, "Can't use WaitGroup inside Task. Use FiberContext.WaitGroupAndYield() instead.", return false);
332 
333 		return groupStats[group].allDoneEvent.Wait(milliseconds);
334 	}
335 
336 	bool TaskScheduler::WaitAll(uint32 milliseconds)
337 	{
338 		VERIFY(IsWorkerThread() == false, "Can't use WaitAll inside Task.", return false);
339 
340 		return allGroupStats.allDoneEvent.Wait(milliseconds);
341 	}
342 
343 	bool TaskScheduler::IsEmpty()
344 	{
345 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
346 		{
347 			if (!threadContext[i].queue.IsEmpty())
348 			{
349 				return false;
350 			}
351 		}
352 		return true;
353 	}
354 
355 	uint32 TaskScheduler::GetWorkerCount() const
356 	{
357 		return threadsCount;
358 	}
359 
360 	bool TaskScheduler::IsWorkerThread() const
361 	{
362 		for (uint32 i = 0; i < MT_MAX_THREAD_COUNT; i++)
363 		{
364 			if (threadContext[i].thread.IsCurrentThread())
365 			{
366 				return true;
367 			}
368 		}
369 		return false;
370 	}
371 
372 }
373