xref: /oneTBB/test/tbb/test_continue_node.cpp (revision 29009a8e)
1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #include "common/config.h"
18 
19 #include "tbb/flow_graph.h"
20 
21 #include "common/test.h"
22 #include "common/utils.h"
23 #include "common/graph_utils.h"
24 #include "common/test_follows_and_precedes_api.h"
25 #include "common/concepts_common.h"
26 
27 
28 //! \file test_continue_node.cpp
29 //! \brief Test for [flow_graph.continue_node] specification
30 
31 
32 #define N 1000
33 #define MAX_NODES 4
34 #define C 8
35 
36 // A class to use as a fake predecessor of continue_node
37 struct fake_continue_sender : public tbb::flow::sender<tbb::flow::continue_msg>
38 {
39     typedef tbb::flow::sender<tbb::flow::continue_msg>::successor_type successor_type;
40     // Define implementations of virtual methods that are abstract in the base class
41     bool register_successor( successor_type& ) override { return false; }
42     bool remove_successor( successor_type& )   override { return false; }
43 };
44 
45 template< typename InputType >
46 struct parallel_puts {
47 
48     tbb::flow::receiver< InputType > * const my_exe_node;
49 
50     parallel_puts( tbb::flow::receiver< InputType > &exe_node ) : my_exe_node(&exe_node) {}
51     parallel_puts& operator=(const parallel_puts&) = delete;
52 
53     void operator()( int ) const  {
54         for ( int i = 0; i < N; ++i ) {
55             // the nodes will accept all puts
56             CHECK_MESSAGE( my_exe_node->try_put( InputType() ) == true, "" );
57         }
58     }
59 
60 };
61 
62 template< typename OutputType >
63 void run_continue_nodes( int p, tbb::flow::graph& g, tbb::flow::continue_node< OutputType >& n ) {
64     fake_continue_sender fake_sender;
65     for (size_t i = 0; i < N; ++i) {
66         tbb::detail::d1::register_predecessor(n, fake_sender);
67     }
68 
69     for (size_t num_receivers = 1; num_receivers <= MAX_NODES; ++num_receivers ) {
70         std::vector< std::shared_ptr<harness_counting_receiver<OutputType>> > receivers;
71         for (size_t i = 0; i < num_receivers; ++i) {
72             receivers.push_back( std::make_shared<harness_counting_receiver<OutputType>>(g) );
73         }
74         harness_graph_executor<tbb::flow::continue_msg, OutputType>::execute_count = 0;
75 
76         for (size_t r = 0; r < num_receivers; ++r ) {
77             tbb::flow::make_edge( n, *receivers[r] );
78         }
79 
80         utils::NativeParallelFor( p, parallel_puts<tbb::flow::continue_msg>(n) );
81         g.wait_for_all();
82 
83         // 2) the nodes will receive puts from multiple predecessors simultaneously,
84         size_t ec = harness_graph_executor<tbb::flow::continue_msg, OutputType>::execute_count;
85         CHECK_MESSAGE( (int)ec == p, "" );
86         for (size_t r = 0; r < num_receivers; ++r ) {
87             size_t c = receivers[r]->my_count;
88             // 3) the nodes will send to multiple successors.
89             CHECK_MESSAGE( (int)c == p, "" );
90         }
91 
92         for (size_t r = 0; r < num_receivers; ++r ) {
93             tbb::flow::remove_edge( n, *receivers[r] );
94         }
95     }
96 }
97 
98 template< typename OutputType, typename Body >
99 void continue_nodes( Body body ) {
100     for (int p = 1; p < 2*4/*MaxThread*/; ++p) {
101         tbb::flow::graph g;
102         tbb::flow::continue_node< OutputType > exe_node( g, body );
103         run_continue_nodes( p, g, exe_node);
104         exe_node.try_put(tbb::flow::continue_msg());
105         tbb::flow::continue_node< OutputType > exe_node_copy( exe_node );
106         run_continue_nodes( p, g, exe_node_copy);
107     }
108 }
109 
110 const size_t Offset = 123;
111 std::atomic<size_t> global_execute_count;
112 
113 template< typename OutputType >
114 struct inc_functor {
115 
116     std::atomic<size_t> local_execute_count;
117     inc_functor( ) { local_execute_count = 0; }
118     inc_functor( const inc_functor &f ) { local_execute_count = size_t(f.local_execute_count); }
119     void operator=(const inc_functor &f) { local_execute_count = size_t(f.local_execute_count); }
120 
121     OutputType operator()( tbb::flow::continue_msg ) {
122        ++global_execute_count;
123        ++local_execute_count;
124        return OutputType();
125     }
126 
127 };
128 
129 template< typename OutputType >
130 void continue_nodes_with_copy( ) {
131 
132     for (int p = 1; p < 2*4/*MaxThread*/; ++p) {
133         tbb::flow::graph g;
134         inc_functor<OutputType> cf;
135         cf.local_execute_count = Offset;
136         global_execute_count = Offset;
137 
138         tbb::flow::continue_node< OutputType > exe_node( g, cf );
139         fake_continue_sender fake_sender;
140         for (size_t i = 0; i < N; ++i) {
141             tbb::detail::d1::register_predecessor(exe_node, fake_sender);
142         }
143 
144         for (size_t num_receivers = 1; num_receivers <= MAX_NODES; ++num_receivers ) {
145             std::vector< std::shared_ptr<harness_counting_receiver<OutputType>> > receivers;
146             for (size_t i = 0; i < num_receivers; ++i) {
147                 receivers.push_back( std::make_shared<harness_counting_receiver<OutputType>>(g) );
148             }
149 
150             for (size_t r = 0; r < num_receivers; ++r ) {
151                 tbb::flow::make_edge( exe_node, *receivers[r] );
152             }
153 
154             utils::NativeParallelFor( p, parallel_puts<tbb::flow::continue_msg>(exe_node) );
155             g.wait_for_all();
156 
157             // 2) the nodes will receive puts from multiple predecessors simultaneously,
158             for (size_t r = 0; r < num_receivers; ++r ) {
159                 size_t c = receivers[r]->my_count;
160                 // 3) the nodes will send to multiple successors.
161                 CHECK_MESSAGE( (int)c == p, "" );
162             }
163             for (size_t r = 0; r < num_receivers; ++r ) {
164                 tbb::flow::remove_edge( exe_node, *receivers[r] );
165             }
166         }
167 
168         // validate that the local body matches the global execute_count and both are correct
169         inc_functor<OutputType> body_copy = tbb::flow::copy_body< inc_functor<OutputType> >( exe_node );
170         const size_t expected_count = p*MAX_NODES + Offset;
171         size_t global_count = global_execute_count;
172         size_t inc_count = body_copy.local_execute_count;
173         CHECK_MESSAGE( global_count == expected_count, "" );
174         CHECK_MESSAGE( global_count == inc_count, "" );
175         g.reset(tbb::flow::rf_reset_bodies);
176         body_copy = tbb::flow::copy_body< inc_functor<OutputType> >( exe_node );
177         inc_count = body_copy.local_execute_count;
178         CHECK_MESSAGE( ( Offset == inc_count), "reset(rf_reset_bodies) did not reset functor" );
179 
180     }
181 }
182 
183 template< typename OutputType >
184 void run_continue_nodes() {
185     harness_graph_executor< tbb::flow::continue_msg, OutputType>::max_executors = 0;
186     continue_nodes<OutputType>( []( tbb::flow::continue_msg i ) -> OutputType { return harness_graph_executor<tbb::flow::continue_msg, OutputType>::func(i); } );
187     continue_nodes<OutputType>( &harness_graph_executor<tbb::flow::continue_msg, OutputType>::func );
188     continue_nodes<OutputType>( typename harness_graph_executor<tbb::flow::continue_msg, OutputType>::functor() );
189     continue_nodes_with_copy<OutputType>();
190 }
191 
192 //! Tests limited concurrency cases for nodes that accept data messages
193 void test_concurrency(int num_threads) {
194     tbb::task_arena arena(num_threads);
195     arena.execute(
196         [&] {
197             run_continue_nodes<tbb::flow::continue_msg>();
198             run_continue_nodes<int>();
199             run_continue_nodes<utils::NoAssign>();
200         }
201     );
202 }
203 /*
204  * Connection of two graphs is not currently supported, but works to some limited extent.
205  * This test is included to check for backward compatibility. It checks that a continue_node
206  * with predecessors in two different graphs receives the required
207  * number of continue messages before it executes.
208  */
209 using namespace tbb::flow;
210 
211 struct add_to_counter {
212     int* counter;
213     add_to_counter(int& var):counter(&var){}
214     void operator()(continue_msg){*counter+=1;}
215 };
216 
217 void test_two_graphs(){
218     int count=0;
219 
220     //graph g with broadcast_node and continue_node
221     graph g;
222     broadcast_node<continue_msg> start_g(g);
223     continue_node<continue_msg> first_g(g, add_to_counter(count));
224 
225     //graph h with broadcast_node
226     graph h;
227     broadcast_node<continue_msg> start_h(h);
228 
229     //making two edges to first_g from the two graphs
230     make_edge(start_g,first_g);
231     make_edge(start_h, first_g);
232 
233     //two try_puts from the two graphs
234     start_g.try_put(continue_msg());
235     start_h.try_put(continue_msg());
236     g.wait_for_all();
237     CHECK_MESSAGE( (count==1), "Not all continue messages received");
238 
239     //two try_puts from the graph that doesn't contain the node
240     count=0;
241     start_h.try_put(continue_msg());
242     start_h.try_put(continue_msg());
243     g.wait_for_all();
244     CHECK_MESSAGE( (count==1), "Not all continue messages received -1");
245 
246     //only one try_put
247     count=0;
248     start_g.try_put(continue_msg());
249     g.wait_for_all();
250     CHECK_MESSAGE( (count==0), "Node executed without waiting for all predecessors");
251 }
252 
253 struct lightweight_policy_body {
254     const std::thread::id my_thread_id;
255     std::atomic<size_t>& my_count;
256 
257     lightweight_policy_body( std::atomic<size_t>& count )
258         : my_thread_id(std::this_thread::get_id()), my_count(count)
259     {
260         my_count = 0;
261     }
262     lightweight_policy_body& operator=(const lightweight_policy_body&) = delete;
263     void operator()(tbb::flow::continue_msg) {
264         ++my_count;
265         std::thread::id body_thread_id = std::this_thread::get_id();
266         CHECK_MESSAGE( (body_thread_id == my_thread_id), "Body executed as not lightweight");
267     }
268 };
269 
270 void test_lightweight_policy() {
271     tbb::flow::graph g;
272     std::atomic<size_t> count1;
273     std::atomic<size_t> count2;
274     tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>
275         node1(g, lightweight_policy_body(count1));
276     tbb::flow::continue_node<tbb::flow::continue_msg, tbb::flow::lightweight>
277         node2(g, lightweight_policy_body(count2));
278 
279     tbb::flow::make_edge(node1, node2);
280     const size_t n = 10;
281     for(size_t i = 0; i < n; ++i) {
282         node1.try_put(tbb::flow::continue_msg());
283     }
284     g.wait_for_all();
285 
286     lightweight_policy_body body1 = tbb::flow::copy_body<lightweight_policy_body>(node1);
287     lightweight_policy_body body2 = tbb::flow::copy_body<lightweight_policy_body>(node2);
288     CHECK_MESSAGE( (body1.my_count == n), "Body of the first node needs to be executed N times");
289     CHECK_MESSAGE( (body2.my_count == n), "Body of the second node needs to be executed N times");
290 }
291 
292 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
293 #include <array>
294 #include <vector>
295 void test_follows_and_precedes_api() {
296     using msg_t = tbb::flow::continue_msg;
297 
298     std::array<msg_t, 3> messages_for_follows = { { msg_t(), msg_t(), msg_t() } };
299     std::vector<msg_t> messages_for_precedes  = { msg_t() };
300 
301     auto pass_through = [](const msg_t& msg) { return msg; };
302 
303     follows_and_precedes_testing::test_follows
304         <msg_t, tbb::flow::continue_node<msg_t>>
305         (messages_for_follows, pass_through, node_priority_t(0));
306 
307     follows_and_precedes_testing::test_precedes
308         <msg_t, tbb::flow::continue_node<msg_t>>
309         (messages_for_precedes, /* number_of_predecessors = */0, pass_through, node_priority_t(1));
310 }
311 #endif // __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
312 
313 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
314 
315 template <typename ExpectedType, typename Body>
316 void test_deduction_guides_common(Body body) {
317     using namespace tbb::flow;
318     graph g;
319 
320     continue_node c1(g, body);
321     static_assert(std::is_same_v<decltype(c1), continue_node<ExpectedType>>);
322 
323     continue_node c2(g, body, lightweight());
324     static_assert(std::is_same_v<decltype(c2), continue_node<ExpectedType, lightweight>>);
325 
326     continue_node c3(g, 5, body);
327     static_assert(std::is_same_v<decltype(c3), continue_node<ExpectedType>>);
328 
329     continue_node c4(g, 5, body, lightweight());
330     static_assert(std::is_same_v<decltype(c4), continue_node<ExpectedType, lightweight>>);
331 
332     continue_node c5(g, body, node_priority_t(5));
333     static_assert(std::is_same_v<decltype(c5), continue_node<ExpectedType>>);
334 
335     continue_node c6(g, body, lightweight(), node_priority_t(5));
336     static_assert(std::is_same_v<decltype(c6), continue_node<ExpectedType, lightweight>>);
337 
338     continue_node c7(g, 5, body, node_priority_t(5));
339     static_assert(std::is_same_v<decltype(c7), continue_node<ExpectedType>>);
340 
341     continue_node c8(g, 5, body, lightweight(), node_priority_t(5));
342     static_assert(std::is_same_v<decltype(c8), continue_node<ExpectedType, lightweight>>);
343 
344 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
345     broadcast_node<continue_msg> b(g);
346 
347     continue_node c9(follows(b), body);
348     static_assert(std::is_same_v<decltype(c9), continue_node<ExpectedType>>);
349 
350     continue_node c10(follows(b), body, lightweight());
351     static_assert(std::is_same_v<decltype(c10), continue_node<ExpectedType, lightweight>>);
352 
353     continue_node c11(follows(b), 5, body);
354     static_assert(std::is_same_v<decltype(c11), continue_node<ExpectedType>>);
355 
356     continue_node c12(follows(b), 5, body, lightweight());
357     static_assert(std::is_same_v<decltype(c12), continue_node<ExpectedType, lightweight>>);
358 
359     continue_node c13(follows(b), body, node_priority_t(5));
360     static_assert(std::is_same_v<decltype(c13), continue_node<ExpectedType>>);
361 
362     continue_node c14(follows(b), body, lightweight(), node_priority_t(5));
363     static_assert(std::is_same_v<decltype(c14), continue_node<ExpectedType, lightweight>>);
364 
365     continue_node c15(follows(b), 5, body, node_priority_t(5));
366     static_assert(std::is_same_v<decltype(c15), continue_node<ExpectedType>>);
367 
368     continue_node c16(follows(b), 5, body, lightweight(), node_priority_t(5));
369     static_assert(std::is_same_v<decltype(c16), continue_node<ExpectedType, lightweight>>);
370 #endif // __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
371 
372     continue_node c17(c1);
373     static_assert(std::is_same_v<decltype(c17), continue_node<ExpectedType>>);
374 }
375 
376 int continue_body_f(const tbb::flow::continue_msg&) { return 1; }
377 void continue_void_body_f(const tbb::flow::continue_msg&) {}
378 
379 void test_deduction_guides() {
380     using tbb::flow::continue_msg;
381     test_deduction_guides_common<int>([](const continue_msg&)->int { return 1; } );
382     test_deduction_guides_common<continue_msg>([](const continue_msg&) {});
383 
384     test_deduction_guides_common<int>([](const continue_msg&) mutable ->int { return 1; });
385     test_deduction_guides_common<continue_msg>([](const continue_msg&) mutable {});
386 
387     test_deduction_guides_common<int>(continue_body_f);
388     test_deduction_guides_common<continue_msg>(continue_void_body_f);
389 }
390 
391 #endif // __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
392 
393 // TODO: use pass_through from test_function_node instead
394 template<typename T>
395 struct passing_body {
396     T operator()(const T& val) {
397         return val;
398     }
399 };
400 
401 /*
402     The test covers the case when a node with non-default mutex type is a predecessor for continue_node,
403     because there used to be a bug when make_edge(node, continue_node)
404     did not update continue_node's predecesosor threshold
405     since the specialization of node's successor_cache for a continue_node was not chosen.
406 */
407 void test_successor_cache_specialization() {
408     using namespace tbb::flow;
409 
410     graph g;
411 
412     broadcast_node<continue_msg> node_with_default_mutex_type(g);
413     buffer_node<continue_msg> node_with_non_default_mutex_type(g);
414 
415     continue_node<continue_msg> node(g, passing_body<continue_msg>());
416 
417     make_edge(node_with_default_mutex_type, node);
418     make_edge(node_with_non_default_mutex_type, node);
419 
420     buffer_node<continue_msg> buf(g);
421 
422     make_edge(node, buf);
423 
424     node_with_default_mutex_type.try_put(continue_msg());
425     node_with_non_default_mutex_type.try_put(continue_msg());
426 
427     g.wait_for_all();
428 
429     continue_msg storage;
430     CHECK_MESSAGE((buf.try_get(storage) && !buf.try_get(storage)),
431                   "Wrong number of messages is passed via continue_node");
432 }
433 
434 //! Test concurrent continue_node for correctness
435 //! \brief \ref error_guessing
436 TEST_CASE("Concurrency testing") {
437     for( unsigned p=utils::MinThread; p<=utils::MaxThread; ++p ) {
438         test_concurrency(p);
439     }
440 }
441 
442 //! Test concurrent continue_node in separate graphs
443 //! \brief \ref error_guessing
444 TEST_CASE("Two graphs") { test_two_graphs(); }
445 
446 //! Test basic behaviour with lightweight body
447 //! \brief \ref requirement \ref error_guessing
448 TEST_CASE( "Lightweight policy" ) { test_lightweight_policy(); }
449 
450 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
451 //! Test deprecated follows and preceedes API
452 //! \brief \ref error_guessing
453 TEST_CASE( "Support for follows and precedes API" ) { test_follows_and_precedes_api(); }
454 #endif
455 
456 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
457 //! Test deduction guides
458 //! \brief requirement
459 TEST_CASE( "Deduction guides" ) { test_deduction_guides(); }
460 #endif
461 
462 //! Test for successor cache specialization
463 //! \brief \ref regression
464 TEST_CASE( "Regression for successor cache specialization" ) {
465     test_successor_cache_specialization();
466 }
467 
468 #if __TBB_CPP20_CONCEPTS_PRESENT
469 //! \brief \ref error_guessing
470 TEST_CASE("constraints for continue_node input") {
471     static_assert(utils::well_formed_instantiation<tbb::flow::continue_node, test_concepts::Copyable>);
472     static_assert(!utils::well_formed_instantiation<tbb::flow::continue_node, test_concepts::NonCopyable>);
473 }
474 
475 template <typename Input, typename Body>
476 concept can_call_continue_node_ctor = requires( tbb::flow::graph& graph, Body body,
477                                                 tbb::flow::buffer_node<int>& f, std::size_t num,
478                                                 tbb::flow::node_priority_t priority  ) {
479     tbb::flow::continue_node<Input>(graph, body);
480     tbb::flow::continue_node<Input>(graph, body, priority);
481     tbb::flow::continue_node<Input>(graph, num, body);
482     tbb::flow::continue_node<Input>(graph, num, body, priority);
483 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
484     tbb::flow::continue_node<Input>(tbb::flow::follows(f), body);
485     tbb::flow::continue_node<Input>(tbb::flow::follows(f), body, priority);
486     tbb::flow::continue_node<Input>(tbb::flow::follows(f), num, body);
487     tbb::flow::continue_node<Input>(tbb::flow::follows(f), num, body, priority);
488 #endif // __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
489 };
490 
491 //! \brief \ref error_guessing
492 TEST_CASE("constraints for continue_node body") {
493     using output_type = int;
494     using namespace test_concepts::continue_node_body;
495 
496     static_assert(can_call_continue_node_ctor<output_type, Correct<output_type>>);
497     static_assert(!can_call_continue_node_ctor<output_type, NonCopyable<output_type>>);
498     static_assert(!can_call_continue_node_ctor<output_type, NonDestructible<output_type>>);
499     static_assert(!can_call_continue_node_ctor<output_type, NoOperatorRoundBrackets<output_type>>);
500     static_assert(!can_call_continue_node_ctor<output_type, WrongInputOperatorRoundBrackets<output_type>>);
501     static_assert(!can_call_continue_node_ctor<output_type, WrongReturnOperatorRoundBrackets<output_type>>);
502 }
503 #endif // __TBB_CPP20_CONCEPTS_PRESENT
504