1 #pragma once
2 
3 #include <MTTools.h>
4 #include <MTPlatform.h>
5 #include <MTConcurrentQueueLIFO.h>
6 #include <MTStackArray.h>
7 #include <MTFixedArray.h>
8 
9 
10 namespace MT
11 {
12 	const uint32 MT_MAX_THREAD_COUNT = 32;
13 	const uint32 MT_MAX_FIBERS_COUNT = 128;
14 	const uint32 MT_SCHEDULER_STACK_SIZE = 131072;
15 	const uint32 MT_FIBER_STACK_SIZE = 32768;
16 
17 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
18 	// Task group
19 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
20 	// Application can wait until whole group was finished.
21 	namespace TaskGroup
22 	{
23 		enum Type
24 		{
25 			GROUP_0 = 0,
26 			GROUP_1 = 1,
27 			GROUP_2 = 2,
28 
29 			COUNT,
30 
31 			GROUP_UNDEFINED
32 		};
33 	}
34 
35 
36 	class FiberContext;
37 
38 	typedef void (*TTaskEntryPoint)(FiberContext & context, void* userData);
39 
40 
41 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
42 	// Task description
43 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
44 	struct TaskDesc
45 	{
46 		//Task entry point
47 		TTaskEntryPoint taskFunc;
48 
49 		//Task user data (task context)
50 		void* userData;
51 
52 		TaskDesc()
53 			: taskFunc(nullptr)
54 			, userData(nullptr)
55 		{
56 		}
57 
58 		TaskDesc(TTaskEntryPoint _taskFunc, void* _userData)
59 			: taskFunc(_taskFunc)
60 			, userData(_userData)
61 		{
62 		}
63 
64 		bool IsValid()
65 		{
66 			return (taskFunc != nullptr);
67 		}
68 	};
69 
70 	struct GroupedTask;
71 	struct ThreadContext;
72 
73 	class TaskScheduler;
74 
75 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
76 	struct TaskBucket
77 	{
78 		GroupedTask* tasks;
79 		size_t count;
80 		TaskBucket(GroupedTask* _tasks, size_t _count)
81 			: tasks(_tasks)
82 			, count(_count)
83 		{
84 		}
85 	};
86 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
87 
88 
89 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
90 	// Fiber task status
91 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
92 	// Task can be completed for several reasons.
93 	// For example task was done or someone call Yield from the Task body.
94 	namespace FiberTaskStatus
95 	{
96 		enum Type
97 		{
98 			UNKNOWN = 0,
99 			RUNNED = 1,
100 			FINISHED = 2,
101 			AWAITING_GROUP = 3,
102 			AWAITING_CHILD = 4,
103 		};
104 	}
105 
106 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
107 	// Fiber context
108 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
109 	// Context passed to fiber main function
110 	class FiberContext
111 	{
112 	private:
113 
114 		void RunSubtasksAndYieldImpl(fixed_array<TaskBucket>& buckets);
115 
116 	public:
117 
118 		FiberContext();
119 
120 		template<class TTask>
121 		void RunSubtasksAndYield(TaskGroup::Type taskGroup, const TTask* taskArray, size_t count);
122 
123 		template<class TTask>
124 		void RunAsync(TaskGroup::Type taskGroup, TTask* taskArray, uint32 count);
125 
126 		void WaitGroupAndYield(TaskGroup::Type group);
127 
128 		void Reset();
129 
130 		void SetThreadContext(ThreadContext * _threadContext);
131 		ThreadContext* GetThreadContext();
132 
133 		void SetStatus(FiberTaskStatus::Type _taskStatus);
134 		FiberTaskStatus::Type GetStatus() const;
135 
136 	private:
137 
138 		// Active thread context (null if fiber context is not executing now)
139 		ThreadContext * threadContext;
140 
141 		// Active task status
142 		FiberTaskStatus::Type taskStatus;
143 
144 	public:
145 
146 		// Active task attached to this fiber
147 		TaskDesc currentTask;
148 
149 
150 		// Active task group
151 		TaskGroup::Type currentGroup;
152 
153 		// Number of children fibers
154 		AtomicInt childrenFibersCount;
155 
156 		// Parent fiber
157 		FiberContext* parentFiber;
158 
159 		// System Fiber
160 		Fiber fiber;
161 
162 		// Prevent false sharing between threads
163 		uint8 cacheline[64];
164 	};
165 
166 
167 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
168 	struct ThreadState
169 	{
170 		enum Type
171 		{
172 			ALIVE,
173 			EXIT,
174 		};
175 	};
176 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
177 	struct GroupedTask
178 	{
179 		FiberContext* awaitingFiber;
180 		FiberContext* parentFiber;
181 		TaskGroup::Type group;
182 		TaskDesc desc;
183 
184 		GroupedTask()
185 			: parentFiber(nullptr)
186 			, awaitingFiber(nullptr)
187 			, group(TaskGroup::GROUP_UNDEFINED)
188 		{}
189 
190 		GroupedTask(TaskDesc& _desc, TaskGroup::Type _group)
191 			: parentFiber(nullptr)
192 			, awaitingFiber(nullptr)
193 			, group(_group)
194 			, desc(_desc)
195 		{}
196 	};
197 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
198 	// Thread (Scheduler fiber) context
199 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
200 	struct ThreadContext
201 	{
202 		FiberContext* lastActiveFiberContext;
203 
204 		// pointer to task manager
205 		TaskScheduler* taskScheduler;
206 
207 		// thread
208 		Thread thread;
209 
210 		// scheduler fiber
211 		Fiber schedulerFiber;
212 
213 		// task queue awaiting execution
214 		ConcurrentQueueLIFO<GroupedTask> queue;
215 
216 		// new task was arrived to queue event
217 		Event hasNewTasksEvent;
218 
219 		// whether thread is alive
220 		AtomicInt state;
221 
222 		// Temporary buffer
223 		std::vector<GroupedTask> descBuffer;
224 
225 		// prevent false sharing between threads
226 		uint8 cacheline[64];
227 
228 		ThreadContext();
229 		~ThreadContext();
230 
231 		void RestoreAwaitingTasks(TaskGroup::Type taskGroup);
232 	};
233 
234 
235 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
236 	// Task scheduler
237 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
238 	class TaskScheduler
239 	{
240 		friend class FiberContext;
241 		friend struct ThreadContext;
242 
243 		struct GroupStats
244 		{
245 			AtomicInt inProgressTaskCount;
246 			Event allDoneEvent;
247 
248 			GroupStats()
249 			{
250 				inProgressTaskCount.Set(0);
251 				allDoneEvent.Create( EventReset::MANUAL, true );
252 			}
253 		};
254 
255 		// Thread index for new task
256 		AtomicInt roundRobinThreadIndex;
257 
258 		// Threads created by task manager
259 		uint32 threadsCount;
260 		ThreadContext threadContext[MT_MAX_THREAD_COUNT];
261 
262 		// Per group task statistic
263 		GroupStats groupStats[TaskGroup::COUNT];
264 
265 		// All groups task statistic
266 		GroupStats allGroupStats;
267 
268 
269 		//Task awaiting group through FiberContext::WaitGroupAndYield call
270 		ConcurrentQueueLIFO<FiberContext*> waitTaskQueues[TaskGroup::COUNT];
271 
272 
273 		// Fibers pool
274 		ConcurrentQueueLIFO<FiberContext*> availableFibers;
275 
276 		// Fibers context
277 		FiberContext fiberContext[MT_MAX_FIBERS_COUNT];
278 
279 		FiberContext* RequestFiberContext(GroupedTask& task);
280 		void ReleaseFiberContext(FiberContext* fiberExecutionContext);
281 
282 		void RunTasksImpl(fixed_array<TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState);
283 
284 		static void ThreadMain( void* userData );
285 		static void FiberMain( void* userData );
286 		static FiberContext* ExecuteTask (ThreadContext& threadContext, FiberContext* fiberContext);
287 
288 
289 		template<class T>
290 		GroupedTask GetGroupedTask(TaskGroup::Type group, T * src) const
291 		{
292 			TaskDesc desc(T::TaskEntryPoint, (void*)(src));
293 			return GroupedTask(desc, group);
294 		}
295 
296 		//template specialization for FiberContext*
297 		template<>
298 		GroupedTask GetGroupedTask(TaskGroup::Type group, FiberContext ** src) const
299 		{
300 			ASSERT(group == TaskGroup::GROUP_UNDEFINED, "Group must be GROUP_UNDEFINED");
301 			FiberContext * fiberContext = *src;
302 			GroupedTask groupedTask(fiberContext->currentTask, fiberContext->currentGroup);
303 			groupedTask.awaitingFiber = fiberContext;
304 			return groupedTask;
305 		}
306 
307 
308 		////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
309 		// Distributes task to threads:
310 		// | Task1 | Task2 | Task3 | Task4 | Task5 | Task6 |
311 		// ThreadCount = 4
312 		// Thread0: Task1, Task5
313 		// Thread1: Task2, Task6
314 		// Thread2: Task3
315 		// Thread3: Task4
316 		////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
317 		template<class TTask>
318 		bool DistibuteDescriptions(TaskGroup::Type group, TTask* taskArray, fixed_array<GroupedTask>& descriptions, fixed_array<TaskBucket>& buckets) const
319 		{
320 			size_t index = 0;
321 
322 			for (size_t bucketIndex = 0; (bucketIndex < buckets.size()) && (index < descriptions.size()); ++bucketIndex)
323 			{
324 				size_t bucketStartIndex = index;
325 
326 				for (size_t i = bucketIndex; i < descriptions.size(); i += buckets.size())
327 				{
328 					descriptions[index++] = GetGroupedTask(group, &taskArray[i]);
329 				}
330 
331 				buckets[bucketIndex] = TaskBucket(&descriptions[bucketStartIndex], index - bucketStartIndex);
332 			}
333 
334 			ASSERT(index == descriptions.size(), "Sanity check")
335 
336 			return index > 0;
337 		}
338 
339 	public:
340 
341 		TaskScheduler();
342 		~TaskScheduler();
343 
344 		template<class TTask>
345 		void RunAsync(TaskGroup::Type group, TTask* taskArray, uint32 count);
346 
347 		bool WaitGroup(TaskGroup::Type group, uint32 milliseconds);
348 		bool WaitAll(uint32 milliseconds);
349 
350 		bool IsEmpty();
351 
352 		uint32 GetWorkerCount() const;
353 
354 		bool IsWorkerThread() const;
355 	};
356 
357     template<class TTask>
358     void TaskScheduler::RunAsync(TaskGroup::Type group, TTask* taskArray, uint32 count)
359     {
360         ASSERT(!IsWorkerThread(), "Can't use RunAsync inside Task. Use FiberContext.RunAsync() instead.");
361 
362         fixed_array<GroupedTask> buffer(ALLOCATE_ON_STACK(GroupedTask, count), count);
363 
364         size_t bucketCount = Min(threadsCount, count);
365         fixed_array<TaskBucket>	buckets(ALLOCATE_ON_STACK(TaskBucket, bucketCount), bucketCount);
366 
367         DistibuteDescriptions(group, taskArray, buffer, buckets);
368         RunTasksImpl(buckets, nullptr, false);
369     }
370 
371     template<class TTask>
372     void FiberContext::RunSubtasksAndYield(TaskGroup::Type taskGroup, const TTask* taskArray, size_t count)
373     {
374         ASSERT(threadContext, "ThreadContext is NULL");
375         ASSERT(count < threadContext->descBuffer.size(), "Buffer overrun!")
376 
377         size_t threadsCount = threadContext->taskScheduler->GetWorkerCount();
378 
379         fixed_array<GroupedTask> buffer(&threadContext->descBuffer.front(), count);
380 
381         size_t bucketCount = Min(threadsCount, count);
382         fixed_array<TaskBucket>	buckets(ALLOCATE_ON_STACK(TaskBucket, bucketCount), bucketCount);
383 
384         threadContext->taskScheduler->DistibuteDescriptions(taskGroup, taskArray, buffer, buckets);
385         RunSubtasksAndYieldImpl(buckets);
386     }
387 
388     template<class TTask>
389     void FiberContext::RunAsync(TaskGroup::Type taskGroup, TTask* taskArray, uint32 count)
390     {
391         ASSERT(threadContext, "ThreadContext is NULL");
392         ASSERT(threadContext->taskScheduler->IsWorkerThread(), "Can't use RunAsync outside Task. Use TaskScheduler.RunAsync() instead.");
393 
394         TaskScheduler& scheduler = *(threadContext->taskScheduler);
395 
396         fixed_array<GroupedTask> buffer(&threadContext->descBuffer.front(), count);
397 
398         size_t bucketCount = Min(scheduler.GetWorkerCount(), count);
399         fixed_array<TaskBucket>	buckets(ALLOCATE_ON_STACK(TaskBucket, bucketCount), bucketCount);
400 
401         scheduler.DistibuteDescriptions(taskGroup, taskArray, buffer, buckets);
402         scheduler.RunTasksImpl(buckets, nullptr, false);
403     }
404 
405 
406 
407 
408 
409 		template<typename T>
410 		struct TaskBase
411 		{
412 			static void TaskEntryPoint(MT::FiberContext& fiberContext, void* userData)
413 			{
414 				T* task = static_cast<T*>(userData);
415 				task->Do(fiberContext);
416 			}
417 		};
418 
419 
420 
421 }
422