xref: /oneTBB/test/tbb/test_input_node.cpp (revision b15aabb3)
151c0b2f7Stbbdev /*
2*b15aabb3Stbbdev     Copyright (c) 2005-2021 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
1751c0b2f7Stbbdev // have to expose the reset_node method to be able to reset a function_body
1851c0b2f7Stbbdev 
1951c0b2f7Stbbdev #include "common/config.h"
2051c0b2f7Stbbdev 
2151c0b2f7Stbbdev #include "tbb/flow_graph.h"
2251c0b2f7Stbbdev 
2351c0b2f7Stbbdev #include "common/test.h"
2451c0b2f7Stbbdev #include "common/utils.h"
2551c0b2f7Stbbdev #include "common/utils_assert.h"
2651c0b2f7Stbbdev 
2751c0b2f7Stbbdev 
2851c0b2f7Stbbdev //! \file test_input_node.cpp
2951c0b2f7Stbbdev //! \brief Test for [flow_graph.input_node] specification
3051c0b2f7Stbbdev 
3151c0b2f7Stbbdev 
3251c0b2f7Stbbdev using tbb::detail::d1::graph_task;
3351c0b2f7Stbbdev using tbb::detail::d1::SUCCESSFULLY_ENQUEUED;
3451c0b2f7Stbbdev 
3551c0b2f7Stbbdev const int N = 1000;
3651c0b2f7Stbbdev 
3751c0b2f7Stbbdev template< typename T >
3851c0b2f7Stbbdev class test_push_receiver : public tbb::flow::receiver<T>, utils::NoAssign {
3951c0b2f7Stbbdev 
4051c0b2f7Stbbdev     std::atomic<int> my_counters[N];
4151c0b2f7Stbbdev     tbb::flow::graph& my_graph;
4251c0b2f7Stbbdev 
4351c0b2f7Stbbdev public:
4451c0b2f7Stbbdev 
4551c0b2f7Stbbdev     test_push_receiver(tbb::flow::graph& g) : my_graph(g) {
4651c0b2f7Stbbdev         for (int i = 0; i < N; ++i )
4751c0b2f7Stbbdev             my_counters[i] = 0;
4851c0b2f7Stbbdev     }
4951c0b2f7Stbbdev 
5051c0b2f7Stbbdev     int get_count( int i ) {
5151c0b2f7Stbbdev         int v = my_counters[i];
5251c0b2f7Stbbdev         return v;
5351c0b2f7Stbbdev     }
5451c0b2f7Stbbdev 
5551c0b2f7Stbbdev     typedef typename tbb::flow::receiver<T>::predecessor_type predecessor_type;
5651c0b2f7Stbbdev 
5751c0b2f7Stbbdev     graph_task* try_put_task( const T &v ) override {
5851c0b2f7Stbbdev         int i = (int)v;
5951c0b2f7Stbbdev         ++my_counters[i];
6051c0b2f7Stbbdev         return const_cast<graph_task*>(SUCCESSFULLY_ENQUEUED);
6151c0b2f7Stbbdev     }
6251c0b2f7Stbbdev 
6351c0b2f7Stbbdev     tbb::flow::graph& graph_reference() const override {
6451c0b2f7Stbbdev         return my_graph;
6551c0b2f7Stbbdev     }
6651c0b2f7Stbbdev };
6751c0b2f7Stbbdev 
6851c0b2f7Stbbdev template< typename T >
6951c0b2f7Stbbdev class my_input_body {
7051c0b2f7Stbbdev 
7151c0b2f7Stbbdev     unsigned my_count;
7251c0b2f7Stbbdev     int *ninvocations;
7351c0b2f7Stbbdev 
7451c0b2f7Stbbdev public:
7551c0b2f7Stbbdev 
7651c0b2f7Stbbdev     my_input_body() : ninvocations(NULL) { my_count = 0; }
7751c0b2f7Stbbdev     my_input_body(int &_inv) : ninvocations(&_inv)  { my_count = 0; }
7851c0b2f7Stbbdev 
7951c0b2f7Stbbdev     T operator()( tbb::flow_control& fc ) {
8051c0b2f7Stbbdev         T v = (T)my_count++;
8151c0b2f7Stbbdev         if(ninvocations) ++(*ninvocations);
8251c0b2f7Stbbdev         if ( (int)v < N ){
8351c0b2f7Stbbdev             return v;
8451c0b2f7Stbbdev         }else{
8551c0b2f7Stbbdev             fc.stop();
8651c0b2f7Stbbdev             return T();
8751c0b2f7Stbbdev         }
8851c0b2f7Stbbdev     }
8951c0b2f7Stbbdev 
9051c0b2f7Stbbdev };
9151c0b2f7Stbbdev 
9251c0b2f7Stbbdev template< typename T >
9351c0b2f7Stbbdev class function_body {
9451c0b2f7Stbbdev 
9551c0b2f7Stbbdev     std::atomic<int> *my_counters;
9651c0b2f7Stbbdev 
9751c0b2f7Stbbdev public:
9851c0b2f7Stbbdev 
9951c0b2f7Stbbdev     function_body( std::atomic<int> *counters ) : my_counters(counters) {
10051c0b2f7Stbbdev         for (int i = 0; i < N; ++i )
10151c0b2f7Stbbdev             my_counters[i] = 0;
10251c0b2f7Stbbdev     }
10351c0b2f7Stbbdev 
10451c0b2f7Stbbdev     bool operator()( T v ) {
10551c0b2f7Stbbdev         ++my_counters[(int)v];
10651c0b2f7Stbbdev         return true;
10751c0b2f7Stbbdev     }
10851c0b2f7Stbbdev 
10951c0b2f7Stbbdev };
11051c0b2f7Stbbdev 
11151c0b2f7Stbbdev template< typename T >
11251c0b2f7Stbbdev void test_single_dest() {
11351c0b2f7Stbbdev     // push only
11451c0b2f7Stbbdev     tbb::flow::graph g;
11551c0b2f7Stbbdev     tbb::flow::input_node<T> src(g, my_input_body<T>() );
11651c0b2f7Stbbdev     test_push_receiver<T> dest(g);
11751c0b2f7Stbbdev     tbb::flow::make_edge( src, dest );
11851c0b2f7Stbbdev     src.activate();
11951c0b2f7Stbbdev     g.wait_for_all();
12051c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
12151c0b2f7Stbbdev         CHECK_MESSAGE( dest.get_count(i) == 1, "" );
12251c0b2f7Stbbdev     }
12351c0b2f7Stbbdev 
12451c0b2f7Stbbdev     // push only
12551c0b2f7Stbbdev     std::atomic<int> counters3[N];
12651c0b2f7Stbbdev     tbb::flow::input_node<T> src3(g, my_input_body<T>() );
12751c0b2f7Stbbdev     src3.activate();
12851c0b2f7Stbbdev 
12951c0b2f7Stbbdev     function_body<T> b3( counters3 );
13051c0b2f7Stbbdev     tbb::flow::function_node<T,bool> dest3(g, tbb::flow::unlimited, b3 );
13151c0b2f7Stbbdev     tbb::flow::make_edge( src3, dest3 );
13251c0b2f7Stbbdev     g.wait_for_all();
13351c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
13451c0b2f7Stbbdev         int v = counters3[i];
13551c0b2f7Stbbdev         CHECK_MESSAGE( v == 1, "" );
13651c0b2f7Stbbdev     }
13751c0b2f7Stbbdev 
13851c0b2f7Stbbdev     // push & pull
13951c0b2f7Stbbdev     tbb::flow::input_node<T> src2(g, my_input_body<T>() );
14051c0b2f7Stbbdev     src2.activate();
14151c0b2f7Stbbdev     std::atomic<int> counters2[N];
14251c0b2f7Stbbdev 
14351c0b2f7Stbbdev     function_body<T> b2( counters2 );
14451c0b2f7Stbbdev     tbb::flow::function_node<T,bool,tbb::flow::rejecting> dest2(g, tbb::flow::serial, b2 );
14551c0b2f7Stbbdev     tbb::flow::make_edge( src2, dest2 );
14651c0b2f7Stbbdev     g.wait_for_all();
14751c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
14851c0b2f7Stbbdev         int v = counters2[i];
14951c0b2f7Stbbdev         CHECK_MESSAGE( v == 1, "" );
15051c0b2f7Stbbdev     }
15151c0b2f7Stbbdev 
15251c0b2f7Stbbdev     // test copy constructor
15351c0b2f7Stbbdev     tbb::flow::input_node<T> src_copy(src);
15451c0b2f7Stbbdev     src_copy.activate();
15551c0b2f7Stbbdev     test_push_receiver<T> dest_c(g);
15651c0b2f7Stbbdev     CHECK_MESSAGE( src_copy.register_successor(dest_c), "" );
15751c0b2f7Stbbdev     g.wait_for_all();
15851c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
15951c0b2f7Stbbdev         CHECK_MESSAGE( dest_c.get_count(i) == 1, "" );
16051c0b2f7Stbbdev     }
16151c0b2f7Stbbdev }
16251c0b2f7Stbbdev 
16351c0b2f7Stbbdev void test_reset() {
16451c0b2f7Stbbdev     //    input_node -> function_node
16551c0b2f7Stbbdev     tbb::flow::graph g;
16651c0b2f7Stbbdev     std::atomic<int> counters3[N];
16751c0b2f7Stbbdev     tbb::flow::input_node<int> src3(g, my_input_body<int>());
16851c0b2f7Stbbdev     src3.activate();
16951c0b2f7Stbbdev     tbb::flow::input_node<int> src_inactive(g, my_input_body<int>());
17051c0b2f7Stbbdev     function_body<int> b3( counters3 );
17151c0b2f7Stbbdev     tbb::flow::function_node<int,bool> dest3(g, tbb::flow::unlimited, b3);
17251c0b2f7Stbbdev     tbb::flow::make_edge( src3, dest3 );
17351c0b2f7Stbbdev     //    source_node already in active state.  Let the graph run,
17451c0b2f7Stbbdev     g.wait_for_all();
17551c0b2f7Stbbdev     //    check the array for each value.
17651c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
17751c0b2f7Stbbdev         int v = counters3[i];
17851c0b2f7Stbbdev         CHECK_MESSAGE( v == 1, "" );
17951c0b2f7Stbbdev         counters3[i] = 0;
18051c0b2f7Stbbdev     }
18151c0b2f7Stbbdev 
18251c0b2f7Stbbdev     g.reset(tbb::flow::rf_reset_bodies);  // <-- re-initializes the counts.
18351c0b2f7Stbbdev     // and spawns task to run input
18451c0b2f7Stbbdev     src3.activate();
18551c0b2f7Stbbdev 
18651c0b2f7Stbbdev     g.wait_for_all();
18751c0b2f7Stbbdev     //    check output queue again.  Should be the same contents.
18851c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
18951c0b2f7Stbbdev         int v = counters3[i];
19051c0b2f7Stbbdev         CHECK_MESSAGE( v == 1, "" );
19151c0b2f7Stbbdev         counters3[i] = 0;
19251c0b2f7Stbbdev     }
19351c0b2f7Stbbdev     g.reset();  // doesn't reset the input_node_body to initial state, but does spawn a task
19451c0b2f7Stbbdev                 // to run the input_node.
19551c0b2f7Stbbdev 
19651c0b2f7Stbbdev     g.wait_for_all();
19751c0b2f7Stbbdev     // array should be all zero
19851c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
19951c0b2f7Stbbdev         int v = counters3[i];
20051c0b2f7Stbbdev         CHECK_MESSAGE( v == 0, "" );
20151c0b2f7Stbbdev     }
20251c0b2f7Stbbdev 
20351c0b2f7Stbbdev     remove_edge(src3, dest3);
20451c0b2f7Stbbdev     make_edge(src_inactive, dest3);
20551c0b2f7Stbbdev 
20651c0b2f7Stbbdev     // src_inactive doesn't run
20751c0b2f7Stbbdev     g.wait_for_all();
20851c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
20951c0b2f7Stbbdev         int v = counters3[i];
21051c0b2f7Stbbdev         CHECK_MESSAGE( v == 0, "" );
21151c0b2f7Stbbdev     }
21251c0b2f7Stbbdev 
21351c0b2f7Stbbdev     // run graph
21451c0b2f7Stbbdev     src_inactive.activate();
21551c0b2f7Stbbdev     g.wait_for_all();
21651c0b2f7Stbbdev     // check output
21751c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
21851c0b2f7Stbbdev         int v = counters3[i];
21951c0b2f7Stbbdev         CHECK_MESSAGE( v == 1, "" );
22051c0b2f7Stbbdev         counters3[i] = 0;
22151c0b2f7Stbbdev     }
22251c0b2f7Stbbdev     g.reset(tbb::flow::rf_reset_bodies);  // <-- reinitializes the counts
22351c0b2f7Stbbdev     // src_inactive doesn't run
22451c0b2f7Stbbdev     g.wait_for_all();
22551c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
22651c0b2f7Stbbdev         int v = counters3[i];
22751c0b2f7Stbbdev         CHECK_MESSAGE( v == 0, "" );
22851c0b2f7Stbbdev     }
22951c0b2f7Stbbdev 
23051c0b2f7Stbbdev     // start it up
23151c0b2f7Stbbdev     src_inactive.activate();
23251c0b2f7Stbbdev     g.wait_for_all();
23351c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
23451c0b2f7Stbbdev         int v = counters3[i];
23551c0b2f7Stbbdev         CHECK_MESSAGE( v == 1, "" );
23651c0b2f7Stbbdev         counters3[i] = 0;
23751c0b2f7Stbbdev     }
23851c0b2f7Stbbdev     g.reset();  // doesn't reset the input_node_body to initial state, and doesn't
23951c0b2f7Stbbdev                 // spawn a task to run the input_node.
24051c0b2f7Stbbdev 
24151c0b2f7Stbbdev     g.wait_for_all();
24251c0b2f7Stbbdev     // array should be all zero
24351c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
24451c0b2f7Stbbdev         int v = counters3[i];
24551c0b2f7Stbbdev         CHECK_MESSAGE( v == 0, "" );
24651c0b2f7Stbbdev     }
24751c0b2f7Stbbdev     src_inactive.activate();
24851c0b2f7Stbbdev     // input_node_body is already in final state, so input_node will not forward a message.
24951c0b2f7Stbbdev     g.wait_for_all();
25051c0b2f7Stbbdev     for (int i = 0; i < N; ++i ) {
25151c0b2f7Stbbdev         int v = counters3[i];
25251c0b2f7Stbbdev         CHECK_MESSAGE( v == 0, "" );
25351c0b2f7Stbbdev     }
25451c0b2f7Stbbdev }
25551c0b2f7Stbbdev 
25651c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
25751c0b2f7Stbbdev int input_body_f(tbb::flow_control&) { return 42; }
25851c0b2f7Stbbdev 
25951c0b2f7Stbbdev void test_deduction_guides() {
26051c0b2f7Stbbdev     using namespace tbb::flow;
26151c0b2f7Stbbdev     graph g;
26251c0b2f7Stbbdev 
26351c0b2f7Stbbdev     auto lambda = [](tbb::flow_control&) { return 42; };
26451c0b2f7Stbbdev     auto non_const_lambda = [](tbb::flow_control&) mutable { return 42; };
26551c0b2f7Stbbdev 
26651c0b2f7Stbbdev     // Tests for input_node(graph&, Body)
26751c0b2f7Stbbdev     input_node s1(g, lambda);
26851c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(s1), input_node<int>>);
26951c0b2f7Stbbdev 
27051c0b2f7Stbbdev     input_node s2(g, non_const_lambda);
27151c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(s2), input_node<int>>);
27251c0b2f7Stbbdev 
27351c0b2f7Stbbdev     input_node s3(g, input_body_f);
27451c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(s3), input_node<int>>);
27551c0b2f7Stbbdev 
27651c0b2f7Stbbdev     input_node s4(s3);
27751c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(s4), input_node<int>>);
27851c0b2f7Stbbdev 
27951c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
28051c0b2f7Stbbdev     broadcast_node<int> bc(g);
28151c0b2f7Stbbdev 
28251c0b2f7Stbbdev     // Tests for input_node(const node_set<Args...>&, Body)
28351c0b2f7Stbbdev     input_node s5(precedes(bc), lambda);
28451c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(s5), input_node<int>>);
28551c0b2f7Stbbdev 
28651c0b2f7Stbbdev     input_node s6(precedes(bc), non_const_lambda);
28751c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(s6), input_node<int>>);
28851c0b2f7Stbbdev 
28951c0b2f7Stbbdev     input_node s7(precedes(bc), input_body_f);
29051c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(s7), input_node<int>>);
29151c0b2f7Stbbdev #endif
29251c0b2f7Stbbdev     g.wait_for_all();
29351c0b2f7Stbbdev }
29451c0b2f7Stbbdev 
29551c0b2f7Stbbdev #endif // __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
29651c0b2f7Stbbdev 
29751c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
29851c0b2f7Stbbdev #include <array>
29951c0b2f7Stbbdev void test_follows_and_precedes_api() {
30051c0b2f7Stbbdev     using namespace tbb::flow;
30151c0b2f7Stbbdev 
30251c0b2f7Stbbdev     graph g;
30351c0b2f7Stbbdev 
30451c0b2f7Stbbdev     std::array<buffer_node<bool>, 3> successors {{
30551c0b2f7Stbbdev                                                   buffer_node<bool>(g),
30651c0b2f7Stbbdev                                                   buffer_node<bool>(g),
30751c0b2f7Stbbdev                                                   buffer_node<bool>(g)
30851c0b2f7Stbbdev         }};
30951c0b2f7Stbbdev 
31051c0b2f7Stbbdev     bool do_try_put = true;
31151c0b2f7Stbbdev     input_node<bool> src(
31251c0b2f7Stbbdev         precedes(successors[0], successors[1], successors[2]),
31351c0b2f7Stbbdev         [&](tbb::flow_control& fc) -> bool {
31451c0b2f7Stbbdev             if(!do_try_put)
31551c0b2f7Stbbdev                 fc.stop();
31651c0b2f7Stbbdev             do_try_put = !do_try_put;
31751c0b2f7Stbbdev             return true;
31851c0b2f7Stbbdev         }
31951c0b2f7Stbbdev     );
32051c0b2f7Stbbdev 
32151c0b2f7Stbbdev     src.activate();
32251c0b2f7Stbbdev     g.wait_for_all();
32351c0b2f7Stbbdev 
32451c0b2f7Stbbdev     bool storage;
32551c0b2f7Stbbdev     for(auto& successor: successors) {
32651c0b2f7Stbbdev         CHECK_MESSAGE((successor.try_get(storage) && !successor.try_get(storage)),
32751c0b2f7Stbbdev                       "Not exact edge quantity was made");
32851c0b2f7Stbbdev     }
32951c0b2f7Stbbdev }
33051c0b2f7Stbbdev #endif // __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
33151c0b2f7Stbbdev 
33251c0b2f7Stbbdev //! Test push, push-pull behavior and copy constructor
33351c0b2f7Stbbdev //! \brief \ref error_guessing \ref requirement
33451c0b2f7Stbbdev TEST_CASE("Single destination tests"){
33551c0b2f7Stbbdev     for ( unsigned int p = utils::MinThread; p < utils::MaxThread; ++p ) {
33651c0b2f7Stbbdev         tbb::task_arena arena(p);
33751c0b2f7Stbbdev         arena.execute(
33851c0b2f7Stbbdev             [&]() {
33951c0b2f7Stbbdev                 test_single_dest<int>();
34051c0b2f7Stbbdev                 test_single_dest<float>();
34151c0b2f7Stbbdev             }
34251c0b2f7Stbbdev         );
34351c0b2f7Stbbdev 	}
34451c0b2f7Stbbdev }
34551c0b2f7Stbbdev 
34651c0b2f7Stbbdev //! Test reset variants
34751c0b2f7Stbbdev //! \brief \ref error_guessing
34851c0b2f7Stbbdev TEST_CASE("Reset test"){
34951c0b2f7Stbbdev     test_reset();
35051c0b2f7Stbbdev }
35151c0b2f7Stbbdev 
35251c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
35351c0b2f7Stbbdev //! Test follows and precedes API
35451c0b2f7Stbbdev //! \brief \ref error_guessing
35551c0b2f7Stbbdev TEST_CASE("Follows and precedes API"){
35651c0b2f7Stbbdev     test_follows_and_precedes_api();
35751c0b2f7Stbbdev }
35851c0b2f7Stbbdev #endif
35951c0b2f7Stbbdev 
36051c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
36151c0b2f7Stbbdev //! Test deduction guides
36251c0b2f7Stbbdev //! \brief \ref requirement
36351c0b2f7Stbbdev TEST_CASE("Deduction guides"){
36451c0b2f7Stbbdev     test_deduction_guides();
36551c0b2f7Stbbdev }
36651c0b2f7Stbbdev #endif
36751c0b2f7Stbbdev 
36851c0b2f7Stbbdev //! Test try_get before activation
36951c0b2f7Stbbdev //! \brief \ref error_guessing
37051c0b2f7Stbbdev TEST_CASE("try_get before activation"){
37151c0b2f7Stbbdev     tbb::flow::graph g;
37251c0b2f7Stbbdev     tbb::flow::input_node<int> in(g, [&](tbb::flow_control& fc) -> bool { fc.stop(); return 0;});
37351c0b2f7Stbbdev 
37451c0b2f7Stbbdev     int tmp = -1;
37551c0b2f7Stbbdev     CHECK_MESSAGE((in.try_get(tmp) == false), "try_get before activation should not succeed");
37651c0b2f7Stbbdev }
377