1 /*
2 Copyright (c) 2005-2022 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 //! \file test_resumable_tasks.cpp
18 //! \brief Test for [scheduler.resumable_tasks] specification
19
20 #include "common/test.h"
21 #include "common/utils.h"
22
23 #include "tbb/task.h"
24
25 #if __TBB_RESUMABLE_TASKS
26
27 #include "tbb/global_control.h"
28 #include "tbb/task_arena.h"
29 #include "tbb/parallel_for.h"
30 #include "tbb/task_scheduler_observer.h"
31 #include "tbb/task_group.h"
32
33 #include <algorithm>
34 #include <thread>
35 #include <queue>
36 #include <condition_variable>
37
38 const int N = 10;
39
40 // External activity used in all tests, which resumes suspended execution point
41 class AsyncActivity {
42 public:
AsyncActivity(int num_)43 AsyncActivity(int num_) : m_numAsyncThreads(num_) {
44 for (int i = 0; i < m_numAsyncThreads ; ++i) {
45 m_asyncThreads.push_back( new std::thread(AsyncActivity::asyncLoop, this) );
46 }
47 }
~AsyncActivity()48 ~AsyncActivity() {
49 {
50 std::lock_guard<std::mutex> lock(m_mutex);
51 for (int i = 0; i < m_numAsyncThreads; ++i) {
52 m_tagQueue.push(nullptr);
53 }
54 m_condvar.notify_all();
55 }
56 for (int i = 0; i < m_numAsyncThreads; ++i) {
57 m_asyncThreads[i]->join();
58 delete m_asyncThreads[i];
59 }
60 CHECK(m_tagQueue.empty());
61 }
submit(tbb::task::suspend_point ctx)62 void submit(tbb::task::suspend_point ctx) {
63 std::lock_guard<std::mutex> lock(m_mutex);
64 m_tagQueue.push(ctx);
65 m_condvar.notify_one();
66 }
67
68 private:
asyncLoop(AsyncActivity * async)69 static void asyncLoop(AsyncActivity* async) {
70 tbb::task::suspend_point tag;
71 for (;;) {
72 {
73 std::unique_lock<std::mutex> lock(async->m_mutex);
74 async->m_condvar.wait(lock, [async] {return !async->m_tagQueue.empty(); });
75 tag = async->m_tagQueue.front();
76 async->m_tagQueue.pop();
77 }
78 if (!tag) {
79 break;
80 }
81 tbb::task::resume(tag);
82 };
83 }
84
85 const int m_numAsyncThreads;
86 std::mutex m_mutex;
87 std::condition_variable m_condvar;
88 std::queue<tbb::task::suspend_point> m_tagQueue;
89 std::vector<std::thread*> m_asyncThreads;
90 };
91
92 struct SuspendBody {
SuspendBodySuspendBody93 SuspendBody(AsyncActivity& a_, std::thread::id id) :
94 m_asyncActivity(a_), thread_id(id) {}
operator ()SuspendBody95 void operator()(tbb::task::suspend_point tag) {
96 CHECK(thread_id == std::this_thread::get_id());
97 m_asyncActivity.submit(tag);
98 }
99
100 private:
101 AsyncActivity& m_asyncActivity;
102 std::thread::id thread_id;
103 };
104
105 class InnermostArenaBody {
106 public:
InnermostArenaBody(AsyncActivity & a_)107 InnermostArenaBody(AsyncActivity& a_) : m_asyncActivity(a_) {}
108
operator ()()109 void operator()() {
110 InnermostOuterParFor inner_outer_body(m_asyncActivity);
111 tbb::parallel_for(0, N, inner_outer_body );
112 }
113
114 private:
115 struct InnermostInnerParFor {
InnermostInnerParForInnermostArenaBody::InnermostInnerParFor116 InnermostInnerParFor(AsyncActivity& a_) : m_asyncActivity(a_) {}
operator ()InnermostArenaBody::InnermostInnerParFor117 void operator()(int) const {
118 tbb::task::suspend(SuspendBody(m_asyncActivity, std::this_thread::get_id()));
119 }
120 AsyncActivity& m_asyncActivity;
121 };
122 struct InnermostOuterParFor {
InnermostOuterParForInnermostArenaBody::InnermostOuterParFor123 InnermostOuterParFor(AsyncActivity& a_) : m_asyncActivity(a_) {}
operator ()InnermostArenaBody::InnermostOuterParFor124 void operator()(int) const {
125 tbb::task::suspend(SuspendBody(m_asyncActivity, std::this_thread::get_id()));
126 InnermostInnerParFor inner_inner_body(m_asyncActivity);
127 tbb::parallel_for(0, N, inner_inner_body);
128 }
129 AsyncActivity& m_asyncActivity;
130 };
131 AsyncActivity& m_asyncActivity;
132 };
133
134 #include "tbb/enumerable_thread_specific.h"
135
136 class OutermostArenaBody {
137 public:
OutermostArenaBody(AsyncActivity & a_,tbb::task_arena & o_,tbb::task_arena & i_,tbb::task_arena & id_,tbb::enumerable_thread_specific<int> & ets)138 OutermostArenaBody(AsyncActivity& a_, tbb::task_arena& o_, tbb::task_arena& i_
139 , tbb::task_arena& id_, tbb::enumerable_thread_specific<int>& ets) :
140 m_asyncActivity(a_), m_outermostArena(o_), m_innermostArena(i_), m_innermostArenaDefault(id_), m_local(ets) {}
141
operator ()()142 void operator()() {
143 tbb::parallel_for(0, 32, *this);
144 }
145
operator ()(int i) const146 void operator()(int i) const {
147 tbb::task::suspend([&] (tbb::task::suspend_point sp) { m_asyncActivity.submit(sp); });
148
149 tbb::task_arena& nested_arena = (i % 3 == 0) ?
150 m_outermostArena : (i % 3 == 1 ? m_innermostArena : m_innermostArenaDefault);
151
152 if (i % 3 != 0) {
153 // We can only guarantee recall coorectness for "not-same" nested arenas entry
154 m_local.local() = i;
155 }
156 InnermostArenaBody innermost_arena_body(m_asyncActivity);
157 nested_arena.execute(innermost_arena_body);
158 if (i % 3 != 0) {
159 CHECK_MESSAGE(i == m_local.local(), "Original thread wasn't recalled for innermost nested arena.");
160 }
161 }
162
163 private:
164 AsyncActivity& m_asyncActivity;
165 tbb::task_arena& m_outermostArena;
166 tbb::task_arena& m_innermostArena;
167 tbb::task_arena& m_innermostArenaDefault;
168 tbb::enumerable_thread_specific<int>& m_local;
169 };
170
TestNestedArena()171 void TestNestedArena() {
172 AsyncActivity asyncActivity(4);
173
174 tbb::task_arena outermost_arena;
175 tbb::task_arena innermost_arena(2,2);
176 tbb::task_arena innermost_arena_default;
177
178 outermost_arena.initialize();
179 innermost_arena_default.initialize();
180 innermost_arena.initialize();
181
182 tbb::enumerable_thread_specific<int> ets;
183
184 OutermostArenaBody outer_arena_body(asyncActivity, outermost_arena, innermost_arena, innermost_arena_default, ets);
185 outermost_arena.execute(outer_arena_body);
186 }
187
188 // External activity used in all tests, which resumes suspended execution point
189 class EpochAsyncActivity {
190 public:
EpochAsyncActivity(int num_,std::atomic<int> & e_)191 EpochAsyncActivity(int num_, std::atomic<int>& e_) : m_numAsyncThreads(num_), m_globalEpoch(e_) {
192 for (int i = 0; i < m_numAsyncThreads ; ++i) {
193 m_asyncThreads.push_back( new std::thread(EpochAsyncActivity::asyncLoop, this) );
194 }
195 }
~EpochAsyncActivity()196 ~EpochAsyncActivity() {
197 {
198 std::lock_guard<std::mutex> lock(m_mutex);
199 for (int i = 0; i < m_numAsyncThreads; ++i) {
200 m_tagQueue.push(nullptr);
201 }
202 m_condvar.notify_all();
203 }
204 for (int i = 0; i < m_numAsyncThreads; ++i) {
205 m_asyncThreads[i]->join();
206 delete m_asyncThreads[i];
207 }
208 CHECK(m_tagQueue.empty());
209 }
submit(tbb::task::suspend_point ctx)210 void submit(tbb::task::suspend_point ctx) {
211 std::lock_guard<std::mutex> lock(m_mutex);
212 m_tagQueue.push(ctx);
213 m_condvar.notify_one();
214 }
215
216 private:
asyncLoop(EpochAsyncActivity * async)217 static void asyncLoop(EpochAsyncActivity* async) {
218 tbb::task::suspend_point tag;
219 for (;;) {
220 {
221 std::unique_lock<std::mutex> lock(async->m_mutex);
222 async->m_condvar.wait(lock, [async] {return !async->m_tagQueue.empty(); });
223 tag = async->m_tagQueue.front();
224 async->m_tagQueue.pop();
225 }
226 if (!tag) {
227 break;
228 }
229 // Track the global epoch
230 async->m_globalEpoch++;
231 tbb::task::resume(tag);
232 };
233 }
234
235 const int m_numAsyncThreads;
236 std::atomic<int>& m_globalEpoch;
237 std::mutex m_mutex;
238 std::condition_variable m_condvar;
239 std::queue<tbb::task::suspend_point> m_tagQueue;
240 std::vector<std::thread*> m_asyncThreads;
241 };
242
243 struct EpochSuspendBody {
EpochSuspendBodyEpochSuspendBody244 EpochSuspendBody(EpochAsyncActivity& a_, std::atomic<int>& e_, int& le_) :
245 m_asyncActivity(a_), m_globalEpoch(e_), m_localEpoch(le_) {}
246
operator ()EpochSuspendBody247 void operator()(tbb::task::suspend_point ctx) {
248 m_localEpoch = m_globalEpoch;
249 m_asyncActivity.submit(ctx);
250 }
251
252 private:
253 EpochAsyncActivity& m_asyncActivity;
254 std::atomic<int>& m_globalEpoch;
255 int& m_localEpoch;
256 };
257
258 // Simple test for basic resumable tasks functionality
TestSuspendResume()259 void TestSuspendResume() {
260 #if __TBB_USE_SANITIZERS
261 constexpr int iter_size = 100;
262 #else
263 constexpr int iter_size = 50000;
264 #endif
265
266 std::atomic<int> global_epoch; global_epoch = 0;
267 EpochAsyncActivity async(4, global_epoch);
268
269 tbb::enumerable_thread_specific<int, tbb::cache_aligned_allocator<int>, tbb::ets_suspend_aware> ets_fiber;
270 std::atomic<int> inner_par_iters, outer_par_iters;
271 inner_par_iters = outer_par_iters = 0;
272
273 tbb::parallel_for(0, N, [&](int) {
274 for (int i = 0; i < iter_size; ++i) {
275 ets_fiber.local() = i;
276
277 int local_epoch;
278 tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch));
279 CHECK(local_epoch < global_epoch);
280 CHECK(ets_fiber.local() == i);
281
282 tbb::parallel_for(0, N, [&](int) {
283 int local_epoch2;
284 tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch2));
285 CHECK(local_epoch2 < global_epoch);
286 ++inner_par_iters;
287 });
288
289 ets_fiber.local() = i;
290 tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch));
291 CHECK(local_epoch < global_epoch);
292 CHECK(ets_fiber.local() == i);
293 }
294 ++outer_par_iters;
295 });
296 CHECK(outer_par_iters == N);
297 CHECK(inner_par_iters == N*N*iter_size);
298 }
299
300 // During cleanup external thread's local task pool may
301 // e.g. contain proxies of affinitized tasks, but can be recalled
TestCleanupMaster()302 void TestCleanupMaster() {
303 if (tbb::this_task_arena::max_concurrency() == 1) {
304 // The test requires at least 2 threads
305 return;
306 }
307 AsyncActivity asyncActivity(4);
308 tbb::task_group tg;
309 std::atomic<int> iter_spawned;
310 std::atomic<int> iter_executed;
311
312 for (int i = 0; i < 100; i++) {
313 iter_spawned = 0;
314 iter_executed = 0;
315
316 utils::NativeParallelFor(N, [&asyncActivity, &tg, &iter_spawned, &iter_executed](int j) {
317 for (int k = 0; k < j*10 + 1; ++k) {
318 tg.run([&asyncActivity, j, &iter_executed] {
319 utils::doDummyWork(j * 10);
320 tbb::task::suspend(SuspendBody(asyncActivity, std::this_thread::get_id()));
321 iter_executed++;
322 });
323 iter_spawned++;
324 }
325 });
326 CHECK(iter_spawned == 460);
327 tg.wait();
328 CHECK(iter_executed == 460);
329 }
330 }
331 class ParForSuspendBody {
332 AsyncActivity& asyncActivity;
333 int m_numIters;
334 public:
ParForSuspendBody(AsyncActivity & a_,int iters)335 ParForSuspendBody(AsyncActivity& a_, int iters) : asyncActivity(a_), m_numIters(iters) {}
operator ()(int) const336 void operator()(int) const {
337 utils::doDummyWork(m_numIters);
338 tbb::task::suspend(SuspendBody(asyncActivity, std::this_thread::get_id()));
339 }
340 };
341
TestNativeThread()342 void TestNativeThread() {
343 AsyncActivity asyncActivity(4);
344
345 tbb::task_arena arena;
346 tbb::task_group tg;
347 std::atomic<int> iter{};
348 utils::NativeParallelFor(arena.max_concurrency() / 2, [&arena, &tg, &asyncActivity, &iter](int) {
349 for (int i = 0; i < 10; i++) {
350 arena.execute([&tg, &asyncActivity, &iter]() {
351 tg.run([&asyncActivity]() {
352 tbb::task::suspend(SuspendBody(asyncActivity, std::this_thread::get_id()));
353 });
354 iter++;
355 });
356 }
357 });
358
359 CHECK(iter == (arena.max_concurrency() / 2 * 10));
360 arena.execute([&tg](){
361 tg.wait();
362 });
363 }
364
365 class ObserverTracker : public tbb::task_scheduler_observer {
366 static thread_local bool is_in_arena;
367 public:
368 std::atomic<int> counter;
369
ObserverTracker(tbb::task_arena & a)370 ObserverTracker(tbb::task_arena& a) : tbb::task_scheduler_observer(a) {
371 counter = 0;
372 observe(true);
373 }
on_scheduler_entry(bool)374 void on_scheduler_entry(bool) override {
375 bool& l = is_in_arena;
376 CHECK_MESSAGE(l == false, "The thread must call on_scheduler_entry only one time.");
377 l = true;
378 ++counter;
379 }
on_scheduler_exit(bool)380 void on_scheduler_exit(bool) override {
381 bool& l = is_in_arena;
382 CHECK_MESSAGE(l == true, "The thread must call on_scheduler_entry before calling on_scheduler_exit.");
383 l = false;
384 }
385 };
386
387 thread_local bool ObserverTracker::is_in_arena;
388
TestObservers()389 void TestObservers() {
390 tbb::task_arena arena;
391 ObserverTracker tracker(arena);
392 do {
393 arena.execute([] {
394 tbb::parallel_for(0, 10, [](int) {
395 auto thread_id = std::this_thread::get_id();
396 tbb::task::suspend([thread_id](tbb::task::suspend_point tag) {
397 CHECK(thread_id == std::this_thread::get_id());
398 tbb::task::resume(tag);
399 });
400 }, tbb::simple_partitioner());
401 });
402 } while (tracker.counter < 100);
403 tracker.observe(false);
404 }
405
406 class TestCaseGuard {
407 static thread_local bool m_local;
408 tbb::global_control m_threadLimit;
409 tbb::global_control m_stackLimit;
410 public:
TestCaseGuard()411 TestCaseGuard()
412 : m_threadLimit(tbb::global_control::max_allowed_parallelism, std::max(tbb::this_task_arena::max_concurrency(), 16))
413 , m_stackLimit(tbb::global_control::thread_stack_size, 128*1024)
414 {
415 CHECK(m_local == false);
416 m_local = true;
417 }
~TestCaseGuard()418 ~TestCaseGuard() {
419 CHECK(m_local == true);
420 m_local = false;
421 }
422 };
423
424 thread_local bool TestCaseGuard::m_local = false;
425
426 //! Nested test for suspend and resume
427 //! \brief \ref error_guessing
428 TEST_CASE("Nested test for suspend and resume") {
429 TestCaseGuard guard;
430 TestSuspendResume();
431 }
432
433 //! Nested arena complex test
434 //! \brief \ref error_guessing
435 TEST_CASE("Nested arena") {
436 TestCaseGuard guard;
437 TestNestedArena();
438 }
439
440 //! Test with external threads
441 //! \brief \ref error_guessing
442 TEST_CASE("External threads") {
443 TestNativeThread();
444 }
445
446 //! Stress test with external threads
447 //! \brief \ref stress
448 TEST_CASE("Stress test with external threads") {
449 TestCleanupMaster();
450 }
451
452 //! Test with an arena observer
453 //! \brief \ref error_guessing
454 TEST_CASE("Arena observer") {
455 TestObservers();
456 }
457 #endif /* __TBB_RESUMABLE_TASKS */
458