xref: /oneTBB/test/tbb/test_broadcast_node.cpp (revision c4a799df)
1 /*
2     Copyright (c) 2005-2023 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #include "common/config.h"
18 
19 #include "tbb/flow_graph.h"
20 
21 #include "common/test.h"
22 #include "common/utils.h"
23 #include "common/test_follows_and_precedes_api.h"
24 
25 #include <atomic>
26 
27 
28 //! \file test_broadcast_node.cpp
29 //! \brief Test for [flow_graph.broadcast_node] specification
30 
31 
32 #define TBB_INTERNAL_NAMESPACE detail::d1
33 namespace tbb {
34 using task = TBB_INTERNAL_NAMESPACE::graph_task;
35 }
36 using tbb::TBB_INTERNAL_NAMESPACE::SUCCESSFULLY_ENQUEUED;
37 
38 const int N = 1000;
39 const int R = 4;
40 
41 class int_convertable_type : private utils::NoAssign {
42 
43    int my_value;
44 
45 public:
46 
int_convertable_type(int v)47    int_convertable_type( int v ) : my_value(v) {}
operator int() const48    operator int() const { return my_value; }
49 
50 };
51 
52 
53 template< typename T >
54 class counting_array_receiver : public tbb::flow::receiver<T> {
55 
56     std::atomic<size_t> my_counters[N];
57     tbb::flow::graph& my_graph;
58 
59 public:
60 
counting_array_receiver(tbb::flow::graph & g)61     counting_array_receiver(tbb::flow::graph& g) : my_graph(g) {
62         for (int i = 0; i < N; ++i )
63            my_counters[i] = 0;
64     }
65 
operator [](int i)66     size_t operator[]( int i ) {
67         size_t v = my_counters[i];
68         return v;
69     }
70 
try_put_task(const T & v)71     tbb::task * try_put_task( const T &v ) override {
72         ++my_counters[(int)v];
73         return const_cast<tbb::task *>(SUCCESSFULLY_ENQUEUED);
74     }
75 
graph_reference() const76     tbb::flow::graph& graph_reference() const override {
77         return my_graph;
78     }
79 };
80 
81 template< typename T >
test_serial_broadcasts()82 void test_serial_broadcasts() {
83 
84     tbb::flow::graph g;
85     tbb::flow::broadcast_node<T> b(g);
86 
87     for ( int num_receivers = 1; num_receivers < R; ++num_receivers ) {
88         std::vector< std::shared_ptr<counting_array_receiver<T>> > receivers;
89         for( int i = 0; i < num_receivers; ++i )
90             receivers.push_back( std::make_shared<counting_array_receiver<T>>(g) );
91 
92         for ( int r = 0; r < num_receivers; ++r ) {
93             tbb::flow::make_edge( b, *receivers[r] );
94         }
95 
96         for (int n = 0; n < N; ++n ) {
97             CHECK_MESSAGE( b.try_put( (T)n ), "" );
98         }
99 
100         for ( int r = 0; r < num_receivers; ++r ) {
101             for (int n = 0; n < N; ++n ) {
102                 CHECK_MESSAGE( (*receivers[r])[n] == 1, "" );
103             }
104             tbb::flow::remove_edge( b, *receivers[r] );
105         }
106         CHECK_MESSAGE( b.try_put( (T)0 ), "" );
107         for ( int r = 0; r < num_receivers; ++r )
108             CHECK_MESSAGE( (*receivers[0])[0] == 1, "" );
109     }
110 
111 }
112 
113 template< typename T >
114 class native_body : private utils::NoAssign {
115 
116     tbb::flow::broadcast_node<T> &my_b;
117 
118 public:
119 
native_body(tbb::flow::broadcast_node<T> & b)120     native_body( tbb::flow::broadcast_node<T> &b ) : my_b(b) {}
121 
operator ()(int) const122     void operator()(int) const {
123         for (int n = 0; n < N; ++n ) {
124             CHECK_MESSAGE( my_b.try_put( (T)n ), "" );
125         }
126     }
127 
128 };
129 
130 template< typename T >
run_parallel_broadcasts(tbb::flow::graph & g,int p,tbb::flow::broadcast_node<T> & b)131 void run_parallel_broadcasts(tbb::flow::graph& g, int p, tbb::flow::broadcast_node<T>& b) {
132     for ( int num_receivers = 1; num_receivers < R; ++num_receivers ) {
133         std::vector< std::shared_ptr<counting_array_receiver<T>> > receivers;
134         for( int i = 0; i < num_receivers; ++i )
135             receivers.push_back( std::make_shared< counting_array_receiver<T> >(g) );
136 
137         for ( int r = 0; r < num_receivers; ++r ) {
138             tbb::flow::make_edge( b, *receivers[r] );
139         }
140 
141         utils::NativeParallelFor( p, native_body<T>( b ) );
142 
143         for ( int r = 0; r < num_receivers; ++r ) {
144             for (int n = 0; n < N; ++n ) {
145                 CHECK_MESSAGE( (int)(*receivers[r])[n] == p, "" );
146             }
147             tbb::flow::remove_edge( b, *receivers[r] );
148         }
149         CHECK_MESSAGE( b.try_put( (T)0 ), "" );
150         for ( int r = 0; r < num_receivers; ++r )
151             CHECK_MESSAGE( (int)(*receivers[r])[0] == p, "" );
152     }
153 }
154 
155 template< typename T >
test_parallel_broadcasts(int p)156 void test_parallel_broadcasts(int p) {
157 
158     tbb::flow::graph g;
159     tbb::flow::broadcast_node<T> b(g);
160     run_parallel_broadcasts(g, p, b);
161 
162     // test copy constructor
163     tbb::flow::broadcast_node<T> b_copy(b);
164     run_parallel_broadcasts(g, p, b_copy);
165 }
166 
167 // broadcast_node does not allow successors to try_get from it (it does not allow
168 // the flow edge to switch) so we only need test the forward direction.
169 template<typename T>
test_resets()170 void test_resets() {
171     tbb::flow::graph g;
172     tbb::flow::broadcast_node<T> b0(g);
173     tbb::flow::broadcast_node<T> b1(g);
174     tbb::flow::queue_node<T> q0(g);
175     tbb::flow::make_edge(b0,b1);
176     tbb::flow::make_edge(b1,q0);
177     T j;
178 
179     // test standard reset
180     for(int testNo = 0; testNo < 2; ++testNo) {
181         for(T i= 0; i <= 3; i += 1) {
182             b0.try_put(i);
183         }
184         g.wait_for_all();
185         for(T i= 0; i <= 3; i += 1) {
186             CHECK_MESSAGE( (q0.try_get(j) && j == i), "Bad value in queue");
187         }
188         CHECK_MESSAGE( (!q0.try_get(j)), "extra value in queue");
189 
190         // reset the graph.  It should work as before.
191         if (testNo == 0) g.reset();
192     }
193 
194     g.reset(tbb::flow::rf_clear_edges);
195     for(T i= 0; i <= 3; i += 1) {
196         b0.try_put(i);
197     }
198     g.wait_for_all();
199     CHECK_MESSAGE( (!q0.try_get(j)), "edge between nodes not removed");
200     for(T i= 0; i <= 3; i += 1) {
201         b1.try_put(i);
202     }
203     g.wait_for_all();
204     CHECK_MESSAGE( (!q0.try_get(j)), "edge between nodes not removed");
205 }
206 
207 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
208 #include <array>
209 #include <vector>
test_follows_and_precedes_api()210 void test_follows_and_precedes_api() {
211     using msg_t = tbb::flow::continue_msg;
212 
213     std::array<msg_t, 3> messages_for_follows= { {msg_t(), msg_t(), msg_t()} };
214     std::vector<msg_t> messages_for_precedes = {msg_t()};
215 
216     follows_and_precedes_testing::test_follows <msg_t, tbb::flow::broadcast_node<msg_t>>(messages_for_follows);
217     follows_and_precedes_testing::test_precedes <msg_t, tbb::flow::broadcast_node<msg_t>>(messages_for_precedes);
218 }
219 #endif
220 
221 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
test_deduction_guides()222 void test_deduction_guides() {
223     using namespace tbb::flow;
224 
225     graph g;
226 
227     broadcast_node<int> b0(g);
228 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
229     buffer_node<int> buf(g);
230 
231     broadcast_node b1(follows(buf));
232     static_assert(std::is_same_v<decltype(b1), broadcast_node<int>>);
233 
234     broadcast_node b2(precedes(buf));
235     static_assert(std::is_same_v<decltype(b2), broadcast_node<int>>);
236 #endif
237 
238     broadcast_node b3(b0);
239     static_assert(std::is_same_v<decltype(b3), broadcast_node<int>>);
240     g.wait_for_all();
241 }
242 #endif
243 
244 //! Test serial broadcasts
245 //! \brief \ref error_guessing
246 TEST_CASE("Serial broadcasts"){
247    test_serial_broadcasts<int>();
248    test_serial_broadcasts<float>();
249    test_serial_broadcasts<int_convertable_type>();
250 }
251 
252 //! Test parallel broadcasts
253 //! \brief \ref error_guessing
254 TEST_CASE("Parallel broadcasts"){
255     for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) {
256        test_parallel_broadcasts<int>(p);
257        test_parallel_broadcasts<float>(p);
258        test_parallel_broadcasts<int_convertable_type>(p);
259    }
260 }
261 
262 //! Test reset and cancellation behavior
263 //! \brief \ref error_guessing
264 TEST_CASE("Resets"){
265    test_resets<int>();
266    test_resets<float>();
267 }
268 
269 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
270 //! Test deprecated follows and precedes API
271 //! \brief \ref error_guessing
272 TEST_CASE("Follows and precedes API"){
273     test_follows_and_precedes_api();
274 }
275 #endif
276 
277 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
278 //! Test deduction guides
279 //! \brief requirement
280 TEST_CASE("Deduction guides"){
281     test_deduction_guides();
282 }
283 #endif
284 
285