1 // The MIT License (MIT)
2 //
3 // 	Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // 	Permission is hereby granted, free of charge, to any person obtaining a copy
6 // 	of this software and associated documentation files (the "Software"), to deal
7 // 	in the Software without restriction, including without limitation the rights
8 // 	to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // 	copies of the Software, and to permit persons to whom the Software is
10 // 	furnished to do so, subject to the following conditions:
11 //
12 //  The above copyright notice and this permission notice shall be included in
13 // 	all copies or substantial portions of the Software.
14 //
15 // 	THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // 	IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // 	FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // 	AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // 	LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // 	OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // 	THE SOFTWARE.
22 
23 #include "Tests.h"
24 #include <UnitTest++.h>
25 #include <MTScheduler.h>
26 #include <MTStaticVector.h>
27 
28 SUITE(SimpleTests)
29 {
30 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
31 struct SimpleTask
32 {
33 	MT_DECLARE_TASK(SimpleTask, MT::StackRequirements::STANDARD, MT::TaskPriority::NORMAL, MT::Color::Blue);
34 
35 	static const int sourceData = 0xFF33FF;
36 	int resultData;
37 
38 	SimpleTask() : resultData(0) {}
39 
40 	void Do(MT::FiberContext&)
41 	{
42 		resultData = sourceData;
43 	}
44 
45 	int GetSourceData()
46 	{
47 		return sourceData;
48 	}
49 };
50 
51 // Checks one simple task
52 TEST(RunOneSimpleTask)
53 {
54 	MT::TaskScheduler scheduler;
55 
56 	SimpleTask task;
57 	scheduler.RunAsync(MT::TaskGroup::Default(), &task, 1);
58 
59 	CHECK(scheduler.WaitAll(1000));
60 	CHECK_EQUAL(task.GetSourceData(), task.resultData);
61 }
62 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
63 
64 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
65 struct ALotOfTasks
66 {
67 	MT_DECLARE_TASK(ALotOfTasks, MT::StackRequirements::STANDARD, MT::TaskPriority::NORMAL, MT::Color::Blue);
68 
69 	MT::Atomic32<int32>* counter;
70 
71 	void Do(MT::FiberContext&)
72 	{
73 		counter->IncFetch();
74 		MT::Thread::SpinSleepMilliSeconds(1);
75 	}
76 };
77 
78 // Checks one simple task
79 TEST(ALotOfTasks)
80 {
81 	MT::TaskScheduler scheduler;
82 
83 	MT::Atomic32<int32> counter;
84 
85 	static const int TASK_COUNT = 1000;
86 
87 	ALotOfTasks tasks[TASK_COUNT];
88 
89 	for (size_t i = 0; i < MT_ARRAY_SIZE(tasks); ++i)
90 		tasks[i].counter = &counter;
91 
92 	scheduler.RunAsync(MT::TaskGroup::Default(), &tasks[0], MT_ARRAY_SIZE(tasks));
93 
94 	int timeout = (TASK_COUNT / scheduler.GetWorkersCount()) * 2000;
95 
96 	CHECK(scheduler.WaitGroup(MT::TaskGroup::Default(), timeout));
97 	CHECK_EQUAL(TASK_COUNT, counter.Load());
98 }
99 
100 
101 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
102 
103 
104 struct WorkerThreadState
105 {
106 	uint32 counterPhase0;
107 	uint32 counterPhase1;
108 
109 	WorkerThreadState()
110 	{
111 		counterPhase0 = 0;
112 		counterPhase1 = 0;
113 	}
114 };
115 
116 
117 WorkerThreadState workerStates[64];
118 
119 uint32 TASK_COUNT_PER_WORKER = 0;
120 
121 MT::Atomic32<uint32> finishedTaskCount;
122 
123 struct YieldTask
124 {
125 	MT::Atomic32<uint32> counter;
126 
127 	MT_DECLARE_TASK(YieldTask, MT::StackRequirements::STANDARD, MT::TaskPriority::NORMAL, MT::Color::Blue);
128 
129 	YieldTask()
130 	{
131 		counter.Store(0);
132 	}
133 
134 
135 	 volatile WorkerThreadState* GetWorkerState( volatile uint32 workerIndex) volatile
136 	{
137 		MT_ASSERT(workerIndex < MT_ARRAY_SIZE(workerStates), "Invalid worker index");
138 		volatile WorkerThreadState& state = workerStates[workerIndex];
139 		return &state;
140 	}
141 
142 	void Do(MT::FiberContext& context)
143 	{
144 		volatile WorkerThreadState* state0 = GetWorkerState( context.GetThreadContext()->workerIndex );
145 
146 		// phase 0
147 		CHECK_EQUAL((uint32)1, counter.IncFetch());
148 		state0->counterPhase0++;
149 		context.Yield();
150 
151 		// worker index can be changed after yield, get actual index
152 		volatile WorkerThreadState* state1 = GetWorkerState( context.GetThreadContext()->workerIndex );
153 
154 		//I check that all the tasks (on this worker) have passed phase0 before executing phase1
155 		CHECK_EQUAL(TASK_COUNT_PER_WORKER, state1->counterPhase0);
156 
157 		// phase 1
158 		CHECK_EQUAL((uint32)2, counter.IncFetch());
159 		state1->counterPhase1++;
160 
161 		finishedTaskCount.IncFetch();
162 	}
163 };
164 
165 
166 TEST(YieldTasks)
167 {
168 	// Disable task stealing (for testing purposes only)
169 #ifdef MT_INSTRUMENTED_BUILD
170 	MT::TaskScheduler scheduler(0, nullptr, nullptr, MT::TaskStealingMode::DISABLED);
171 #else
172 	MT::TaskScheduler scheduler(0, nullptr, MT::TaskStealingMode::DISABLED);
173 #endif
174 
175 	finishedTaskCount.Store(0);
176 
177 	int32 workersCount = scheduler.GetWorkersCount();
178 	TASK_COUNT_PER_WORKER = workersCount * 4;
179 	int32 taskCount = workersCount * TASK_COUNT_PER_WORKER;
180 
181 	MT::HardwareFullMemoryBarrier();
182 
183 	MT::StaticVector<YieldTask, 512> tasks;
184 	for(int32 i = 0; i < taskCount; i++)
185 	{
186 		tasks.PushBack(YieldTask());
187 	}
188 
189 	scheduler.RunAsync(MT::TaskGroup::Default(), tasks.Begin(), (uint32)tasks.Size());
190 
191 	CHECK(scheduler.WaitGroup(MT::TaskGroup::Default(), 10000));
192 
193 	for(int32 i = 0; i < workersCount; i++)
194 	{
195 		WorkerThreadState& state = workerStates[i];
196 
197 		CHECK_EQUAL(TASK_COUNT_PER_WORKER, state.counterPhase0);
198 		CHECK_EQUAL(TASK_COUNT_PER_WORKER, state.counterPhase1);
199 	}
200 
201 	CHECK_EQUAL(taskCount, (int32)finishedTaskCount.Load());
202 
203 	printf("Yield test: %d tasks finished, used %d workers\n", taskCount, workersCount);
204 
205 }
206 
207 
208 
209 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
210 }
211