1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #pragma once
24 
25 #include <MTConfig.h>
26 #include <MTColorTable.h>
27 #include <MTTools.h>
28 #include <MTPlatform.h>
29 #include <MTQueueMPMC.h>
30 #include <MTArrayView.h>
31 #include <MTThreadContext.h>
32 #include <MTFiberContext.h>
33 #include <MTAppInterop.h>
34 #include <MTTaskPool.h>
35 #include <MTStackRequirements.h>
36 #include <Scopes/MTScopes.h>
37 
38 
39 namespace MT
40 {
41 
42 	template<typename CLASS_TYPE, typename MACRO_TYPE>
43 	struct CheckType
44 	{
45 		static_assert(std::is_same<CLASS_TYPE, MACRO_TYPE>::value, "Invalid type in MT_DECLARE_TASK macro. See CheckType template instantiation params to details.");
46 	};
47 
48 	struct TypeChecker
49 	{
50 		template <typename T>
51 		static T QueryThisType(T thisPtr)
52 		{
53 			MT_UNUSED(thisPtr);
54 			return (T)nullptr;
55 		}
56 	};
57 
58 
59 	template <typename T>
60 	inline void CallDtor(T* p)
61 	{
62 		MT_UNUSED(p);
63 		p->~T();
64 	}
65 
66 }
67 
68 #if MT_MSVC_COMPILER_FAMILY
69 
70 // Visual Studio compile time check
71 #define MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
72 	void CompileTimeCheckMethod() \
73 	{ \
74 		MT::CheckType< typename std::remove_pointer< decltype(MT::TypeChecker::QueryThisType(this)) >::type, typename TYPE > compileTypeTypesCheck; \
75 		compileTypeTypesCheck; \
76 	}
77 
78 #elif MT_GCC_COMPILER_FAMILY
79 
80 // GCC, Clang and other compilers compile time check
81 #define MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
82 	void CompileTimeCheckMethod() \
83 	{ \
84 		/* query this pointer type */ \
85 		typedef decltype(MT::TypeChecker::QueryThisType(this)) THIS_PTR_TYPE; \
86 		/* query class type from this pointer type */ \
87 		typedef typename std::remove_pointer<THIS_PTR_TYPE>::type CPP_TYPE; \
88 		/* define macro type */ \
89 		typedef TYPE MACRO_TYPE; \
90 		/* compile time checking that is same types */ \
91 		MT::CheckType< CPP_TYPE, MACRO_TYPE > compileTypeTypesCheck; \
92 		/* remove unused variable warning */ \
93 		MT_UNUSED(compileTypeTypesCheck); \
94 	}
95 
96 #else
97 
98 #error Platform is not supported.
99 
100 #endif
101 
102 
103 #define MT_DECLARE_TASK_IMPL(TYPE, STACK_REQUIREMENTS, TASK_PRIORITY) \
104 	\
105 	MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
106 	\
107 	static void TaskEntryPoint(MT::FiberContext& fiberContext, const void* userData) \
108 	{ \
109 		/* C style cast */ \
110 		TYPE * task = (TYPE *)(userData); \
111 		task->Do(fiberContext); \
112 	} \
113 	\
114 	static void PoolTaskDestroy(const void* userData) \
115 	{ \
116 		/* C style cast */ \
117 		TYPE * task = (TYPE *)(userData); \
118 		MT::CallDtor( task ); \
119 		/* Find task pool header */ \
120 		MT::PoolElementHeader * poolHeader = (MT::PoolElementHeader *)((char*)userData - sizeof(MT::PoolElementHeader)); \
121 		/* Fixup pool header, mark task as unused */ \
122 		poolHeader->id.Store(MT::TaskID::UNUSED); \
123 	} \
124 	\
125 	static MT::StackRequirements::Type GetStackRequirements() \
126 	{ \
127 		return STACK_REQUIREMENTS; \
128 	} \
129 	static MT::TaskPriority::Type GetTaskPriority() \
130 	{ \
131 		return TASK_PRIORITY; \
132 	} \
133 
134 
135 
136 #ifdef MT_INSTRUMENTED_BUILD
137 #include <MTProfilerEventListener.h>
138 
139 #define MT_DECLARE_TASK(TYPE, STACK_REQUIREMENTS, TASK_PRIORITY, DEBUG_COLOR) \
140 	static const mt_char* GetDebugID() \
141 	{ \
142 		return MT_TEXT( #TYPE ); \
143 	} \
144 	\
145 	static MT::Color::Type GetDebugColor() \
146 	{ \
147 		return DEBUG_COLOR; \
148 	} \
149 	\
150 	MT_DECLARE_TASK_IMPL(TYPE, STACK_REQUIREMENTS, TASK_PRIORITY);
151 
152 
153 #else
154 
155 #define MT_DECLARE_TASK(TYPE, STACK_REQUIREMENTS, TASK_PRIORITY, DEBUG_COLOR) \
156 	MT_DECLARE_TASK_IMPL(TYPE, STACK_REQUIREMENTS, TASK_PRIORITY);
157 
158 #endif
159 
160 
161 
162 
163 #if defined(MT_DEBUG) || defined(MT_INSTRUMENTED_BUILD)
164 #define MT_GROUP_DEBUG (1)
165 #endif
166 
167 
168 
169 namespace MT
170 {
171 	const uint32 MT_MAX_THREAD_COUNT = 64;
172 	const uint32 MT_SCHEDULER_STACK_SIZE = 1048576; // 1Mb
173 
174 	const uint32 MT_MAX_STANDART_FIBERS_COUNT = 256;
175 	const uint32 MT_STANDART_FIBER_STACK_SIZE = 32768; //32Kb
176 
177 	const uint32 MT_MAX_EXTENDED_FIBERS_COUNT = 8;
178 	const uint32 MT_EXTENDED_FIBER_STACK_SIZE = 1048576; // 1Mb
179 
180 	namespace internal
181 	{
182 		struct ThreadContext;
183 	}
184 
185 	namespace TaskStealingMode
186 	{
187 		enum Type
188 		{
189 			DISABLED = 0,
190 			ENABLED = 1,
191 		};
192 	}
193 
194 	struct WorkerThreadParams
195 	{
196 		uint32 core;
197 		ThreadPriority::Type priority;
198 
199 		WorkerThreadParams()
200 			: core(MT_CPUCORE_ANY)
201 			, priority(ThreadPriority::DEFAULT)
202 		{
203 		}
204 	};
205 
206 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
207 	// Task scheduler
208 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
209 	class TaskScheduler
210 	{
211 		friend class FiberContext;
212 		friend struct internal::ThreadContext;
213 
214 
215 
216 		////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
217 		// Task group description
218 		////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
219 		// Application can assign task group to task and later wait until group was finished.
220 		class TaskGroupDescription
221 		{
222 			Atomic32<int32> inProgressTaskCount;
223 			Event allDoneEvent;
224 
225 #if MT_GROUP_DEBUG
226 			bool debugIsFree;
227 #endif
228 
229 		public:
230 
231 			MT_NOCOPYABLE(TaskGroupDescription);
232 
233 			TaskGroupDescription()
234 			{
235 				inProgressTaskCount.Store(0);
236 				allDoneEvent.Create( EventReset::MANUAL, true );
237 
238 #if MT_GROUP_DEBUG
239 				debugIsFree = true;
240 #endif
241 			}
242 
243 			int GetTaskCount() const
244 			{
245 				return inProgressTaskCount.Load();
246 			}
247 
248 			int Dec()
249 			{
250 				return inProgressTaskCount.DecFetch();
251 			}
252 
253 			int Inc()
254 			{
255 				return inProgressTaskCount.IncFetch();
256 			}
257 
258 			int Add(int sum)
259 			{
260 				return inProgressTaskCount.AddFetch(sum);
261 			}
262 
263 			void Signal()
264 			{
265 				allDoneEvent.Signal();
266 			}
267 
268 			void Reset()
269 			{
270 				allDoneEvent.Reset();
271 			}
272 
273 			bool Wait(uint32 milliseconds)
274 			{
275 				return allDoneEvent.Wait(milliseconds);
276 			}
277 
278 #if MT_GROUP_DEBUG
279 			void SetDebugIsFree(bool _debugIsFree)
280 			{
281 				debugIsFree = _debugIsFree;
282 			}
283 
284 			bool GetDebugIsFree() const
285 			{
286 				return debugIsFree;
287 			}
288 #endif
289 		};
290 
291 
292 		// Thread index for new task
293 		Atomic32<int32> roundRobinThreadIndex;
294 
295 		// Started threads count
296 		Atomic32<int32> startedThreadsCount;
297 
298 		// Threads created by task manager
299 		Atomic32<int32> threadsCount;
300 		internal::ThreadContext threadContext[MT_MAX_THREAD_COUNT];
301 
302 		// All groups task statistic
303 		TaskGroupDescription allGroups;
304 
305 		// Groups pool
306 		LockFreeQueueMPMC<TaskGroup, TaskGroup::MT_MAX_GROUPS_COUNT * 2> availableGroups;
307 
308 		//
309 		TaskGroupDescription groupStats[TaskGroup::MT_MAX_GROUPS_COUNT];
310 
311 		// Fibers context
312 		FiberContext standartFiberContexts[MT_MAX_STANDART_FIBERS_COUNT];
313 		FiberContext extendedFiberContexts[MT_MAX_EXTENDED_FIBERS_COUNT];
314 
315 		// Fibers pool
316 		LockFreeQueueMPMC<FiberContext*, MT_MAX_STANDART_FIBERS_COUNT * 2> standartFibersAvailable;
317 		LockFreeQueueMPMC<FiberContext*, MT_MAX_EXTENDED_FIBERS_COUNT * 2> extendedFibersAvailable;
318 
319 #ifdef MT_INSTRUMENTED_BUILD
320 		IProfilerEventListener * profilerEventListener;
321 #endif
322 
323 		bool taskStealingDisabled;
324 
325 		FiberContext* RequestFiberContext(internal::GroupedTask& task);
326 		void ReleaseFiberContext(FiberContext*&& fiberExecutionContext);
327 		void RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState);
328 		TaskGroupDescription & GetGroupDesc(TaskGroup group);
329 
330 		static void WorkerThreadMain( void* userData );
331 		static void SchedulerFiberMain( void* userData );
332 		static void FiberMain( void* userData );
333 		static bool TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount, bool taskStealingDisabled);
334 
335 		static FiberContext* ExecuteTask (internal::ThreadContext& threadContext, FiberContext* fiberContext);
336 
337 	public:
338 
339 		/// \brief Initializes a new instance of the TaskScheduler class.
340 		/// \param workerThreadsCount Worker threads count. Automatically determines the required number of threads if workerThreadsCount set to 0
341 #ifdef MT_INSTRUMENTED_BUILD
342 		TaskScheduler(uint32 workerThreadsCount = 0, WorkerThreadParams* workerParameters = nullptr, IProfilerEventListener* listener = nullptr, TaskStealingMode::Type stealMode = TaskStealingMode::ENABLED);
343 #else
344 		TaskScheduler(uint32 workerThreadsCount = 0, WorkerThreadParams* workerParameters = nullptr, TaskStealingMode::Type stealMode = TaskStealingMode::ENABLED);
345 #endif
346 
347 
348 		~TaskScheduler();
349 
350 		template<class TTask>
351 		void RunAsync(TaskGroup group, const TTask* taskArray, uint32 taskCount);
352 
353 		void RunAsync(TaskGroup group, const TaskHandle* taskHandleArray, uint32 taskHandleCount);
354 
355 		/// \brief Wait while no more tasks in specific group.
356 		/// \return true - if no more tasks in specific group. false - if timeout in milliseconds has reached and group still has some tasks.
357 		bool WaitGroup(TaskGroup group, uint32 milliseconds);
358 
359 		bool WaitAll(uint32 milliseconds);
360 
361 		TaskGroup CreateGroup();
362 		void ReleaseGroup(TaskGroup group);
363 
364 		int32 GetWorkersCount() const;
365 
366 		bool IsTaskStealingDisabled() const;
367 
368 		bool IsWorkerThread() const;
369 
370 #ifdef MT_INSTRUMENTED_BUILD
371 
372 		inline IProfilerEventListener* GetProfilerEventListener()
373 		{
374 			return profilerEventListener;
375 		}
376 
377 #endif
378 	};
379 }
380 
381 #include "MTScheduler.inl"
382 #include "MTFiberContext.inl"
383