1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #pragma once
24 
25 #include <MTColorTable.h>
26 #include <MTTools.h>
27 #include <MTPlatform.h>
28 #include <MTConcurrentQueueLIFO.h>
29 #include <MTStackArray.h>
30 #include <MTArrayView.h>
31 #include <MTThreadContext.h>
32 #include <MTFiberContext.h>
33 #include <MTAllocator.h>
34 #include <MTTaskPool.h>
35 
36 
37 namespace MT
38 {
39 
40 	template<typename CLASS_TYPE, typename MACRO_TYPE>
41 	struct CheckType
42 	{
43 		static_assert(std::is_same<CLASS_TYPE, MACRO_TYPE>::value, "Invalid type in MT_DECLARE_TASK macro. See CheckType template instantiation params to details.");
44 	};
45 
46 	struct TypeChecker
47 	{
48 		template <typename T>
49 		static T QueryThisType(T thisPtr)
50 		{
51 			return (T)nullptr;
52 		}
53 	};
54 
55 
56 	template <typename T>
57 	inline void CallDtor(T * p)
58 	{
59 #if _MSC_VER
60 		p;
61 #endif
62 		p->~T();
63 	}
64 
65 }
66 
67 #if _MSC_VER
68 
69 // Visual Studio compile time check
70 #define MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
71 	void CompileTimeCheckMethod() \
72 	{ \
73 		MT::CheckType< typename std::remove_pointer< decltype(MT::TypeChecker::QueryThisType(this)) >::type, typename TYPE > compileTypeTypesCheck; \
74 		compileTypeTypesCheck; \
75 	}
76 
77 #else
78 
79 #define MT_UNUSED(x) (void)(x)
80 
81 // GCC, Clang and other compilers compile time check
82 #define MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
83 	void CompileTimeCheckMethod() \
84 	{ \
85 		/* query this pointer type */ \
86 		typedef decltype(MT::TypeChecker::QueryThisType(this)) THIS_PTR_TYPE; \
87 		/* query class type from this pointer type */ \
88 		typedef typename std::remove_pointer<THIS_PTR_TYPE>::type CPP_TYPE; \
89 		/* define macro type */ \
90 		typedef TYPE MACRO_TYPE; \
91 		/* compile time checking that is same types */ \
92 		MT::CheckType< CPP_TYPE, MACRO_TYPE > compileTypeTypesCheck; \
93 		/* remove unused variable warning */ \
94 		MT_UNUSED(compileTypeTypesCheck); \
95 	}
96 
97 #endif
98 
99 
100 
101 
102 #define MT_DECLARE_TASK_IMPL(TYPE) \
103 	\
104 	MT_COMPILE_TIME_TYPE_CHECK(TYPE) \
105 	\
106 	static void TaskEntryPoint(MT::FiberContext& fiberContext, void* userData) \
107 	{ \
108 		TYPE * task = static_cast< TYPE *>(userData); \
109 		task->Do(fiberContext); \
110 	} \
111 	\
112 	static void PoolTaskDestroy(void* userData) \
113 	{ \
114 		TYPE * task = static_cast< TYPE *>(userData); \
115 		MT::CallDtor( task ); \
116 		/* Find task pool header */ \
117 		MT::PoolElementHeader * poolHeader = (MT::PoolElementHeader *)((char*)userData - sizeof(MT::PoolElementHeader)); \
118 		/* Fixup pool header, mark task as unused */ \
119 		poolHeader->id.Store(MT::TaskID::UNUSED); \
120 	} \
121 
122 
123 
124 #ifdef MT_INSTRUMENTED_BUILD
125 #include <MTProfilerEventListener.h>
126 
127 #define MT_DECLARE_TASK(TYPE, DEBUG_COLOR) \
128 	static const mt_char* GetDebugID() \
129 	{ \
130 		return MT_TEXT( #TYPE ); \
131 	} \
132 	\
133 	static MT::Color::Type GetDebugColor() \
134 	{ \
135 		return DEBUG_COLOR; \
136 	} \
137 	\
138 	MT_DECLARE_TASK_IMPL(TYPE);
139 
140 
141 #else
142 
143 #define MT_DECLARE_TASK(TYPE, colorID) \
144 	MT_DECLARE_TASK_IMPL(TYPE);
145 
146 #endif
147 
148 
149 
150 
151 
152 
153 namespace MT
154 {
155 	const uint32 MT_MAX_THREAD_COUNT = 64;
156 	const uint32 MT_MAX_FIBERS_COUNT = 256;
157 	const uint32 MT_SCHEDULER_STACK_SIZE = 1048576;
158 	const uint32 MT_FIBER_STACK_SIZE = 65536;
159 
160 	namespace internal
161 	{
162 		struct ThreadContext;
163 	}
164 
165 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
166 	// Task scheduler
167 	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
168 	class TaskScheduler
169 	{
170 		friend class FiberContext;
171 		friend struct internal::ThreadContext;
172 
173 
174 
175 		////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
176 		// Task group description
177 		////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
178 		// Application can assign task group to task and later wait until group was finished.
179 		class TaskGroupDescription
180 		{
181 			AtomicInt32 inProgressTaskCount;
182 			Event allDoneEvent;
183 
184 			//Tasks awaiting group through FiberContext::WaitGroupAndYield call
185 			ConcurrentQueueLIFO<FiberContext*> waitTasksQueue;
186 
187 		public:
188 
189 			bool debugIsFree;
190 
191 
192 		private:
193 
194 			TaskGroupDescription(TaskGroupDescription& ) {}
195 			void operator=(const TaskGroupDescription&) {}
196 
197 		public:
198 
199 			TaskGroupDescription()
200 			{
201 				inProgressTaskCount.Store(0);
202 				allDoneEvent.Create( EventReset::MANUAL, true );
203 				debugIsFree = true;
204 			}
205 
206 			int GetTaskCount() const
207 			{
208 				return inProgressTaskCount.Load();
209 			}
210 
211 			ConcurrentQueueLIFO<FiberContext*> & GetWaitQueue()
212 			{
213 				return waitTasksQueue;
214 			}
215 
216 			int Dec()
217 			{
218 				return inProgressTaskCount.DecFetch();
219 			}
220 
221 			int Inc()
222 			{
223 				return inProgressTaskCount.IncFetch();
224 			}
225 
226 			int Add(int sum)
227 			{
228 				return inProgressTaskCount.AddFetch(sum);
229 			}
230 
231 			void Signal()
232 			{
233 				allDoneEvent.Signal();
234 			}
235 
236 			void Reset()
237 			{
238 				allDoneEvent.Reset();
239 			}
240 
241 			bool Wait(uint32 milliseconds)
242 			{
243 				return allDoneEvent.Wait(milliseconds);
244 			}
245 		};
246 
247 
248 		// Thread index for new task
249 		AtomicInt32 roundRobinThreadIndex;
250 
251 		// Started threads count
252 		AtomicInt32 startedThreadsCount;
253 
254 		// Threads created by task manager
255 		volatile uint32 threadsCount;
256 		internal::ThreadContext threadContext[MT_MAX_THREAD_COUNT];
257 
258 		// All groups task statistic
259 		TaskGroupDescription allGroups;
260 
261 		// Groups pool
262 		ConcurrentQueueLIFO<TaskGroup> availableGroups;
263 
264 		//
265 		TaskGroupDescription groupStats[TaskGroup::MT_MAX_GROUPS_COUNT];
266 
267 		// Fibers pool
268 		ConcurrentQueueLIFO<FiberContext*> availableFibers;
269 
270 		// Fibers context
271 		FiberContext fiberContext[MT_MAX_FIBERS_COUNT];
272 
273 #ifdef MT_INSTRUMENTED_BUILD
274 		IProfilerEventListener * profilerEventListener;
275 #endif
276 
277 		FiberContext* RequestFiberContext(internal::GroupedTask& task);
278 		void ReleaseFiberContext(FiberContext* fiberExecutionContext);
279 		void RunTasksImpl(ArrayView<internal::TaskBucket>& buckets, FiberContext * parentFiber, bool restoredFromAwaitState);
280 		TaskGroupDescription & GetGroupDesc(TaskGroup group);
281 
282 		static void ThreadMain( void* userData );
283 		static void FiberMain( void* userData );
284 		static bool TryStealTask(internal::ThreadContext& threadContext, internal::GroupedTask & task, uint32 workersCount);
285 
286 		static FiberContext* ExecuteTask (internal::ThreadContext& threadContext, FiberContext* fiberContext);
287 
288 	public:
289 
290 		/// \brief Initializes a new instance of the TaskScheduler class.
291 		/// \param workerThreadsCount Worker threads count. Automatically determines the required number of threads if workerThreadsCount set to 0
292 #ifdef MT_INSTRUMENTED_BUILD
293 		TaskScheduler(uint32 workerThreadsCount = 0, IProfilerEventListener* listener = nullptr);
294 #else
295 		TaskScheduler(uint32 workerThreadsCount = 0);
296 #endif
297 
298 
299 		~TaskScheduler();
300 
301 		template<class TTask>
302 		void RunAsync(TaskGroup group, TTask* taskArray, uint32 taskCount);
303 
304 		void RunAsync(TaskGroup group, TaskHandle* taskHandleArray, uint32 taskHandleCount);
305 
306 
307 		bool WaitGroup(TaskGroup group, uint32 milliseconds);
308 		bool WaitAll(uint32 milliseconds);
309 
310 		TaskGroup CreateGroup();
311 		void ReleaseGroup(TaskGroup group);
312 
313 		bool IsEmpty();
314 
315 		uint32 GetWorkerCount() const;
316 
317 		bool IsWorkerThread() const;
318 
319 #ifdef MT_INSTRUMENTED_BUILD
320 
321 		inline IProfilerEventListener* GetProfilerEventListener()
322 		{
323 			return profilerEventListener;
324 		}
325 
326 #endif
327 	};
328 }
329 
330 #include "MTScheduler.inl"
331 #include "MTFiberContext.inl"
332