xref: /oneTBB/test/tbb/test_resumable_tasks.cpp (revision e77098d6)
1 /*
2     Copyright (c) 2005-2022 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 //! \file test_resumable_tasks.cpp
18 //! \brief Test for [scheduler.resumable_tasks] specification
19 
20 #include "common/test.h"
21 #include "common/utils.h"
22 
23 #include "tbb/task.h"
24 
25 #if __TBB_RESUMABLE_TASKS
26 
27 #include "tbb/global_control.h"
28 #include "tbb/task_arena.h"
29 #include "tbb/parallel_for.h"
30 #include "tbb/task_scheduler_observer.h"
31 #include "tbb/task_group.h"
32 
33 #include <algorithm>
34 #include <thread>
35 #include <queue>
36 #include <condition_variable>
37 
38 const int N = 10;
39 
40 // External activity used in all tests, which resumes suspended execution point
41 class AsyncActivity {
42 public:
AsyncActivity(int num_)43     AsyncActivity(int num_) : m_numAsyncThreads(num_) {
44         for (int i = 0; i < m_numAsyncThreads ; ++i) {
45             m_asyncThreads.push_back( new std::thread(AsyncActivity::asyncLoop, this) );
46         }
47     }
~AsyncActivity()48     ~AsyncActivity() {
49         {
50             std::lock_guard<std::mutex> lock(m_mutex);
51             for (int i = 0; i < m_numAsyncThreads; ++i) {
52                 m_tagQueue.push(nullptr);
53             }
54             m_condvar.notify_all();
55         }
56         for (int i = 0; i < m_numAsyncThreads; ++i) {
57             m_asyncThreads[i]->join();
58             delete m_asyncThreads[i];
59         }
60         CHECK(m_tagQueue.empty());
61     }
submit(tbb::task::suspend_point ctx)62     void submit(tbb::task::suspend_point ctx) {
63         std::lock_guard<std::mutex> lock(m_mutex);
64         m_tagQueue.push(ctx);
65         m_condvar.notify_one();
66     }
67 
68 private:
asyncLoop(AsyncActivity * async)69     static void asyncLoop(AsyncActivity* async) {
70         tbb::task::suspend_point tag;
71         for (;;) {
72             {
73                 std::unique_lock<std::mutex> lock(async->m_mutex);
74                 async->m_condvar.wait(lock, [async] {return !async->m_tagQueue.empty(); });
75                 tag = async->m_tagQueue.front();
76                 async->m_tagQueue.pop();
77             }
78             if (!tag) {
79                 break;
80             }
81             tbb::task::resume(tag);
82         };
83     }
84 
85     const int m_numAsyncThreads;
86     std::mutex m_mutex;
87     std::condition_variable m_condvar;
88     std::queue<tbb::task::suspend_point> m_tagQueue;
89     std::vector<std::thread*> m_asyncThreads;
90 };
91 
92 struct SuspendBody {
SuspendBodySuspendBody93     SuspendBody(AsyncActivity& a_, std::thread::id id) :
94         m_asyncActivity(a_), thread_id(id) {}
operator ()SuspendBody95     void operator()(tbb::task::suspend_point tag) {
96         CHECK(thread_id == std::this_thread::get_id());
97         m_asyncActivity.submit(tag);
98     }
99 
100 private:
101     AsyncActivity& m_asyncActivity;
102     std::thread::id thread_id;
103 };
104 
105 class InnermostArenaBody {
106 public:
InnermostArenaBody(AsyncActivity & a_)107     InnermostArenaBody(AsyncActivity& a_) : m_asyncActivity(a_) {}
108 
operator ()()109     void operator()() {
110         InnermostOuterParFor inner_outer_body(m_asyncActivity);
111         tbb::parallel_for(0, N, inner_outer_body );
112     }
113 
114 private:
115     struct InnermostInnerParFor {
InnermostInnerParForInnermostArenaBody::InnermostInnerParFor116         InnermostInnerParFor(AsyncActivity& a_) : m_asyncActivity(a_) {}
operator ()InnermostArenaBody::InnermostInnerParFor117         void operator()(int) const {
118             tbb::task::suspend(SuspendBody(m_asyncActivity, std::this_thread::get_id()));
119         }
120         AsyncActivity& m_asyncActivity;
121     };
122     struct InnermostOuterParFor {
InnermostOuterParForInnermostArenaBody::InnermostOuterParFor123         InnermostOuterParFor(AsyncActivity& a_) : m_asyncActivity(a_) {}
operator ()InnermostArenaBody::InnermostOuterParFor124         void operator()(int) const {
125             tbb::task::suspend(SuspendBody(m_asyncActivity, std::this_thread::get_id()));
126             InnermostInnerParFor inner_inner_body(m_asyncActivity);
127             tbb::parallel_for(0, N, inner_inner_body);
128         }
129         AsyncActivity& m_asyncActivity;
130     };
131     AsyncActivity& m_asyncActivity;
132 };
133 
134 #include "tbb/enumerable_thread_specific.h"
135 
136 class OutermostArenaBody {
137 public:
OutermostArenaBody(AsyncActivity & a_,tbb::task_arena & o_,tbb::task_arena & i_,tbb::task_arena & id_,tbb::enumerable_thread_specific<int> & ets)138     OutermostArenaBody(AsyncActivity& a_, tbb::task_arena& o_, tbb::task_arena& i_
139             , tbb::task_arena& id_, tbb::enumerable_thread_specific<int>& ets) :
140         m_asyncActivity(a_), m_outermostArena(o_), m_innermostArena(i_), m_innermostArenaDefault(id_), m_local(ets) {}
141 
operator ()()142     void operator()() {
143         tbb::parallel_for(0, 32, *this);
144     }
145 
operator ()(int i) const146     void operator()(int i) const {
147         tbb::task::suspend([&] (tbb::task::suspend_point sp) { m_asyncActivity.submit(sp); });
148 
149         tbb::task_arena& nested_arena = (i % 3 == 0) ?
150             m_outermostArena : (i % 3 == 1 ? m_innermostArena : m_innermostArenaDefault);
151 
152         if (i % 3 != 0) {
153             // We can only guarantee recall coorectness for "not-same" nested arenas entry
154             m_local.local() = i;
155         }
156         InnermostArenaBody innermost_arena_body(m_asyncActivity);
157         nested_arena.execute(innermost_arena_body);
158         if (i % 3 != 0) {
159             CHECK_MESSAGE(i == m_local.local(), "Original thread wasn't recalled for innermost nested arena.");
160         }
161     }
162 
163 private:
164     AsyncActivity& m_asyncActivity;
165     tbb::task_arena& m_outermostArena;
166     tbb::task_arena& m_innermostArena;
167     tbb::task_arena& m_innermostArenaDefault;
168     tbb::enumerable_thread_specific<int>& m_local;
169 };
170 
TestNestedArena()171 void TestNestedArena() {
172     AsyncActivity asyncActivity(4);
173 
174     tbb::task_arena outermost_arena;
175     tbb::task_arena innermost_arena(2,2);
176     tbb::task_arena innermost_arena_default;
177 
178     outermost_arena.initialize();
179     innermost_arena_default.initialize();
180     innermost_arena.initialize();
181 
182     tbb::enumerable_thread_specific<int> ets;
183 
184     OutermostArenaBody outer_arena_body(asyncActivity, outermost_arena, innermost_arena, innermost_arena_default, ets);
185     outermost_arena.execute(outer_arena_body);
186 }
187 
188 // External activity used in all tests, which resumes suspended execution point
189 class EpochAsyncActivity {
190 public:
EpochAsyncActivity(int num_,std::atomic<int> & e_)191     EpochAsyncActivity(int num_, std::atomic<int>& e_) : m_numAsyncThreads(num_), m_globalEpoch(e_) {
192         for (int i = 0; i < m_numAsyncThreads ; ++i) {
193             m_asyncThreads.push_back( new std::thread(EpochAsyncActivity::asyncLoop, this) );
194         }
195     }
~EpochAsyncActivity()196     ~EpochAsyncActivity() {
197         {
198             std::lock_guard<std::mutex> lock(m_mutex);
199             for (int i = 0; i < m_numAsyncThreads; ++i) {
200                 m_tagQueue.push(nullptr);
201             }
202             m_condvar.notify_all();
203         }
204         for (int i = 0; i < m_numAsyncThreads; ++i) {
205             m_asyncThreads[i]->join();
206             delete m_asyncThreads[i];
207         }
208         CHECK(m_tagQueue.empty());
209     }
submit(tbb::task::suspend_point ctx)210     void submit(tbb::task::suspend_point ctx) {
211         std::lock_guard<std::mutex> lock(m_mutex);
212         m_tagQueue.push(ctx);
213         m_condvar.notify_one();
214     }
215 
216 private:
asyncLoop(EpochAsyncActivity * async)217     static void asyncLoop(EpochAsyncActivity* async) {
218         tbb::task::suspend_point tag;
219         for (;;) {
220             {
221                 std::unique_lock<std::mutex> lock(async->m_mutex);
222                 async->m_condvar.wait(lock, [async] {return !async->m_tagQueue.empty(); });
223                 tag = async->m_tagQueue.front();
224                 async->m_tagQueue.pop();
225             }
226             if (!tag) {
227                 break;
228             }
229             // Track the global epoch
230             async->m_globalEpoch++;
231             tbb::task::resume(tag);
232         };
233     }
234 
235     const int m_numAsyncThreads;
236     std::atomic<int>& m_globalEpoch;
237     std::mutex m_mutex;
238     std::condition_variable m_condvar;
239     std::queue<tbb::task::suspend_point> m_tagQueue;
240     std::vector<std::thread*> m_asyncThreads;
241 };
242 
243 struct EpochSuspendBody {
EpochSuspendBodyEpochSuspendBody244     EpochSuspendBody(EpochAsyncActivity& a_, std::atomic<int>& e_, int& le_) :
245         m_asyncActivity(a_), m_globalEpoch(e_), m_localEpoch(le_) {}
246 
operator ()EpochSuspendBody247     void operator()(tbb::task::suspend_point ctx) {
248         m_localEpoch = m_globalEpoch;
249         m_asyncActivity.submit(ctx);
250     }
251 
252 private:
253     EpochAsyncActivity& m_asyncActivity;
254     std::atomic<int>& m_globalEpoch;
255     int& m_localEpoch;
256 };
257 
258 // Simple test for basic resumable tasks functionality
TestSuspendResume()259 void TestSuspendResume() {
260 #if __TBB_USE_SANITIZERS
261     constexpr int iter_size = 100;
262 #else
263     constexpr int iter_size = 50000;
264 #endif
265 
266     std::atomic<int> global_epoch; global_epoch = 0;
267     EpochAsyncActivity async(4, global_epoch);
268 
269     tbb::enumerable_thread_specific<int, tbb::cache_aligned_allocator<int>, tbb::ets_suspend_aware> ets_fiber;
270     std::atomic<int> inner_par_iters, outer_par_iters;
271     inner_par_iters = outer_par_iters = 0;
272 
273     tbb::parallel_for(0, N, [&](int) {
274         for (int i = 0; i < iter_size; ++i) {
275             ets_fiber.local() = i;
276 
277             int local_epoch;
278             tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch));
279             CHECK(local_epoch < global_epoch);
280             CHECK(ets_fiber.local() == i);
281 
282             tbb::parallel_for(0, N, [&](int) {
283                 int local_epoch2;
284                 tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch2));
285                 CHECK(local_epoch2 < global_epoch);
286                 ++inner_par_iters;
287             });
288 
289             ets_fiber.local() = i;
290             tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch));
291             CHECK(local_epoch < global_epoch);
292             CHECK(ets_fiber.local() == i);
293         }
294         ++outer_par_iters;
295     });
296     CHECK(outer_par_iters == N);
297     CHECK(inner_par_iters == N*N*iter_size);
298 }
299 
300 // During cleanup external thread's local task pool may
301 // e.g. contain proxies of affinitized tasks, but can be recalled
TestCleanupMaster()302 void TestCleanupMaster() {
303     if (tbb::this_task_arena::max_concurrency() == 1) {
304         // The test requires at least 2 threads
305         return;
306     }
307     AsyncActivity asyncActivity(4);
308     tbb::task_group tg;
309     std::atomic<int> iter_spawned;
310     std::atomic<int> iter_executed;
311 
312     for (int i = 0; i < 100; i++) {
313         iter_spawned = 0;
314         iter_executed = 0;
315 
316         utils::NativeParallelFor(N, [&asyncActivity, &tg, &iter_spawned, &iter_executed](int j) {
317             for (int k = 0; k < j*10 + 1; ++k) {
318                 tg.run([&asyncActivity, j, &iter_executed] {
319                     utils::doDummyWork(j * 10);
320                     tbb::task::suspend(SuspendBody(asyncActivity, std::this_thread::get_id()));
321                     iter_executed++;
322                 });
323                 iter_spawned++;
324             }
325         });
326         CHECK(iter_spawned == 460);
327         tg.wait();
328         CHECK(iter_executed == 460);
329     }
330 }
331 class ParForSuspendBody {
332     AsyncActivity& asyncActivity;
333     int m_numIters;
334 public:
ParForSuspendBody(AsyncActivity & a_,int iters)335     ParForSuspendBody(AsyncActivity& a_, int iters) : asyncActivity(a_), m_numIters(iters) {}
operator ()(int) const336     void operator()(int) const {
337         utils::doDummyWork(m_numIters);
338         tbb::task::suspend(SuspendBody(asyncActivity, std::this_thread::get_id()));
339     }
340 };
341 
TestNativeThread()342 void TestNativeThread() {
343     AsyncActivity asyncActivity(4);
344 
345     tbb::task_arena arena;
346     tbb::task_group tg;
347     std::atomic<int> iter{};
348     utils::NativeParallelFor(arena.max_concurrency() / 2, [&arena, &tg, &asyncActivity, &iter](int) {
349         for (int i = 0; i < 10; i++) {
350             arena.execute([&tg, &asyncActivity, &iter]() {
351                 tg.run([&asyncActivity]() {
352                     tbb::task::suspend(SuspendBody(asyncActivity, std::this_thread::get_id()));
353                 });
354                 iter++;
355             });
356         }
357     });
358 
359     CHECK(iter == (arena.max_concurrency() / 2 * 10));
360     arena.execute([&tg](){
361         tg.wait();
362     });
363 }
364 
365 class ObserverTracker : public tbb::task_scheduler_observer {
366     static thread_local bool is_in_arena;
367 public:
368     std::atomic<int> counter;
369 
ObserverTracker(tbb::task_arena & a)370     ObserverTracker(tbb::task_arena& a) : tbb::task_scheduler_observer(a) {
371         counter = 0;
372         observe(true);
373     }
on_scheduler_entry(bool)374     void on_scheduler_entry(bool) override {
375         bool& l = is_in_arena;
376         CHECK_MESSAGE(l == false, "The thread must call on_scheduler_entry only one time.");
377         l = true;
378         ++counter;
379     }
on_scheduler_exit(bool)380     void on_scheduler_exit(bool) override {
381         bool& l = is_in_arena;
382         CHECK_MESSAGE(l == true, "The thread must call on_scheduler_entry before calling on_scheduler_exit.");
383         l = false;
384     }
385 };
386 
387 thread_local bool ObserverTracker::is_in_arena;
388 
TestObservers()389 void TestObservers() {
390     tbb::task_arena arena;
391     ObserverTracker tracker(arena);
392     do {
393         arena.execute([] {
394             tbb::parallel_for(0, 10, [](int) {
395                 auto thread_id = std::this_thread::get_id();
396                 tbb::task::suspend([thread_id](tbb::task::suspend_point tag) {
397                     CHECK(thread_id == std::this_thread::get_id());
398                     tbb::task::resume(tag);
399                 });
400             }, tbb::simple_partitioner());
401         });
402     } while (tracker.counter < 100);
403     tracker.observe(false);
404 }
405 
406 class TestCaseGuard {
407     static thread_local bool m_local;
408     tbb::global_control m_threadLimit;
409     tbb::global_control m_stackLimit;
410 public:
TestCaseGuard()411     TestCaseGuard()
412         : m_threadLimit(tbb::global_control::max_allowed_parallelism, std::max(tbb::this_task_arena::max_concurrency(), 16))
413         , m_stackLimit(tbb::global_control::thread_stack_size, 128*1024)
414     {
415         CHECK(m_local == false);
416         m_local = true;
417     }
~TestCaseGuard()418     ~TestCaseGuard() {
419         CHECK(m_local == true);
420         m_local = false;
421     }
422 };
423 
424 thread_local bool TestCaseGuard::m_local = false;
425 
426 //! Nested test for suspend and resume
427 //! \brief \ref error_guessing
428 TEST_CASE("Nested test for suspend and resume") {
429     TestCaseGuard guard;
430     TestSuspendResume();
431 }
432 
433 //! Nested arena complex test
434 //! \brief \ref error_guessing
435 TEST_CASE("Nested arena") {
436     TestCaseGuard guard;
437     TestNestedArena();
438 }
439 
440 //! Test with external threads
441 //! \brief \ref error_guessing
442 TEST_CASE("External threads") {
443     TestNativeThread();
444 }
445 
446 //! Stress test with external threads
447 //! \brief \ref stress
448 TEST_CASE("Stress test with external threads") {
449     TestCleanupMaster();
450 }
451 
452 //! Test with an arena observer
453 //! \brief \ref error_guessing
454 TEST_CASE("Arena observer") {
455     TestObservers();
456 }
457 #endif /* __TBB_RESUMABLE_TASKS */
458