xref: /oneTBB/test/tbb/test_broadcast_node.cpp (revision c4a799df)
151c0b2f7Stbbdev /*
2*c4a799dfSJhaShweta1     Copyright (c) 2005-2023 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
1751c0b2f7Stbbdev #include "common/config.h"
1851c0b2f7Stbbdev 
1951c0b2f7Stbbdev #include "tbb/flow_graph.h"
2051c0b2f7Stbbdev 
2151c0b2f7Stbbdev #include "common/test.h"
2251c0b2f7Stbbdev #include "common/utils.h"
2351c0b2f7Stbbdev #include "common/test_follows_and_precedes_api.h"
2451c0b2f7Stbbdev 
2551c0b2f7Stbbdev #include <atomic>
2651c0b2f7Stbbdev 
2751c0b2f7Stbbdev 
2851c0b2f7Stbbdev //! \file test_broadcast_node.cpp
2951c0b2f7Stbbdev //! \brief Test for [flow_graph.broadcast_node] specification
3051c0b2f7Stbbdev 
3151c0b2f7Stbbdev 
3251c0b2f7Stbbdev #define TBB_INTERNAL_NAMESPACE detail::d1
3351c0b2f7Stbbdev namespace tbb {
3451c0b2f7Stbbdev using task = TBB_INTERNAL_NAMESPACE::graph_task;
3551c0b2f7Stbbdev }
3651c0b2f7Stbbdev using tbb::TBB_INTERNAL_NAMESPACE::SUCCESSFULLY_ENQUEUED;
3751c0b2f7Stbbdev 
3851c0b2f7Stbbdev const int N = 1000;
3951c0b2f7Stbbdev const int R = 4;
4051c0b2f7Stbbdev 
4151c0b2f7Stbbdev class int_convertable_type : private utils::NoAssign {
4251c0b2f7Stbbdev 
4351c0b2f7Stbbdev    int my_value;
4451c0b2f7Stbbdev 
4551c0b2f7Stbbdev public:
4651c0b2f7Stbbdev 
int_convertable_type(int v)4751c0b2f7Stbbdev    int_convertable_type( int v ) : my_value(v) {}
operator int() const4851c0b2f7Stbbdev    operator int() const { return my_value; }
4951c0b2f7Stbbdev 
5051c0b2f7Stbbdev };
5151c0b2f7Stbbdev 
5251c0b2f7Stbbdev 
5351c0b2f7Stbbdev template< typename T >
5451c0b2f7Stbbdev class counting_array_receiver : public tbb::flow::receiver<T> {
5551c0b2f7Stbbdev 
5651c0b2f7Stbbdev     std::atomic<size_t> my_counters[N];
5751c0b2f7Stbbdev     tbb::flow::graph& my_graph;
5851c0b2f7Stbbdev 
5951c0b2f7Stbbdev public:
6051c0b2f7Stbbdev 
counting_array_receiver(tbb::flow::graph & g)6151c0b2f7Stbbdev     counting_array_receiver(tbb::flow::graph& g) : my_graph(g) {
6251c0b2f7Stbbdev         for (int i = 0; i < N; ++i )
6351c0b2f7Stbbdev            my_counters[i] = 0;
6451c0b2f7Stbbdev     }
6551c0b2f7Stbbdev 
operator [](int i)6651c0b2f7Stbbdev     size_t operator[]( int i ) {
6751c0b2f7Stbbdev         size_t v = my_counters[i];
6851c0b2f7Stbbdev         return v;
6951c0b2f7Stbbdev     }
7051c0b2f7Stbbdev 
try_put_task(const T & v)7151c0b2f7Stbbdev     tbb::task * try_put_task( const T &v ) override {
7251c0b2f7Stbbdev         ++my_counters[(int)v];
7351c0b2f7Stbbdev         return const_cast<tbb::task *>(SUCCESSFULLY_ENQUEUED);
7451c0b2f7Stbbdev     }
7551c0b2f7Stbbdev 
graph_reference() const7651c0b2f7Stbbdev     tbb::flow::graph& graph_reference() const override {
7751c0b2f7Stbbdev         return my_graph;
7851c0b2f7Stbbdev     }
7951c0b2f7Stbbdev };
8051c0b2f7Stbbdev 
8151c0b2f7Stbbdev template< typename T >
test_serial_broadcasts()8251c0b2f7Stbbdev void test_serial_broadcasts() {
8351c0b2f7Stbbdev 
8451c0b2f7Stbbdev     tbb::flow::graph g;
8551c0b2f7Stbbdev     tbb::flow::broadcast_node<T> b(g);
8651c0b2f7Stbbdev 
8751c0b2f7Stbbdev     for ( int num_receivers = 1; num_receivers < R; ++num_receivers ) {
8851c0b2f7Stbbdev         std::vector< std::shared_ptr<counting_array_receiver<T>> > receivers;
8951c0b2f7Stbbdev         for( int i = 0; i < num_receivers; ++i )
9051c0b2f7Stbbdev             receivers.push_back( std::make_shared<counting_array_receiver<T>>(g) );
9151c0b2f7Stbbdev 
9251c0b2f7Stbbdev         for ( int r = 0; r < num_receivers; ++r ) {
9351c0b2f7Stbbdev             tbb::flow::make_edge( b, *receivers[r] );
9451c0b2f7Stbbdev         }
9551c0b2f7Stbbdev 
9651c0b2f7Stbbdev         for (int n = 0; n < N; ++n ) {
9751c0b2f7Stbbdev             CHECK_MESSAGE( b.try_put( (T)n ), "" );
9851c0b2f7Stbbdev         }
9951c0b2f7Stbbdev 
10051c0b2f7Stbbdev         for ( int r = 0; r < num_receivers; ++r ) {
10151c0b2f7Stbbdev             for (int n = 0; n < N; ++n ) {
10251c0b2f7Stbbdev                 CHECK_MESSAGE( (*receivers[r])[n] == 1, "" );
10351c0b2f7Stbbdev             }
10451c0b2f7Stbbdev             tbb::flow::remove_edge( b, *receivers[r] );
10551c0b2f7Stbbdev         }
10651c0b2f7Stbbdev         CHECK_MESSAGE( b.try_put( (T)0 ), "" );
10751c0b2f7Stbbdev         for ( int r = 0; r < num_receivers; ++r )
10851c0b2f7Stbbdev             CHECK_MESSAGE( (*receivers[0])[0] == 1, "" );
10951c0b2f7Stbbdev     }
11051c0b2f7Stbbdev 
11151c0b2f7Stbbdev }
11251c0b2f7Stbbdev 
11351c0b2f7Stbbdev template< typename T >
11451c0b2f7Stbbdev class native_body : private utils::NoAssign {
11551c0b2f7Stbbdev 
11651c0b2f7Stbbdev     tbb::flow::broadcast_node<T> &my_b;
11751c0b2f7Stbbdev 
11851c0b2f7Stbbdev public:
11951c0b2f7Stbbdev 
native_body(tbb::flow::broadcast_node<T> & b)12051c0b2f7Stbbdev     native_body( tbb::flow::broadcast_node<T> &b ) : my_b(b) {}
12151c0b2f7Stbbdev 
operator ()(int) const12251c0b2f7Stbbdev     void operator()(int) const {
12351c0b2f7Stbbdev         for (int n = 0; n < N; ++n ) {
12451c0b2f7Stbbdev             CHECK_MESSAGE( my_b.try_put( (T)n ), "" );
12551c0b2f7Stbbdev         }
12651c0b2f7Stbbdev     }
12751c0b2f7Stbbdev 
12851c0b2f7Stbbdev };
12951c0b2f7Stbbdev 
13051c0b2f7Stbbdev template< typename T >
run_parallel_broadcasts(tbb::flow::graph & g,int p,tbb::flow::broadcast_node<T> & b)13151c0b2f7Stbbdev void run_parallel_broadcasts(tbb::flow::graph& g, int p, tbb::flow::broadcast_node<T>& b) {
13251c0b2f7Stbbdev     for ( int num_receivers = 1; num_receivers < R; ++num_receivers ) {
13351c0b2f7Stbbdev         std::vector< std::shared_ptr<counting_array_receiver<T>> > receivers;
13451c0b2f7Stbbdev         for( int i = 0; i < num_receivers; ++i )
13551c0b2f7Stbbdev             receivers.push_back( std::make_shared< counting_array_receiver<T> >(g) );
13651c0b2f7Stbbdev 
13751c0b2f7Stbbdev         for ( int r = 0; r < num_receivers; ++r ) {
13851c0b2f7Stbbdev             tbb::flow::make_edge( b, *receivers[r] );
13951c0b2f7Stbbdev         }
14051c0b2f7Stbbdev 
14151c0b2f7Stbbdev         utils::NativeParallelFor( p, native_body<T>( b ) );
14251c0b2f7Stbbdev 
14351c0b2f7Stbbdev         for ( int r = 0; r < num_receivers; ++r ) {
14451c0b2f7Stbbdev             for (int n = 0; n < N; ++n ) {
14551c0b2f7Stbbdev                 CHECK_MESSAGE( (int)(*receivers[r])[n] == p, "" );
14651c0b2f7Stbbdev             }
14751c0b2f7Stbbdev             tbb::flow::remove_edge( b, *receivers[r] );
14851c0b2f7Stbbdev         }
14951c0b2f7Stbbdev         CHECK_MESSAGE( b.try_put( (T)0 ), "" );
15051c0b2f7Stbbdev         for ( int r = 0; r < num_receivers; ++r )
15151c0b2f7Stbbdev             CHECK_MESSAGE( (int)(*receivers[r])[0] == p, "" );
15251c0b2f7Stbbdev     }
15351c0b2f7Stbbdev }
15451c0b2f7Stbbdev 
15551c0b2f7Stbbdev template< typename T >
test_parallel_broadcasts(int p)15651c0b2f7Stbbdev void test_parallel_broadcasts(int p) {
15751c0b2f7Stbbdev 
15851c0b2f7Stbbdev     tbb::flow::graph g;
15951c0b2f7Stbbdev     tbb::flow::broadcast_node<T> b(g);
16051c0b2f7Stbbdev     run_parallel_broadcasts(g, p, b);
16151c0b2f7Stbbdev 
16251c0b2f7Stbbdev     // test copy constructor
16351c0b2f7Stbbdev     tbb::flow::broadcast_node<T> b_copy(b);
16451c0b2f7Stbbdev     run_parallel_broadcasts(g, p, b_copy);
16551c0b2f7Stbbdev }
16651c0b2f7Stbbdev 
16751c0b2f7Stbbdev // broadcast_node does not allow successors to try_get from it (it does not allow
16851c0b2f7Stbbdev // the flow edge to switch) so we only need test the forward direction.
16951c0b2f7Stbbdev template<typename T>
test_resets()17051c0b2f7Stbbdev void test_resets() {
17151c0b2f7Stbbdev     tbb::flow::graph g;
17251c0b2f7Stbbdev     tbb::flow::broadcast_node<T> b0(g);
17351c0b2f7Stbbdev     tbb::flow::broadcast_node<T> b1(g);
17451c0b2f7Stbbdev     tbb::flow::queue_node<T> q0(g);
17551c0b2f7Stbbdev     tbb::flow::make_edge(b0,b1);
17651c0b2f7Stbbdev     tbb::flow::make_edge(b1,q0);
17751c0b2f7Stbbdev     T j;
17851c0b2f7Stbbdev 
17951c0b2f7Stbbdev     // test standard reset
18051c0b2f7Stbbdev     for(int testNo = 0; testNo < 2; ++testNo) {
18151c0b2f7Stbbdev         for(T i= 0; i <= 3; i += 1) {
18251c0b2f7Stbbdev             b0.try_put(i);
18351c0b2f7Stbbdev         }
18451c0b2f7Stbbdev         g.wait_for_all();
18551c0b2f7Stbbdev         for(T i= 0; i <= 3; i += 1) {
18651c0b2f7Stbbdev             CHECK_MESSAGE( (q0.try_get(j) && j == i), "Bad value in queue");
18751c0b2f7Stbbdev         }
18851c0b2f7Stbbdev         CHECK_MESSAGE( (!q0.try_get(j)), "extra value in queue");
18951c0b2f7Stbbdev 
19051c0b2f7Stbbdev         // reset the graph.  It should work as before.
19151c0b2f7Stbbdev         if (testNo == 0) g.reset();
19251c0b2f7Stbbdev     }
19351c0b2f7Stbbdev 
19451c0b2f7Stbbdev     g.reset(tbb::flow::rf_clear_edges);
19551c0b2f7Stbbdev     for(T i= 0; i <= 3; i += 1) {
19651c0b2f7Stbbdev         b0.try_put(i);
19751c0b2f7Stbbdev     }
19851c0b2f7Stbbdev     g.wait_for_all();
19951c0b2f7Stbbdev     CHECK_MESSAGE( (!q0.try_get(j)), "edge between nodes not removed");
20051c0b2f7Stbbdev     for(T i= 0; i <= 3; i += 1) {
20151c0b2f7Stbbdev         b1.try_put(i);
20251c0b2f7Stbbdev     }
20351c0b2f7Stbbdev     g.wait_for_all();
20451c0b2f7Stbbdev     CHECK_MESSAGE( (!q0.try_get(j)), "edge between nodes not removed");
20551c0b2f7Stbbdev }
20651c0b2f7Stbbdev 
20751c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
20851c0b2f7Stbbdev #include <array>
20951c0b2f7Stbbdev #include <vector>
test_follows_and_precedes_api()21051c0b2f7Stbbdev void test_follows_and_precedes_api() {
21151c0b2f7Stbbdev     using msg_t = tbb::flow::continue_msg;
21251c0b2f7Stbbdev 
21351c0b2f7Stbbdev     std::array<msg_t, 3> messages_for_follows= { {msg_t(), msg_t(), msg_t()} };
21451c0b2f7Stbbdev     std::vector<msg_t> messages_for_precedes = {msg_t()};
21551c0b2f7Stbbdev 
21651c0b2f7Stbbdev     follows_and_precedes_testing::test_follows <msg_t, tbb::flow::broadcast_node<msg_t>>(messages_for_follows);
21751c0b2f7Stbbdev     follows_and_precedes_testing::test_precedes <msg_t, tbb::flow::broadcast_node<msg_t>>(messages_for_precedes);
21851c0b2f7Stbbdev }
21951c0b2f7Stbbdev #endif
22051c0b2f7Stbbdev 
22151c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
test_deduction_guides()22251c0b2f7Stbbdev void test_deduction_guides() {
22351c0b2f7Stbbdev     using namespace tbb::flow;
22451c0b2f7Stbbdev 
22551c0b2f7Stbbdev     graph g;
22651c0b2f7Stbbdev 
22751c0b2f7Stbbdev     broadcast_node<int> b0(g);
22851c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
22951c0b2f7Stbbdev     buffer_node<int> buf(g);
23051c0b2f7Stbbdev 
23151c0b2f7Stbbdev     broadcast_node b1(follows(buf));
23251c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(b1), broadcast_node<int>>);
23351c0b2f7Stbbdev 
23451c0b2f7Stbbdev     broadcast_node b2(precedes(buf));
23551c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(b2), broadcast_node<int>>);
23651c0b2f7Stbbdev #endif
23751c0b2f7Stbbdev 
23851c0b2f7Stbbdev     broadcast_node b3(b0);
23951c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(b3), broadcast_node<int>>);
24051c0b2f7Stbbdev     g.wait_for_all();
24151c0b2f7Stbbdev }
24251c0b2f7Stbbdev #endif
24351c0b2f7Stbbdev 
24451c0b2f7Stbbdev //! Test serial broadcasts
24551c0b2f7Stbbdev //! \brief \ref error_guessing
24651c0b2f7Stbbdev TEST_CASE("Serial broadcasts"){
24751c0b2f7Stbbdev    test_serial_broadcasts<int>();
24851c0b2f7Stbbdev    test_serial_broadcasts<float>();
24951c0b2f7Stbbdev    test_serial_broadcasts<int_convertable_type>();
25051c0b2f7Stbbdev }
25151c0b2f7Stbbdev 
25251c0b2f7Stbbdev //! Test parallel broadcasts
25351c0b2f7Stbbdev //! \brief \ref error_guessing
25451c0b2f7Stbbdev TEST_CASE("Parallel broadcasts"){
25551c0b2f7Stbbdev     for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) {
25651c0b2f7Stbbdev        test_parallel_broadcasts<int>(p);
25751c0b2f7Stbbdev        test_parallel_broadcasts<float>(p);
25851c0b2f7Stbbdev        test_parallel_broadcasts<int_convertable_type>(p);
25951c0b2f7Stbbdev    }
26051c0b2f7Stbbdev }
26151c0b2f7Stbbdev 
26251c0b2f7Stbbdev //! Test reset and cancellation behavior
26351c0b2f7Stbbdev //! \brief \ref error_guessing
26451c0b2f7Stbbdev TEST_CASE("Resets"){
26551c0b2f7Stbbdev    test_resets<int>();
26651c0b2f7Stbbdev    test_resets<float>();
26751c0b2f7Stbbdev }
26851c0b2f7Stbbdev 
26951c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
270*c4a799dfSJhaShweta1 //! Test deprecated follows and precedes API
27151c0b2f7Stbbdev //! \brief \ref error_guessing
272*c4a799dfSJhaShweta1 TEST_CASE("Follows and precedes API"){
27351c0b2f7Stbbdev     test_follows_and_precedes_api();
27451c0b2f7Stbbdev }
27551c0b2f7Stbbdev #endif
27651c0b2f7Stbbdev 
27751c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
27851c0b2f7Stbbdev //! Test deduction guides
27951c0b2f7Stbbdev //! \brief requirement
28051c0b2f7Stbbdev TEST_CASE("Deduction guides"){
28151c0b2f7Stbbdev     test_deduction_guides();
28251c0b2f7Stbbdev }
28351c0b2f7Stbbdev #endif
28451c0b2f7Stbbdev 
285