1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #include "Tests.h"
24 #include <UnitTest++.h>
25 #include <MTScheduler.h>
26 #include <MTStaticVector.h>
27 
28 SUITE(SimpleTests)
29 {
30 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
31 struct SimpleTask
32 {
33 	MT_DECLARE_TASK(SimpleTask, MT::StackRequirements::STANDARD, MT::TaskPriority::NORMAL, MT::Color::Blue);
34 
35 	static const int sourceData = 0xFF33FF;
36 	int resultData;
37 
38 	SimpleTask() : resultData(0) {}
39 
40 	void Do(MT::FiberContext&)
41 	{
42 		resultData = sourceData;
43 	}
44 
45 	int GetSourceData()
46 	{
47 		return sourceData;
48 	}
49 };
50 
51 // Checks one simple task
52 TEST(RunOneSimpleTask)
53 {
54 	MT::TaskScheduler scheduler;
55 
56 	SimpleTask task;
57 	scheduler.RunAsync(MT::TaskGroup::Default(), &task, 1);
58 
59 	CHECK(scheduler.WaitAll(1000));
60 	CHECK_EQUAL(task.GetSourceData(), task.resultData);
61 }
62 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
63 
64 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
65 struct ALotOfTasks
66 {
67 	MT_DECLARE_TASK(ALotOfTasks, MT::StackRequirements::STANDARD, MT::TaskPriority::NORMAL, MT::Color::Blue);
68 
69 	MT::Atomic32<int32>* counter;
70 
71 	void Do(MT::FiberContext&)
72 	{
73 		counter->IncFetch();
74 		MT::SpinSleepMilliSeconds(1);
75 	}
76 };
77 
78 // Checks one simple task
79 TEST(ALotOfTasks)
80 {
81 	MT::TaskScheduler scheduler;
82 
83 	MT::Atomic32<int32> counter;
84 
85 	static const int TASK_COUNT = 1000;
86 
87 	ALotOfTasks tasks[TASK_COUNT];
88 
89 	for (size_t i = 0; i < MT_ARRAY_SIZE(tasks); ++i)
90 		tasks[i].counter = &counter;
91 
92 	scheduler.RunAsync(MT::TaskGroup::Default(), &tasks[0], MT_ARRAY_SIZE(tasks));
93 
94 	int timeout = (TASK_COUNT / scheduler.GetWorkersCount()) * 2000;
95 
96 	CHECK(scheduler.WaitGroup(MT::TaskGroup::Default(), timeout));
97 	CHECK_EQUAL(TASK_COUNT, counter.Load());
98 }
99 
100 
101 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
102 
103 
104 struct WorkerThreadState
105 {
106 	uint32 counterPhase0;
107 	uint32 counterPhase1;
108 
109 	WorkerThreadState()
110 	{
111 		Reset();
112 	}
113 
114 	void Reset()
115 	{
116 		counterPhase0 = 0;
117 		counterPhase1 = 0;
118 	}
119 };
120 
121 
122 WorkerThreadState workerStates[64];
123 
124 uint32 TASK_COUNT_PER_WORKER = 0;
125 
126 MT::Atomic32<uint32> finishedTaskCount;
127 
128 struct YieldTask
129 {
130 	MT::Atomic32<uint32> counter;
131 
132 	MT_DECLARE_TASK(YieldTask, MT::StackRequirements::STANDARD, MT::TaskPriority::NORMAL, MT::Color::Blue);
133 
134 	YieldTask()
135 	{
136 		counter.Store(0);
137 	}
138 
139 
140 	 volatile WorkerThreadState* GetWorkerState( volatile uint32 workerIndex) volatile
141 	{
142 		MT_ASSERT(workerIndex < MT_ARRAY_SIZE(workerStates), "Invalid worker index");
143 		volatile WorkerThreadState& state = workerStates[workerIndex];
144 		return &state;
145 	}
146 
147 	void Do(MT::FiberContext& context)
148 	{
149 		volatile WorkerThreadState* state0 = GetWorkerState( context.GetThreadContext()->workerIndex );
150 
151 		// phase 0
152 		CHECK_EQUAL((uint32)1, counter.IncFetch());
153 		state0->counterPhase0++;
154 		context.Yield();
155 
156 		// worker index can be changed after yield, get actual index
157 		volatile WorkerThreadState* state1 = GetWorkerState( context.GetThreadContext()->workerIndex );
158 
159 		//I check that all the tasks (on this worker) have passed phase0 before executing phase1
160 		CHECK_EQUAL(TASK_COUNT_PER_WORKER, state1->counterPhase0);
161 
162 		// phase 1
163 		CHECK_EQUAL((uint32)2, counter.IncFetch());
164 		state1->counterPhase1++;
165 
166 		finishedTaskCount.IncFetch();
167 	}
168 };
169 
170 
171 TEST(YieldTasks)
172 {
173 	// Disable task stealing (for testing purposes only)
174 #ifdef MT_INSTRUMENTED_BUILD
175 	MT::TaskScheduler scheduler(0, nullptr, nullptr, MT::TaskStealingMode::DISABLED);
176 #else
177 	MT::TaskScheduler scheduler(0, nullptr, MT::TaskStealingMode::DISABLED);
178 #endif
179 
180 	finishedTaskCount.Store(0);
181 
182 	int32 workersCount = scheduler.GetWorkersCount();
183 	TASK_COUNT_PER_WORKER = workersCount * 4;
184 	int32 taskCount = workersCount * TASK_COUNT_PER_WORKER;
185 
186 	MT::HardwareFullMemoryBarrier();
187 
188 	MT::StaticVector<YieldTask, 512> tasks;
189 	for(int32 i = 0; i < taskCount; i++)
190 	{
191 		tasks.PushBack(YieldTask());
192 	}
193 
194 	for(int32 i = 0; i < workersCount; i++)
195 	{
196 		WorkerThreadState& state = workerStates[i];
197 		state.Reset();
198 	}
199 
200 
201 	scheduler.RunAsync(MT::TaskGroup::Default(), tasks.Begin(), (uint32)tasks.Size());
202 
203 	CHECK(scheduler.WaitGroup(MT::TaskGroup::Default(), 10000));
204 
205 	for(int32 i = 0; i < workersCount; i++)
206 	{
207 		WorkerThreadState& state = workerStates[i];
208 
209 		CHECK_EQUAL(TASK_COUNT_PER_WORKER, state.counterPhase0);
210 		CHECK_EQUAL(TASK_COUNT_PER_WORKER, state.counterPhase1);
211 	}
212 
213 	CHECK_EQUAL(taskCount, (int32)finishedTaskCount.Load());
214 
215 	printf("Yield test: %d tasks finished, used %d workers\n", taskCount, workersCount);
216 
217 }
218 
219 
220 
221 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
222 }
223