xref: /oneTBB/test/tbb/test_resumable_tasks.cpp (revision adf44fbf)
1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 //! \file test_resumable_tasks.cpp
18 //! \brief Test for [scheduler.resumable_tasks] specification
19 
20 #include "common/test.h"
21 #include "common/utils.h"
22 
23 #include "tbb/task.h"
24 
25 #if __TBB_RESUMABLE_TASKS
26 
27 #include "tbb/global_control.h"
28 #include "tbb/task_arena.h"
29 #include "tbb/parallel_for.h"
30 #include "tbb/task_scheduler_observer.h"
31 #include "tbb/task_group.h"
32 
33 #include <algorithm>
34 #include <thread>
35 #include <queue>
36 #include <condition_variable>
37 
38 const int N = 10;
39 
40 // External activity used in all tests, which resumes suspended execution point
41 class AsyncActivity {
42 public:
43     AsyncActivity(int num_) : m_numAsyncThreads(num_) {
44         for (int i = 0; i < m_numAsyncThreads ; ++i) {
45             m_asyncThreads.push_back( new std::thread(AsyncActivity::asyncLoop, this) );
46         }
47     }
48     ~AsyncActivity() {
49         {
50             std::lock_guard<std::mutex> lock(m_mutex);
51             for (int i = 0; i < m_numAsyncThreads; ++i) {
52                 m_tagQueue.push(nullptr);
53             }
54             m_condvar.notify_all();
55         }
56         for (int i = 0; i < m_numAsyncThreads; ++i) {
57             m_asyncThreads[i]->join();
58             delete m_asyncThreads[i];
59         }
60         CHECK(m_tagQueue.empty());
61     }
62     void submit(tbb::task::suspend_point ctx) {
63         std::lock_guard<std::mutex> lock(m_mutex);
64         m_tagQueue.push(ctx);
65         m_condvar.notify_one();
66     }
67 
68 private:
69     static void asyncLoop(AsyncActivity* async) {
70         tbb::task::suspend_point tag;
71         for (;;) {
72             {
73                 std::unique_lock<std::mutex> lock(async->m_mutex);
74                 async->m_condvar.wait(lock, [async] {return !async->m_tagQueue.empty(); });
75                 tag = async->m_tagQueue.front();
76                 async->m_tagQueue.pop();
77             }
78             if (!tag) {
79                 break;
80             }
81             tbb::task::resume(tag);
82         };
83     }
84 
85     const int m_numAsyncThreads;
86     std::mutex m_mutex;
87     std::condition_variable m_condvar;
88     std::queue<tbb::task::suspend_point> m_tagQueue;
89     std::vector<std::thread*> m_asyncThreads;
90 };
91 
92 struct SuspendBody {
93     SuspendBody(AsyncActivity& a_) :
94         m_asyncActivity(a_) {}
95     void operator()(tbb::task::suspend_point tag) {
96         m_asyncActivity.submit(tag);
97     }
98 
99 private:
100     AsyncActivity& m_asyncActivity;
101 };
102 
103 class InnermostArenaBody {
104 public:
105     InnermostArenaBody(AsyncActivity& a_) : m_asyncActivity(a_) {}
106 
107     void operator()() {
108         InnermostOuterParFor inner_outer_body(m_asyncActivity);
109         tbb::parallel_for(0, N, inner_outer_body );
110     }
111 
112 private:
113     struct InnermostInnerParFor {
114         InnermostInnerParFor(AsyncActivity& a_) : m_asyncActivity(a_) {}
115         void operator()(int) const {
116             tbb::task::suspend(SuspendBody(m_asyncActivity));
117         }
118         AsyncActivity& m_asyncActivity;
119     };
120     struct InnermostOuterParFor {
121         InnermostOuterParFor(AsyncActivity& a_) : m_asyncActivity(a_) {}
122         void operator()(int) const {
123             tbb::task::suspend(SuspendBody(m_asyncActivity));
124             InnermostInnerParFor inner_inner_body(m_asyncActivity);
125             tbb::parallel_for(0, N, inner_inner_body);
126         }
127         AsyncActivity& m_asyncActivity;
128     };
129     AsyncActivity& m_asyncActivity;
130 };
131 
132 #include "tbb/enumerable_thread_specific.h"
133 
134 class OutermostArenaBody {
135 public:
136     OutermostArenaBody(AsyncActivity& a_, tbb::task_arena& o_, tbb::task_arena& i_
137             , tbb::task_arena& id_, tbb::enumerable_thread_specific<int>& ets) :
138         m_asyncActivity(a_), m_outermostArena(o_), m_innermostArena(i_), m_innermostArenaDefault(id_), m_local(ets) {}
139 
140     void operator()() {
141         tbb::parallel_for(0, 32, *this);
142     }
143 
144     void operator()(int i) const {
145         tbb::task::suspend(SuspendBody(m_asyncActivity));
146 
147         tbb::task_arena& nested_arena = (i % 3 == 0) ?
148             m_outermostArena : (i % 3 == 1 ? m_innermostArena : m_innermostArenaDefault);
149 
150         if (i % 3 != 0) {
151             // We can only guarantee recall coorectness for "not-same" nested arenas entry
152             m_local.local() = i;
153         }
154         InnermostArenaBody innermost_arena_body(m_asyncActivity);
155         nested_arena.execute(innermost_arena_body);
156         if (i % 3 != 0) {
157             CHECK_MESSAGE(i == m_local.local(), "Original thread wasn't recalled for innermost nested arena.");
158         }
159     }
160 
161 private:
162     AsyncActivity& m_asyncActivity;
163     tbb::task_arena& m_outermostArena;
164     tbb::task_arena& m_innermostArena;
165     tbb::task_arena& m_innermostArenaDefault;
166     tbb::enumerable_thread_specific<int>& m_local;
167 };
168 
169 void TestNestedArena() {
170     AsyncActivity asyncActivity(4);
171 
172     tbb::task_arena outermost_arena;
173     tbb::task_arena innermost_arena(2,2);
174     tbb::task_arena innermost_arena_default;
175 
176     outermost_arena.initialize();
177     innermost_arena_default.initialize();
178     innermost_arena.initialize();
179 
180     tbb::enumerable_thread_specific<int> ets;
181 
182     OutermostArenaBody outer_arena_body(asyncActivity, outermost_arena, innermost_arena, innermost_arena_default, ets);
183     outermost_arena.execute(outer_arena_body);
184 }
185 
186 // External activity used in all tests, which resumes suspended execution point
187 class EpochAsyncActivity {
188 public:
189     EpochAsyncActivity(int num_, std::atomic<int>& e_) : m_numAsyncThreads(num_), m_globalEpoch(e_) {
190         for (int i = 0; i < m_numAsyncThreads ; ++i) {
191             m_asyncThreads.push_back( new std::thread(EpochAsyncActivity::asyncLoop, this) );
192         }
193     }
194     ~EpochAsyncActivity() {
195         {
196             std::lock_guard<std::mutex> lock(m_mutex);
197             for (int i = 0; i < m_numAsyncThreads; ++i) {
198                 m_tagQueue.push(nullptr);
199             }
200             m_condvar.notify_all();
201         }
202         for (int i = 0; i < m_numAsyncThreads; ++i) {
203             m_asyncThreads[i]->join();
204             delete m_asyncThreads[i];
205         }
206         CHECK(m_tagQueue.empty());
207     }
208     void submit(tbb::task::suspend_point ctx) {
209         std::lock_guard<std::mutex> lock(m_mutex);
210         m_tagQueue.push(ctx);
211         m_condvar.notify_one();
212     }
213 
214 private:
215     static void asyncLoop(EpochAsyncActivity* async) {
216         tbb::task::suspend_point tag;
217         for (;;) {
218             {
219                 std::unique_lock<std::mutex> lock(async->m_mutex);
220                 async->m_condvar.wait(lock, [async] {return !async->m_tagQueue.empty(); });
221                 tag = async->m_tagQueue.front();
222                 async->m_tagQueue.pop();
223             }
224             if (!tag) {
225                 break;
226             }
227             // Track the global epoch
228             async->m_globalEpoch++;
229             tbb::task::resume(tag);
230         };
231     }
232 
233     const int m_numAsyncThreads;
234     std::atomic<int>& m_globalEpoch;
235     std::mutex m_mutex;
236     std::condition_variable m_condvar;
237     std::queue<tbb::task::suspend_point> m_tagQueue;
238     std::vector<std::thread*> m_asyncThreads;
239 };
240 
241 struct EpochSuspendBody {
242     EpochSuspendBody(EpochAsyncActivity& a_, std::atomic<int>& e_, int& le_) :
243         m_asyncActivity(a_), m_globalEpoch(e_), m_localEpoch(le_) {}
244 
245     void operator()(tbb::task::suspend_point ctx) {
246         m_localEpoch = m_globalEpoch;
247         m_asyncActivity.submit(ctx);
248     }
249 
250 private:
251     EpochAsyncActivity& m_asyncActivity;
252     std::atomic<int>& m_globalEpoch;
253     int& m_localEpoch;
254 };
255 
256 // Simple test for basic resumable tasks functionality
257 void TestSuspendResume() {
258 #if __TBB_USE_SANITIZERS
259     constexpr int iter_size = 100;
260 #else
261     constexpr int iter_size = 50000;
262 #endif
263 
264     std::atomic<int> global_epoch; global_epoch = 0;
265     EpochAsyncActivity async(4, global_epoch);
266 
267     tbb::enumerable_thread_specific<int, tbb::cache_aligned_allocator<int>, tbb::ets_suspend_aware> ets_fiber;
268     std::atomic<int> inner_par_iters, outer_par_iters;
269     inner_par_iters = outer_par_iters = 0;
270 
271     tbb::parallel_for(0, N, [&](int) {
272         for (int i = 0; i < iter_size; ++i) {
273             ets_fiber.local() = i;
274 
275             int local_epoch;
276             tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch));
277             CHECK(local_epoch < global_epoch);
278             CHECK(ets_fiber.local() == i);
279 
280             tbb::parallel_for(0, N, [&](int) {
281                 int local_epoch2;
282                 tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch2));
283                 CHECK(local_epoch2 < global_epoch);
284                 ++inner_par_iters;
285             });
286 
287             ets_fiber.local() = i;
288             tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch));
289             CHECK(local_epoch < global_epoch);
290             CHECK(ets_fiber.local() == i);
291         }
292         ++outer_par_iters;
293     });
294     CHECK(outer_par_iters == N);
295     CHECK(inner_par_iters == N*N*iter_size);
296 }
297 
298 // During cleanup external thread's local task pool may
299 // e.g. contain proxies of affinitized tasks, but can be recalled
300 void TestCleanupMaster() {
301     if (tbb::this_task_arena::max_concurrency() == 1) {
302         // The test requires at least 2 threads
303         return;
304     }
305     AsyncActivity asyncActivity(4);
306     tbb::task_group tg;
307     std::atomic<int> iter_spawned;
308     std::atomic<int> iter_executed;
309 
310     for (int i = 0; i < 100; i++) {
311         iter_spawned = 0;
312         iter_executed = 0;
313 
314         utils::NativeParallelFor(N, [&asyncActivity, &tg, &iter_spawned, &iter_executed](int j) {
315             for (int k = 0; k < j*10 + 1; ++k) {
316                 tg.run([&asyncActivity, j, &iter_executed] {
317                     utils::doDummyWork(j * 10);
318                     tbb::task::suspend(SuspendBody(asyncActivity));
319                     iter_executed++;
320                 });
321                 iter_spawned++;
322             }
323         });
324         CHECK(iter_spawned == 460);
325         tg.wait();
326         CHECK(iter_executed == 460);
327     }
328 }
329 class ParForSuspendBody {
330     AsyncActivity& asyncActivity;
331     int m_numIters;
332 public:
333     ParForSuspendBody(AsyncActivity& a_, int iters) : asyncActivity(a_), m_numIters(iters) {}
334     void operator()(int) const {
335         utils::doDummyWork(m_numIters);
336         tbb::task::suspend(SuspendBody(asyncActivity));
337     }
338 };
339 
340 void TestNativeThread() {
341     AsyncActivity asyncActivity(4);
342 
343     tbb::task_arena arena;
344     tbb::task_group tg;
345     std::atomic<int> iter{};
346     utils::NativeParallelFor(arena.max_concurrency() / 2, [&arena, &tg, &asyncActivity, &iter](int) {
347         for (int i = 0; i < 10; i++) {
348             arena.execute([&tg, &asyncActivity, &iter]() {
349                 tg.run([&asyncActivity]() {
350                     tbb::task::suspend(SuspendBody(asyncActivity));
351                 });
352                 iter++;
353             });
354         }
355     });
356 
357     CHECK(iter == (arena.max_concurrency() / 2 * 10));
358     arena.execute([&tg](){
359         tg.wait();
360     });
361 }
362 
363 class ObserverTracker : public tbb::task_scheduler_observer {
364     static thread_local bool is_in_arena;
365 public:
366     std::atomic<int> counter;
367 
368     ObserverTracker(tbb::task_arena& a) : tbb::task_scheduler_observer(a) {
369         counter = 0;
370         observe(true);
371     }
372     void on_scheduler_entry(bool) override {
373         bool& l = is_in_arena;
374         CHECK_MESSAGE(l == false, "The thread must call on_scheduler_entry only one time.");
375         l = true;
376         ++counter;
377     }
378     void on_scheduler_exit(bool) override {
379         bool& l = is_in_arena;
380         CHECK_MESSAGE(l == true, "The thread must call on_scheduler_entry before calling on_scheduler_exit.");
381         l = false;
382     }
383 };
384 
385 thread_local bool ObserverTracker::is_in_arena;
386 
387 void TestObservers() {
388     tbb::task_arena arena;
389     ObserverTracker tracker(arena);
390     do {
391         arena.execute([] {
392             tbb::parallel_for(0, 10, [](int) {
393                 tbb::task::suspend([](tbb::task::suspend_point tag) {
394                     tbb::task::resume(tag);
395                 });
396             }, tbb::simple_partitioner());
397         });
398     } while (tracker.counter < 100);
399     tracker.observe(false);
400 }
401 
402 class TestCaseGuard {
403     static thread_local bool m_local;
404     tbb::global_control m_threadLimit;
405     tbb::global_control m_stackLimit;
406 public:
407     TestCaseGuard()
408         : m_threadLimit(tbb::global_control::max_allowed_parallelism, std::max(tbb::this_task_arena::max_concurrency(), 16))
409         , m_stackLimit(tbb::global_control::thread_stack_size, 128*1024)
410     {
411         CHECK(m_local == false);
412         m_local = true;
413     }
414     ~TestCaseGuard() {
415         CHECK(m_local == true);
416         m_local = false;
417     }
418 };
419 
420 thread_local bool TestCaseGuard::m_local = false;
421 
422 //! Nested test for suspend and resume
423 //! \brief \ref error_guessing
424 TEST_CASE("Nested test for suspend and resume") {
425     TestCaseGuard guard;
426     TestSuspendResume();
427 }
428 
429 //! Nested arena complex test
430 //! \brief \ref error_guessing
431 TEST_CASE("Nested arena") {
432     TestCaseGuard guard;
433     TestNestedArena();
434 }
435 
436 //! Test with external threads
437 //! \brief \ref error_guessing
438 TEST_CASE("External threads") {
439     TestNativeThread();
440 }
441 
442 //! Stress test with external threads
443 //! \brief \ref stress
444 TEST_CASE("Stress test with external threads") {
445     TestCleanupMaster();
446 }
447 
448 //! Test with an arena observer
449 //! \brief \ref error_guessing
450 TEST_CASE("Arena observer") {
451     TestObservers();
452 }
453 #endif /* __TBB_RESUMABLE_TASKS */
454