1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #pragma once
24 
25 #include <MTConfig.h>
26 #include <MTColorTable.h>
27 #include <MTTools.h>
28 #include <MTPlatform.h>
29 #include <MTQueueMPMC.h>
30 #include <MTArrayView.h>
31 #include <MTThreadContext.h>
32 #include <MTFiberContext.h>
33 #include <MTAppInterop.h>
34 #include <MTTaskPool.h>
35 #include <MTStackRequirements.h>
36 #include <Scopes/MTScopes.h>
37 
38 /*
39 	You can inject some profiler code right into the task scope using this macro.
40 */
41 #ifndef MT_SCHEDULER_PROFILER_TASK_SCOPE_CODE_INJECTION
42 #define MT_SCHEDULER_PROFILER_TASK_SCOPE_CODE_INJECTION( TYPE, DEBUG_COLOR, SRC_FILE, SRC_LINE)
43 #endif
44 
45 namespace MT
46 {
47 
48 	template<typename CLASS_TYPE, typename MACRO_TYPE>
49 	struct CheckType
50 	{
51 		static_assert(std::is_same<CLASS_TYPE, MACRO_TYPE>::value, "Invalid type in MT_DECLARE_TASK macro. See CheckType template instantiation params to details.");
52 	};
53 
54 	struct TypeChecker
55 	{
56 		template <typename T>
57 		static T QueryThisType(T thisPtr)
58 		{
59 			MT_UNUSED(thisPtr);
60 			return (T)nullptr;
61 		}
62 	};
63 
64 
65 	template <typename T>
66 	inline void CallDtor(T* p)
67 	{
68 		MT_UNUSED(p);
69 		p->~T();
70 	}
71 
72 }
73 
74 #if MT_MSVC_COMPILER_FAMILY
75 
76 // Visual Studio compile time check
77 #define MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
78 	void CompileTimeCheckMethod() \
79 	{ \
80 		MT::CheckType< typename std::remove_pointer< decltype(MT::TypeChecker::QueryThisType(this)) >::type, typename TYPE > compileTypeTypesCheck; \
81 		compileTypeTypesCheck; \
82 	}
83 
84 #elif MT_GCC_COMPILER_FAMILY
85 
86 // GCC, Clang and other compilers compile time check
87 #define MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
88 	void CompileTimeCheckMethod() \
89 	{ \
90 		/* query this pointer type */ \
91 		typedef decltype(MT::TypeChecker::QueryThisType(this)) THIS_PTR_TYPE; \
92 		/* query class type from this pointer type */ \
93 		typedef typename std::remove_pointer<THIS_PTR_TYPE>::type CPP_TYPE; \
94 		/* define macro type */ \
95 		typedef TYPE MACRO_TYPE; \
96 		/* compile time checking that is same types */ \
97 		MT::CheckType< CPP_TYPE, MACRO_TYPE > compileTypeTypesCheck; \
98 		/* remove unused variable warning */ \
99 		MT_UNUSED(compileTypeTypesCheck); \
100 	}
101 
102 #else
103 
104 #error Platform is not supported.
105 
106 #endif
107 
108 
109 #define MT_DECLARE_TASK_IMPL(TYPE, STACK_REQUIREMENTS, TASK_PRIORITY, DEBUG_COLOR) \
110 	\
111 	MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
112 	\
113 	static void TaskEntryPoint(MT::FiberContext& fiberContext, const void* userData) \
114 	{ \
115 		MT_SCHEDULER_PROFILER_TASK_SCOPE_CODE_INJECTION(TYPE, DEBUG_COLOR, __FILE__, __LINE__); \
116 		/* C style cast */ \
117 		TYPE * task = (TYPE *)(userData); \
118 		task->Do(fiberContext); \
119 	} \
120 	\
121 	static void PoolTaskDestroy(const void* userData) \
122 	{ \
123 		/* C style cast */ \
124 		TYPE * task = (TYPE *)(userData); \
125 		MT::CallDtor( task ); \
126 		/* Find task pool header */ \
127 		MT::PoolElementHeader * poolHeader = (MT::PoolElementHeader *)((char*)userData - sizeof(MT::PoolElementHeader)); \
128 		/* Fixup pool header, mark task as unused */ \
129 		poolHeader->id.Store(MT::TaskID::UNUSED); \
130 	} \
131 	\
132 	static MT::StackRequirements::Type GetStackRequirements() \
133 	{ \
134 		return STACK_REQUIREMENTS; \
135 	} \
136 	static MT::TaskPriority::Type GetTaskPriority() \
137 	{ \
138 		return TASK_PRIORITY; \
139 	} \
140 
141 
142 
143 #ifdef MT_INSTRUMENTED_BUILD
144 #include <MTProfilerEventListener.h>
145 
146 #define MT_DECLARE_TASK(TYPE, STACK_REQUIREMENTS, TASK_PRIORITY, DEBUG_COLOR) \
147 	static const mt_char* GetDebugID() \
148 	{ \
149 		return MT_TEXT( #TYPE ); \
150 	} \
151 	\
152 	static MT::Color::Type GetDebugColor() \
153 	{ \
154 		return DEBUG_COLOR; \
155 	} \
156 	\
157 	MT_DECLARE_TASK_IMPL(TYPE, STACK_REQUIREMENTS, TASK_PRIORITY, DEBUG_COLOR);
158 
159 #else
160 
161 #define MT_DECLARE_TASK(TYPE, STACK_REQUIREMENTS, TASK_PRIORITY, DEBUG_COLOR) \
162 	MT_DECLARE_TASK_IMPL(TYPE, STACK_REQUIREMENTS, TASK_PRIORITY, DEBUG_COLOR);
163 
164 #endif
165 
166 
167 
168 
169 #if defined(MT_DEBUG) || defined(MT_INSTRUMENTED_BUILD)
170 #define MT_GROUP_DEBUG (1)
171 #endif
172 
173 
174 
175 namespace MT
176 {
177 	const uint32 MT_MAX_THREAD_COUNT = 64;
178 	const uint32 MT_SCHEDULER_STACK_SIZE = 1048576; // 1Mb
179 
180 	const uint32 MT_MAX_STANDART_FIBERS_COUNT = 256;
181 	const uint32 MT_STANDART_FIBER_STACK_SIZE = 32768; //32Kb
182 
183 	const uint32 MT_MAX_EXTENDED_FIBERS_COUNT = 8;
184 	const uint32 MT_EXTENDED_FIBER_STACK_SIZE = 1048576; // 1Mb
185 
186 	namespace internal
187 	{
188 		struct ThreadContext;
189 	}
190 
191 	namespace TaskStealingMode
192 	{
193 		enum Type
194 		{
195 			DISABLED = 0,
196 			ENABLED = 1,
197 		};
198 	}
199 
200 	struct WorkerThreadParams
201 	{
202 		uint32 core;
203 		ThreadPriority::Type priority;
204 
205 		WorkerThreadParams()
206 			: core(MT_CPUCORE_ANY)
207 			, priority(ThreadPriority::DEFAULT)
208 		{
209 		}
210 	};
211 
212 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
213 	// Task scheduler
214 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
215 	class TaskScheduler
216 	{
217 		friend class FiberContext;
218 		friend struct internal::ThreadContext;
219 
220 
221 
222 		////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
223 		// Task group description
224 		////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
225 		// Application can assign task group to task and later wait until group was finished.
226 		class TaskGroupDescription
227 		{
228 			Atomic32<int32> inProgressTaskCount;
229 
230 #if MT_GROUP_DEBUG
231 			bool debugIsFree;
232 #endif
233 
234 		public:
235 
236 			MT_NOCOPYABLE(TaskGroupDescription);
237 
238 			TaskGroupDescription()
239 			{
240 				inProgressTaskCount.Store(0);
241 #if MT_GROUP_DEBUG
242 				debugIsFree = true;
243 #endif
244 			}
245 
246 			int32 GetTaskCount() const
247 			{
248 				return inProgressTaskCount.Load();
249 			}
250 
251 			int32 Dec()
252 			{
253 				return inProgressTaskCount.DecFetch();
254 			}
255 
256 			int32 Inc()
257 			{
258 				return inProgressTaskCount.IncFetch();
259 			}
260 
261 			int32 Add(int sum)
262 			{
263 				return inProgressTaskCount.AddFetch(sum);
264 			}
265 
266 			Atomic32<int32>* GetWaitCounter()
267 			{
268 				return &inProgressTaskCount;
269 			}
270 
271 #if MT_GROUP_DEBUG
272 			void SetDebugIsFree(bool _debugIsFree)
273 			{
274 				debugIsFree = _debugIsFree;
275 			}
276 
277 			bool GetDebugIsFree() const
278 			{
279 				return debugIsFree;
280 			}
281 #endif
282 		};
283 
284 
285 		struct WaitContext
286 		{
287 			Atomic32<int32>* waitCounter;
288 			internal::ThreadContext* threadContext;
289 			uint32 waitTimeMs;
290 			uint32 exitCode;
291 		};
292 
293 
294 		// Thread index for new task
295 		Atomic32<int32> roundRobinThreadIndex;
296 
297 		// Started threads count
298 		Atomic32<int32> startedThreadsCount;
299 
300 		// Threads created by task manager
301 		Atomic32<int32> threadsCount;
302 		internal::ThreadContext threadContext[MT_MAX_THREAD_COUNT];
303 
304 		// All groups task statistic
305 		TaskGroupDescription allGroups;
306 
307 		// Groups pool
308 		LockFreeQueueMPMC<TaskGroup, TaskGroup::MT_MAX_GROUPS_COUNT * 2> availableGroups;
309 
310 		//
311 		TaskGroupDescription groupStats[TaskGroup::MT_MAX_GROUPS_COUNT];
312 
313 		// Fibers context
314 		FiberContext standartFiberContexts[MT_MAX_STANDART_FIBERS_COUNT];
315 		FiberContext extendedFiberContexts[MT_MAX_EXTENDED_FIBERS_COUNT];
316 
317 		// Fibers pool
318 		LockFreeQueueMPMC<FiberContext*, MT_MAX_STANDART_FIBERS_COUNT * 2> standartFibersAvailable;
319 		LockFreeQueueMPMC<FiberContext*, MT_MAX_EXTENDED_FIBERS_COUNT * 2> extendedFibersAvailable;
320 
321 #ifdef MT_INSTRUMENTED_BUILD
322 		IProfilerEventListener * profilerEventListener;
323 #endif
324 
325 		bool taskStealingDisabled;
326 
327 		FiberContext* RequestFiberContext(internal::GroupedTask& task);
328 		void ReleaseFiberContext(FiberContext*&& fiberExecutionContext);
329 		void RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState);
330 		TaskGroupDescription & GetGroupDesc(TaskGroup group);
331 
332 		static void WorkerThreadMain( void* userData );
333 		static void SchedulerFiberMain( void* userData );
334 		static void SchedulerFiberWait( void* userData );
335 		static bool SchedulerFiberStep( internal::ThreadContext& context, bool disableTaskStealing);
336 		static void SchedulerFiberProcessTask( internal::ThreadContext& context, internal::GroupedTask& task );
337 		static void FiberMain( void* userData );
338 		static bool TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task);
339 
340 		static FiberContext* ExecuteTask (internal::ThreadContext& threadContext, FiberContext* fiberContext);
341 
342 	public:
343 
344 		/// \brief Initializes a new instance of the TaskScheduler class.
345 		/// \param workerThreadsCount Worker threads count. Automatically determines the required number of threads if workerThreadsCount set to 0
346 #ifdef MT_INSTRUMENTED_BUILD
347 		TaskScheduler(uint32 workerThreadsCount = 0, WorkerThreadParams* workerParameters = nullptr, IProfilerEventListener* listener = nullptr, TaskStealingMode::Type stealMode = TaskStealingMode::ENABLED);
348 #else
349 		TaskScheduler(uint32 workerThreadsCount = 0, WorkerThreadParams* workerParameters = nullptr, TaskStealingMode::Type stealMode = TaskStealingMode::ENABLED);
350 #endif
351 
352 
353 		~TaskScheduler();
354 
355 		template<class TTask>
356 		void RunAsync(TaskGroup group, const TTask* taskArray, uint32 taskCount);
357 
358 		void RunAsync(TaskGroup group, const TaskHandle* taskHandleArray, uint32 taskHandleCount);
359 
360 		/// \brief Wait while no more tasks in specific group.
361 		/// \return true - if no more tasks in specific group. false - if timeout in milliseconds has reached and group still has some tasks.
362 		bool WaitGroup(TaskGroup group, uint32 milliseconds);
363 
364 		bool WaitAll(uint32 milliseconds);
365 
366 		TaskGroup CreateGroup();
367 		void ReleaseGroup(TaskGroup group);
368 
369 		int32 GetWorkersCount() const;
370 
371 		bool IsTaskStealingDisabled(uint32 minWorkersCount = 1) const;
372 
373 		bool IsWorkerThread() const;
374 
375 #ifdef MT_INSTRUMENTED_BUILD
376 
377 		inline IProfilerEventListener* GetProfilerEventListener()
378 		{
379 			return profilerEventListener;
380 		}
381 
382 		void NotifyFibersCreated(uint32 fibersCount);
383 		void NotifyThreadsCreated(uint32 threadsCount);
384 
385 
386 #endif
387 	};
388 }
389 
390 #include "MTScheduler.inl"
391 #include "MTFiberContext.inl"
392