1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #pragma once
24 
25 #include <MTColorTable.h>
26 #include <MTTools.h>
27 #include <MTPlatform.h>
28 #include <MTConcurrentQueueLIFO.h>
29 #include <MTStackArray.h>
30 #include <MTArrayView.h>
31 #include <MTThreadContext.h>
32 #include <MTFiberContext.h>
33 #include <MTAllocator.h>
34 #include <MTTaskPool.h>
35 #include <MTStackRequirements.h>
36 #include <Scopes/MTScopes.h>
37 
38 
39 namespace MT
40 {
41 
42 	template<typename CLASS_TYPE, typename MACRO_TYPE>
43 	struct CheckType
44 	{
45 		static_assert(std::is_same<CLASS_TYPE, MACRO_TYPE>::value, "Invalid type in MT_DECLARE_TASK macro. See CheckType template instantiation params to details.");
46 	};
47 
48 	struct TypeChecker
49 	{
50 		template <typename T>
51 		static T QueryThisType(T thisPtr)
52 		{
53 			return (T)nullptr;
54 		}
55 	};
56 
57 
58 	template <typename T>
59 	inline void CallDtor(T* p)
60 	{
61 		MT_UNUSED(p);
62 		p->~T();
63 	}
64 
65 }
66 
67 #if _MSC_VER
68 
69 // Visual Studio compile time check
70 #define MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
71 	void CompileTimeCheckMethod() \
72 	{ \
73 		MT::CheckType< typename std::remove_pointer< decltype(MT::TypeChecker::QueryThisType(this)) >::type, typename TYPE > compileTypeTypesCheck; \
74 		compileTypeTypesCheck; \
75 	}
76 
77 #else
78 
79 
80 // GCC, Clang and other compilers compile time check
81 #define MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
82 	void CompileTimeCheckMethod() \
83 	{ \
84 		/* query this pointer type */ \
85 		typedef decltype(MT::TypeChecker::QueryThisType(this)) THIS_PTR_TYPE; \
86 		/* query class type from this pointer type */ \
87 		typedef typename std::remove_pointer<THIS_PTR_TYPE>::type CPP_TYPE; \
88 		/* define macro type */ \
89 		typedef TYPE MACRO_TYPE; \
90 		/* compile time checking that is same types */ \
91 		MT::CheckType< CPP_TYPE, MACRO_TYPE > compileTypeTypesCheck; \
92 		/* remove unused variable warning */ \
93 		MT_UNUSED(compileTypeTypesCheck); \
94 	}
95 
96 #endif
97 
98 
99 
100 
101 #define MT_DECLARE_TASK_IMPL(TYPE, STACK_REQUIREMENTS) \
102 	\
103 	MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
104 	\
105 	static void TaskEntryPoint(MT::FiberContext& fiberContext, const void* userData) \
106 	{ \
107 		/* C style cast */ \
108 		TYPE * task = (TYPE *)(userData); \
109 		task->Do(fiberContext); \
110 	} \
111 	\
112 	static void PoolTaskDestroy(const void* userData) \
113 	{ \
114 		/* C style cast */ \
115 		TYPE * task = (TYPE *)(userData); \
116 		MT::CallDtor( task ); \
117 		/* Find task pool header */ \
118 		MT::PoolElementHeader * poolHeader = (MT::PoolElementHeader *)((char*)userData - sizeof(MT::PoolElementHeader)); \
119 		/* Fixup pool header, mark task as unused */ \
120 		poolHeader->id.Store(MT::TaskID::UNUSED); \
121 	} \
122 	\
123 	static MT::StackRequirements::Type GetStackRequirements() \
124 	{ \
125 		return STACK_REQUIREMENTS; \
126 	} \
127 
128 
129 
130 #ifdef MT_INSTRUMENTED_BUILD
131 #include <MTProfilerEventListener.h>
132 
133 #define MT_DECLARE_TASK(TYPE, STACK_REQUIREMENTS, DEBUG_COLOR) \
134 	static const mt_char* GetDebugID() \
135 	{ \
136 		return MT_TEXT( #TYPE ); \
137 	} \
138 	\
139 	static MT::Color::Type GetDebugColor() \
140 	{ \
141 		return DEBUG_COLOR; \
142 	} \
143 	\
144 	MT_DECLARE_TASK_IMPL(TYPE, STACK_REQUIREMENTS);
145 
146 
147 #else
148 
149 #define MT_DECLARE_TASK(TYPE, STACK_REQUIREMENTS, DEBUG_COLOR) \
150 	MT_DECLARE_TASK_IMPL(TYPE, STACK_REQUIREMENTS);
151 
152 #endif
153 
154 
155 
156 
157 
158 
159 namespace MT
160 {
161 	const uint32 MT_MAX_THREAD_COUNT = 64;
162 	const uint32 MT_SCHEDULER_STACK_SIZE = 1048576; // 1Mb
163 
164 	const uint32 MT_MAX_STANDART_FIBERS_COUNT = 256;
165 	const uint32 MT_STANDART_FIBER_STACK_SIZE = 32768; //32Kb
166 
167 	const uint32 MT_MAX_EXTENDED_FIBERS_COUNT = 8;
168 	const uint32 MT_EXTENDED_FIBER_STACK_SIZE = 1048576; // 1Mb
169 
170 	namespace internal
171 	{
172 		struct ThreadContext;
173 	}
174 
175 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
176 	// Task scheduler
177 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
178 	class TaskScheduler
179 	{
180 		friend class FiberContext;
181 		friend struct internal::ThreadContext;
182 
183 
184 
185 		////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
186 		// Task group description
187 		////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
188 		// Application can assign task group to task and later wait until group was finished.
189 		class TaskGroupDescription
190 		{
191 			AtomicInt32 inProgressTaskCount;
192 			Event allDoneEvent;
193 
194 			//Tasks awaiting group through FiberContext::WaitGroupAndYield call
195 			ConcurrentQueueLIFO<FiberContext*> waitTasksQueue;
196 
197 			bool debugIsFree;
198 
199 		public:
200 
201 			MT_NOCOPYABLE(TaskGroupDescription);
202 
203 			TaskGroupDescription()
204 			{
205 				inProgressTaskCount.Store(0);
206 				allDoneEvent.Create( EventReset::MANUAL, true );
207 				debugIsFree = true;
208 			}
209 
210 			int GetTaskCount() const
211 			{
212 				return inProgressTaskCount.Load();
213 			}
214 
215 			ConcurrentQueueLIFO<FiberContext*> & GetWaitQueue()
216 			{
217 				return waitTasksQueue;
218 			}
219 
220 			int Dec()
221 			{
222 				return inProgressTaskCount.DecFetch();
223 			}
224 
225 			int Inc()
226 			{
227 				return inProgressTaskCount.IncFetch();
228 			}
229 
230 			int Add(int sum)
231 			{
232 				return inProgressTaskCount.AddFetch(sum);
233 			}
234 
235 			void Signal()
236 			{
237 				allDoneEvent.Signal();
238 			}
239 
240 			void Reset()
241 			{
242 				allDoneEvent.Reset();
243 			}
244 
245 			bool Wait(uint32 milliseconds)
246 			{
247 				return allDoneEvent.Wait(milliseconds);
248 			}
249 
250 			void SetDebugIsFree(bool _debugIsFree)
251 			{
252 				debugIsFree = _debugIsFree;
253 			}
254 
255 			bool GetDebugIsFree() const
256 			{
257 				return debugIsFree;
258 			}
259 		};
260 
261 
262 		// Thread index for new task
263 		AtomicInt32 roundRobinThreadIndex;
264 
265 		// Started threads count
266 		AtomicInt32 startedThreadsCount;
267 
268 		// Threads created by task manager
269 		AtomicInt32 threadsCount;
270 		internal::ThreadContext threadContext[MT_MAX_THREAD_COUNT];
271 
272 		// All groups task statistic
273 		TaskGroupDescription allGroups;
274 
275 		// Groups pool
276 		ConcurrentQueueLIFO<TaskGroup> availableGroups;
277 
278 		//
279 		TaskGroupDescription groupStats[TaskGroup::MT_MAX_GROUPS_COUNT];
280 
281 		// Fibers context
282 		FiberContext standartFiberContexts[MT_MAX_STANDART_FIBERS_COUNT];
283 		FiberContext extendedFiberContexts[MT_MAX_EXTENDED_FIBERS_COUNT];
284 
285 		// Fibers pool
286 		ConcurrentQueueLIFO<FiberContext*> standartFibersAvailable;
287 		ConcurrentQueueLIFO<FiberContext*> extendedFibersAvailable;
288 
289 		ConcurrentQueueLIFO<FiberContext*>* GetFibersStorage(MT::StackRequirements::Type stackRequirements);
290 
291 #ifdef MT_INSTRUMENTED_BUILD
292 		IProfilerEventListener * profilerEventListener;
293 #endif
294 
295 		FiberContext* RequestFiberContext(internal::GroupedTask& task);
296 		void ReleaseFiberContext(FiberContext* fiberExecutionContext);
297 		void RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState);
298 		TaskGroupDescription & GetGroupDesc(TaskGroup group);
299 
300 		static void ThreadMain( void* userData );
301 		static void FiberMain( void* userData );
302 		static bool TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount);
303 
304 		static FiberContext* ExecuteTask (internal::ThreadContext& threadContext, FiberContext* fiberContext);
305 
306 	public:
307 
308 		/// \brief Initializes a new instance of the TaskScheduler class.
309 		/// \param workerThreadsCount Worker threads count. Automatically determines the required number of threads if workerThreadsCount set to 0
310 #ifdef MT_INSTRUMENTED_BUILD
311 		TaskScheduler(uint32 workerThreadsCount = 0, IProfilerEventListener* listener = nullptr);
312 #else
313 		TaskScheduler(uint32 workerThreadsCount = 0);
314 #endif
315 
316 
317 		~TaskScheduler();
318 
319 		template<class TTask>
320 		void RunAsync(TaskGroup group, const TTask* taskArray, uint32 taskCount);
321 
322 		void RunAsync(TaskGroup group, const TaskHandle* taskHandleArray, uint32 taskHandleCount);
323 
324 
325 		bool WaitGroup(TaskGroup group, uint32 milliseconds);
326 		bool WaitAll(uint32 milliseconds);
327 
328 		TaskGroup CreateGroup();
329 		void ReleaseGroup(TaskGroup group);
330 
331 		bool IsEmpty();
332 
333 		int32 GetWorkersCount() const;
334 
335 		bool IsWorkerThread() const;
336 
337 #ifdef MT_INSTRUMENTED_BUILD
338 
339 		inline IProfilerEventListener* GetProfilerEventListener()
340 		{
341 			return profilerEventListener;
342 		}
343 
344 #endif
345 	};
346 }
347 
348 #include "MTScheduler.inl"
349 #include "MTFiberContext.inl"
350