1 // The MIT License (MIT)
2 //
3 // Copyright (c) 2015 Sergey Makeev, Vadim Slyusarev
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining a copy
6 // of this software and associated documentation files (the "Software"), to deal
7 // in the Software without restriction, including without limitation the rights
8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // copies of the Software, and to permit persons to whom the Software is
10 // furnished to do so, subject to the following conditions:
11 //
12 // The above copyright notice and this permission notice shall be included in
13 // all copies or substantial portions of the Software.
14 //
15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 // THE SOFTWARE.
22
23 #include "Tests.h"
24 #include <UnitTest++.h>
25 #include <MTScheduler.h>
26 #include <MTStaticVector.h>
27
SUITE(SimpleTests)28 SUITE(SimpleTests)
29 {
30 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
31 struct SimpleTask
32 {
33 MT_DECLARE_TASK(SimpleTask, MT::StackRequirements::STANDARD, MT::TaskPriority::NORMAL, MT::Color::Blue);
34
35 static const int sourceData = 0xFF33FF;
36 int resultData;
37
38 SimpleTask() : resultData(0) {}
39
40 void Do(MT::FiberContext&)
41 {
42 resultData = sourceData;
43 }
44
45 int GetSourceData()
46 {
47 return sourceData;
48 }
49 };
50
51 // Checks one simple task
52 TEST(RunOneSimpleTask)
53 {
54 MT::TaskScheduler scheduler;
55
56 SimpleTask task;
57 scheduler.RunAsync(MT::TaskGroup::Default(), &task, 1);
58
59 CHECK(scheduler.WaitAll(1000));
60 CHECK_EQUAL(task.GetSourceData(), task.resultData);
61 }
62 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
63
64 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
65 struct ALotOfTasks
66 {
67 MT_DECLARE_TASK(ALotOfTasks, MT::StackRequirements::STANDARD, MT::TaskPriority::NORMAL, MT::Color::Blue);
68
69 MT::Atomic32<int32>* counter;
70
71 void Do(MT::FiberContext&)
72 {
73 counter->IncFetch();
74 MT::SpinSleepMilliSeconds(1);
75 }
76 };
77
78 // Checks one simple task
79 TEST(ALotOfTasks)
80 {
81 MT::TaskScheduler scheduler;
82
83 MT::Atomic32<int32> counter;
84
85 static const int TASK_COUNT = 1000;
86
87 ALotOfTasks tasks[TASK_COUNT];
88
89 for (size_t i = 0; i < MT_ARRAY_SIZE(tasks); ++i)
90 tasks[i].counter = &counter;
91
92 scheduler.RunAsync(MT::TaskGroup::Default(), &tasks[0], MT_ARRAY_SIZE(tasks));
93
94 int timeout = (TASK_COUNT / scheduler.GetWorkersCount()) * 2000;
95
96 CHECK(scheduler.WaitGroup(MT::TaskGroup::Default(), timeout));
97 CHECK_EQUAL(TASK_COUNT, counter.Load());
98 }
99
100
101 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
102
103
104 struct WorkerThreadState
105 {
106 uint32 counterPhase0;
107 uint32 counterPhase1;
108
109 WorkerThreadState()
110 {
111 Reset();
112 }
113
114 void Reset()
115 {
116 counterPhase0 = 0;
117 counterPhase1 = 0;
118 }
119 };
120
121
122 WorkerThreadState workerStates[64];
123
124 uint32 TASK_COUNT_PER_WORKER = 0;
125
126 MT::Atomic32<uint32> finishedTaskCount;
127
128 struct YieldTask
129 {
130 MT::Atomic32<uint32> counter;
131
132 MT_DECLARE_TASK(YieldTask, MT::StackRequirements::STANDARD, MT::TaskPriority::NORMAL, MT::Color::Blue);
133
134 YieldTask()
135 {
136 counter.Store(0);
137 }
138
139
140 volatile WorkerThreadState* GetWorkerState( volatile uint32 workerIndex) volatile
141 {
142 MT_ASSERT(workerIndex < MT_ARRAY_SIZE(workerStates), "Invalid worker index");
143 volatile WorkerThreadState& state = workerStates[workerIndex];
144 return &state;
145 }
146
147 void Do(MT::FiberContext& context)
148 {
149 volatile WorkerThreadState* state0 = GetWorkerState( context.GetThreadContext()->workerIndex );
150
151 // phase 0
152 CHECK_EQUAL((uint32)1, counter.IncFetch());
153 state0->counterPhase0++;
154 context.Yield();
155
156 // worker index can be changed after yield, get actual index
157 volatile WorkerThreadState* state1 = GetWorkerState( context.GetThreadContext()->workerIndex );
158
159 //I check that all the tasks (on this worker) have passed phase0 before executing phase1
160 CHECK_EQUAL(TASK_COUNT_PER_WORKER, state1->counterPhase0);
161
162 // phase 1
163 CHECK_EQUAL((uint32)2, counter.IncFetch());
164 state1->counterPhase1++;
165
166 finishedTaskCount.IncFetch();
167 }
168 };
169
170
171 TEST(YieldTasks)
172 {
173 // Disable task stealing (for testing purposes only)
174 #ifdef MT_INSTRUMENTED_BUILD
175 MT::TaskScheduler scheduler(0, nullptr, nullptr, MT::TaskStealingMode::DISABLED);
176 #else
177 MT::TaskScheduler scheduler(0, nullptr, MT::TaskStealingMode::DISABLED);
178 #endif
179
180 finishedTaskCount.Store(0);
181
182 int32 workersCount = scheduler.GetWorkersCount();
183 TASK_COUNT_PER_WORKER = workersCount * 4;
184 int32 taskCount = workersCount * TASK_COUNT_PER_WORKER;
185
186 MT::HardwareFullMemoryBarrier();
187
188 MT::StaticVector<YieldTask, 512> tasks;
189 for(int32 i = 0; i < taskCount; i++)
190 {
191 tasks.PushBack(YieldTask());
192 }
193
194 for(int32 i = 0; i < workersCount; i++)
195 {
196 WorkerThreadState& state = workerStates[i];
197 state.Reset();
198 }
199
200
201 scheduler.RunAsync(MT::TaskGroup::Default(), tasks.Begin(), (uint32)tasks.Size());
202
203 CHECK(scheduler.WaitGroup(MT::TaskGroup::Default(), 10000));
204
205 for(int32 i = 0; i < workersCount; i++)
206 {
207 WorkerThreadState& state = workerStates[i];
208
209 CHECK_EQUAL(TASK_COUNT_PER_WORKER, state.counterPhase0);
210 CHECK_EQUAL(TASK_COUNT_PER_WORKER, state.counterPhase1);
211 }
212
213 CHECK_EQUAL(taskCount, (int32)finishedTaskCount.Load());
214
215 printf("Yield test: %d tasks finished, used %d workers\n", taskCount, workersCount);
216
217 }
218
219
220
221 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
222 }
223