xref: /oneTBB/test/tbb/test_resumable_tasks.cpp (revision d86ed7fb)
1 /*
2     Copyright (c) 2005-2020 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 //! \file test_resumable_tasks.cpp
18 //! \brief Test for [scheduler.resumable_tasks] specification
19 
20 #include "common/test.h"
21 #include "common/utils.h"
22 
23 #include "tbb/task.h"
24 
25 #if __TBB_RESUMABLE_TASKS
26 
27 #include "tbb/global_control.h"
28 #include "tbb/task_arena.h"
29 #include "tbb/parallel_for.h"
30 #include "tbb/task_scheduler_observer.h"
31 #include "tbb/task_group.h"
32 
33 #include <algorithm>
34 #include <thread>
35 #include <queue>
36 #include <condition_variable>
37 
38 const int N = 10;
39 
40 // External activity used in all tests, which resumes suspended execution point
41 class AsyncActivity {
42 public:
43     AsyncActivity(int num_) : m_numAsyncThreads(num_) {
44         for (int i = 0; i < m_numAsyncThreads ; ++i) {
45             m_asyncThreads.push_back( new std::thread(AsyncActivity::asyncLoop, this) );
46         }
47     }
48     ~AsyncActivity() {
49         {
50             std::lock_guard<std::mutex> lock(m_mutex);
51             for (int i = 0; i < m_numAsyncThreads; ++i) {
52                 m_tagQueue.push(nullptr);
53             }
54             m_condvar.notify_all();
55         }
56         for (int i = 0; i < m_numAsyncThreads; ++i) {
57             m_asyncThreads[i]->join();
58             delete m_asyncThreads[i];
59         }
60         CHECK(m_tagQueue.empty());
61     }
62     void submit(tbb::task::suspend_point ctx) {
63         std::lock_guard<std::mutex> lock(m_mutex);
64         m_tagQueue.push(ctx);
65         m_condvar.notify_one();
66     }
67 
68 private:
69     static void asyncLoop(AsyncActivity* async) {
70         tbb::task::suspend_point tag;
71         for (;;) {
72             {
73                 std::unique_lock<std::mutex> lock(async->m_mutex);
74                 async->m_condvar.wait(lock, [async] {return !async->m_tagQueue.empty(); });
75                 tag = async->m_tagQueue.front();
76                 async->m_tagQueue.pop();
77             }
78             if (!tag) {
79                 break;
80             }
81             tbb::task::resume(tag);
82         };
83     }
84 
85     const int m_numAsyncThreads;
86     std::mutex m_mutex;
87     std::condition_variable m_condvar;
88     std::queue<tbb::task::suspend_point> m_tagQueue;
89     std::vector<std::thread*> m_asyncThreads;
90 };
91 
92 struct SuspendBody {
93     SuspendBody(AsyncActivity& a_) :
94         m_asyncActivity(a_) {}
95     void operator()(tbb::task::suspend_point tag) {
96         m_asyncActivity.submit(tag);
97     }
98 
99 private:
100     AsyncActivity& m_asyncActivity;
101 };
102 
103 class InnermostArenaBody {
104 public:
105     InnermostArenaBody(AsyncActivity& a_) : m_asyncActivity(a_) {}
106 
107     void operator()() {
108         InnermostOuterParFor inner_outer_body(m_asyncActivity);
109         tbb::parallel_for(0, N, inner_outer_body );
110     }
111 
112 private:
113     struct InnermostInnerParFor {
114         InnermostInnerParFor(AsyncActivity& a_) : m_asyncActivity(a_) {}
115         void operator()(int) const {
116             tbb::task::suspend(SuspendBody(m_asyncActivity));
117         }
118         AsyncActivity& m_asyncActivity;
119     };
120     struct InnermostOuterParFor {
121         InnermostOuterParFor(AsyncActivity& a_) : m_asyncActivity(a_) {}
122         void operator()(int) const {
123             tbb::task::suspend(SuspendBody(m_asyncActivity));
124             InnermostInnerParFor inner_inner_body(m_asyncActivity);
125             tbb::parallel_for(0, N, inner_inner_body);
126         }
127         AsyncActivity& m_asyncActivity;
128     };
129     AsyncActivity& m_asyncActivity;
130 };
131 
132 #include "tbb/enumerable_thread_specific.h"
133 
134 class OutermostArenaBody {
135 public:
136     OutermostArenaBody(AsyncActivity& a_, tbb::task_arena& o_, tbb::task_arena& i_
137             , tbb::task_arena& id_, tbb::enumerable_thread_specific<int>& ets) :
138         m_asyncActivity(a_), m_outermostArena(o_), m_innermostArena(i_), m_innermostArenaDefault(id_), m_local(ets) {}
139 
140     void operator()() {
141         tbb::parallel_for(0, 32, *this);
142     }
143 
144     void operator()(int i) const {
145         tbb::task::suspend(SuspendBody(m_asyncActivity));
146 
147         tbb::task_arena& nested_arena = (i % 3 == 0) ?
148             m_outermostArena : (i % 3 == 1 ? m_innermostArena : m_innermostArenaDefault);
149 
150         if (i % 3 != 0) {
151             // We can only guarantee recall coorectness for "not-same" nested arenas entry
152             m_local.local() = i;
153         }
154         InnermostArenaBody innermost_arena_body(m_asyncActivity);
155         nested_arena.execute(innermost_arena_body);
156         if (i % 3 != 0) {
157             CHECK_MESSAGE(i == m_local.local(), "Original thread wasn't recalled for innermost nested arena.");
158         }
159     }
160 
161 private:
162     AsyncActivity& m_asyncActivity;
163     tbb::task_arena& m_outermostArena;
164     tbb::task_arena& m_innermostArena;
165     tbb::task_arena& m_innermostArenaDefault;
166     tbb::enumerable_thread_specific<int>& m_local;
167 };
168 
169 void TestNestedArena() {
170     AsyncActivity asyncActivity(4);
171 
172     tbb::task_arena outermost_arena;
173     tbb::task_arena innermost_arena(2,2);
174     tbb::task_arena innermost_arena_default;
175 
176     outermost_arena.initialize();
177     innermost_arena_default.initialize();
178     innermost_arena.initialize();
179 
180     tbb::enumerable_thread_specific<int> ets;
181 
182     OutermostArenaBody outer_arena_body(asyncActivity, outermost_arena, innermost_arena, innermost_arena_default, ets);
183     outermost_arena.execute(outer_arena_body);
184 }
185 
186 // External activity used in all tests, which resumes suspended execution point
187 class EpochAsyncActivity {
188 public:
189     EpochAsyncActivity(int num_, std::atomic<int>& e_) : m_numAsyncThreads(num_), m_globalEpoch(e_) {
190         for (int i = 0; i < m_numAsyncThreads ; ++i) {
191             m_asyncThreads.push_back( new std::thread(EpochAsyncActivity::asyncLoop, this) );
192         }
193     }
194     ~EpochAsyncActivity() {
195         {
196             std::lock_guard<std::mutex> lock(m_mutex);
197             for (int i = 0; i < m_numAsyncThreads; ++i) {
198                 m_tagQueue.push(nullptr);
199             }
200             m_condvar.notify_all();
201         }
202         for (int i = 0; i < m_numAsyncThreads; ++i) {
203             m_asyncThreads[i]->join();
204             delete m_asyncThreads[i];
205         }
206         CHECK(m_tagQueue.empty());
207     }
208     void submit(tbb::task::suspend_point ctx) {
209         std::lock_guard<std::mutex> lock(m_mutex);
210         m_tagQueue.push(ctx);
211         m_condvar.notify_one();
212     }
213 
214 private:
215     static void asyncLoop(EpochAsyncActivity* async) {
216         tbb::task::suspend_point tag;
217         for (;;) {
218             {
219                 std::unique_lock<std::mutex> lock(async->m_mutex);
220                 async->m_condvar.wait(lock, [async] {return !async->m_tagQueue.empty(); });
221                 tag = async->m_tagQueue.front();
222                 async->m_tagQueue.pop();
223             }
224             if (!tag) {
225                 break;
226             }
227             // Track the global epoch
228             async->m_globalEpoch++;
229             tbb::task::resume(tag);
230         };
231     }
232 
233     const int m_numAsyncThreads;
234     std::atomic<int>& m_globalEpoch;
235     std::mutex m_mutex;
236     std::condition_variable m_condvar;
237     std::queue<tbb::task::suspend_point> m_tagQueue;
238     std::vector<std::thread*> m_asyncThreads;
239 };
240 
241 struct EpochSuspendBody {
242     EpochSuspendBody(EpochAsyncActivity& a_, std::atomic<int>& e_, int& le_) :
243         m_asyncActivity(a_), m_globalEpoch(e_), m_localEpoch(le_) {}
244 
245     void operator()(tbb::task::suspend_point ctx) {
246         m_localEpoch = m_globalEpoch;
247         m_asyncActivity.submit(ctx);
248     }
249 
250 private:
251     EpochAsyncActivity& m_asyncActivity;
252     std::atomic<int>& m_globalEpoch;
253     int& m_localEpoch;
254 };
255 
256 // Simple test for basic resumable tasks functionality
257 void TestSuspendResume() {
258     std::atomic<int> global_epoch; global_epoch = 0;
259     EpochAsyncActivity async(4, global_epoch);
260 
261     tbb::enumerable_thread_specific<int, tbb::cache_aligned_allocator<int>, tbb::ets_suspend_aware> ets_fiber;
262     std::atomic<int> inner_par_iters, outer_par_iters;
263     inner_par_iters = outer_par_iters = 0;
264 
265     tbb::parallel_for(0, N, [&](int) {
266         for (int i = 0; i < 100; ++i) {
267             ets_fiber.local() = i;
268 
269             int local_epoch;
270             tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch));
271             CHECK(local_epoch < global_epoch);
272             CHECK(ets_fiber.local() == i);
273 
274             tbb::parallel_for(0, N, [&](int) {
275                 int local_epoch2;
276                 tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch2));
277                 CHECK(local_epoch2 < global_epoch);
278                 ++inner_par_iters;
279             });
280 
281             ets_fiber.local() = i;
282             tbb::task::suspend(EpochSuspendBody(async, global_epoch, local_epoch));
283             CHECK(local_epoch < global_epoch);
284             CHECK(ets_fiber.local() == i);
285         }
286         ++outer_par_iters;
287     });
288     CHECK(outer_par_iters == N);
289     CHECK(inner_par_iters == N*N*100);
290 }
291 
292 // During cleanup master's local task pool may
293 // e.g. contain proxies of affinitized tasks, but can be recalled
294 void TestCleanupMaster() {
295     if (tbb::this_task_arena::max_concurrency() == 1) {
296         // The test requires at least 2 threads
297         return;
298     }
299     AsyncActivity asyncActivity(4);
300     tbb::task_group tg;
301     std::atomic<int> iter_spawned;
302     std::atomic<int> iter_executed;
303 
304     for (int i = 0; i < 100; i++) {
305         iter_spawned = 0;
306         iter_executed = 0;
307 
308         utils::NativeParallelFor(N, [&asyncActivity, &tg, &iter_spawned, &iter_executed](int j) {
309             for (int k = 0; k < j*10 + 1; ++k) {
310                 tg.run([&asyncActivity, j, &iter_executed] {
311                     for (volatile int l = 0; l < j*10; ++l) {}
312                     tbb::task::suspend(SuspendBody(asyncActivity));
313                     iter_executed++;
314                 });
315                 iter_spawned++;
316             }
317         });
318         CHECK(iter_spawned == 460);
319         tg.wait();
320         CHECK(iter_executed == 460);
321     }
322 }
323 class ParForSuspendBody {
324     AsyncActivity& asyncActivity;
325     int m_numIters;
326 public:
327     ParForSuspendBody(AsyncActivity& a_, int iters) : asyncActivity(a_), m_numIters(iters) {}
328     void operator()(int) const {
329         for (volatile int i = 0; i < m_numIters; ++i) {}
330         tbb::task::suspend(SuspendBody(asyncActivity));
331     }
332 };
333 
334 void TestNativeThread() {
335     AsyncActivity asyncActivity(4);
336 
337     tbb::task_arena arena;
338     tbb::task_group tg;
339     std::atomic<int> iter{};
340     utils::NativeParallelFor(arena.max_concurrency() / 2, [&arena, &tg, &asyncActivity, &iter](int) {
341         for (int i = 0; i < 10; i++) {
342             arena.execute([&tg, &asyncActivity, &iter]() {
343                 tg.run([&asyncActivity]() {
344                     tbb::task::suspend(SuspendBody(asyncActivity));
345                 });
346                 iter++;
347             });
348         }
349     });
350 
351     CHECK(iter == (arena.max_concurrency() / 2 * 10));
352     arena.execute([&tg](){
353         tg.wait();
354     });
355 }
356 
357 class ObserverTracker : public tbb::task_scheduler_observer {
358     static thread_local bool is_in_arena;
359 public:
360     std::atomic<int> counter;
361 
362     ObserverTracker(tbb::task_arena& a) : tbb::task_scheduler_observer(a) {
363         counter = 0;
364         observe(true);
365     }
366     void on_scheduler_entry(bool) override {
367         bool& l = is_in_arena;
368         CHECK_MESSAGE(l == false, "The thread must call on_scheduler_entry only one time.");
369         l = true;
370         ++counter;
371     }
372     void on_scheduler_exit(bool) override {
373         bool& l = is_in_arena;
374         CHECK_MESSAGE(l == true, "The thread must call on_scheduler_entry before calling on_scheduler_exit.");
375         l = false;
376     }
377 };
378 
379 thread_local bool ObserverTracker::is_in_arena;
380 
381 void TestObservers() {
382     tbb::task_arena arena;
383     ObserverTracker tracker(arena);
384     do {
385         arena.execute([] {
386             tbb::parallel_for(0, 10, [](int) {
387                 tbb::task::suspend([](tbb::task::suspend_point tag) {
388                     tbb::task::resume(tag);
389                 });
390             }, tbb::simple_partitioner());
391         });
392     } while (tracker.counter < 100);
393     tracker.observe(false);
394 }
395 
396 class TestCaseGuard {
397     static thread_local bool m_local;
398     tbb::global_control m_threadLimit;
399     tbb::global_control m_stackLimit;
400 public:
401     TestCaseGuard()
402         : m_threadLimit(tbb::global_control::max_allowed_parallelism, std::max(tbb::this_task_arena::max_concurrency(), 16))
403         , m_stackLimit(tbb::global_control::thread_stack_size, 128*1024)
404     {
405         CHECK(m_local == false);
406         m_local = true;
407     }
408     ~TestCaseGuard() {
409         CHECK(m_local == true);
410         m_local = false;
411     }
412 };
413 
414 thread_local bool TestCaseGuard::m_local = false;
415 
416 //! Nested test for suspend and resume
417 //! \brief \ref error_guessing
418 TEST_CASE("Nested test for suspend and resume") {
419     TestCaseGuard guard;
420     TestSuspendResume();
421 }
422 
423 //! Nested arena complex test
424 //! \brief \ref error_guessing
425 TEST_CASE("Nested arena") {
426     TestCaseGuard guard;
427     TestNestedArena();
428 }
429 
430 //! Test with external threads
431 //! \brief \ref error_guessing
432 TEST_CASE("External threads") {
433     TestNativeThread();
434 }
435 
436 //! Stress test with external threads
437 //! \brief \ref stress
438 TEST_CASE("Stress test with external threads") {
439     TestCleanupMaster();
440 }
441 
442 //! Test with an arena observer
443 //! \brief \ref error_guessing
444 TEST_CASE("Arena observer") {
445     TestObservers();
446 }
447 #endif /* __TBB_RESUMABLE_TASKS */
448