1 /*
2 Copyright (c) 2005-2023 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 #include "common/config.h"
18
19 #include "tbb/flow_graph.h"
20
21 #include "common/test.h"
22 #include "common/utils.h"
23 #include "common/test_follows_and_precedes_api.h"
24
25 #include <atomic>
26
27
28 //! \file test_broadcast_node.cpp
29 //! \brief Test for [flow_graph.broadcast_node] specification
30
31
32 #define TBB_INTERNAL_NAMESPACE detail::d1
33 namespace tbb {
34 using task = TBB_INTERNAL_NAMESPACE::graph_task;
35 }
36 using tbb::TBB_INTERNAL_NAMESPACE::SUCCESSFULLY_ENQUEUED;
37
38 const int N = 1000;
39 const int R = 4;
40
41 class int_convertable_type : private utils::NoAssign {
42
43 int my_value;
44
45 public:
46
int_convertable_type(int v)47 int_convertable_type( int v ) : my_value(v) {}
operator int() const48 operator int() const { return my_value; }
49
50 };
51
52
53 template< typename T >
54 class counting_array_receiver : public tbb::flow::receiver<T> {
55
56 std::atomic<size_t> my_counters[N];
57 tbb::flow::graph& my_graph;
58
59 public:
60
counting_array_receiver(tbb::flow::graph & g)61 counting_array_receiver(tbb::flow::graph& g) : my_graph(g) {
62 for (int i = 0; i < N; ++i )
63 my_counters[i] = 0;
64 }
65
operator [](int i)66 size_t operator[]( int i ) {
67 size_t v = my_counters[i];
68 return v;
69 }
70
try_put_task(const T & v)71 tbb::task * try_put_task( const T &v ) override {
72 ++my_counters[(int)v];
73 return const_cast<tbb::task *>(SUCCESSFULLY_ENQUEUED);
74 }
75
graph_reference() const76 tbb::flow::graph& graph_reference() const override {
77 return my_graph;
78 }
79 };
80
81 template< typename T >
test_serial_broadcasts()82 void test_serial_broadcasts() {
83
84 tbb::flow::graph g;
85 tbb::flow::broadcast_node<T> b(g);
86
87 for ( int num_receivers = 1; num_receivers < R; ++num_receivers ) {
88 std::vector< std::shared_ptr<counting_array_receiver<T>> > receivers;
89 for( int i = 0; i < num_receivers; ++i )
90 receivers.push_back( std::make_shared<counting_array_receiver<T>>(g) );
91
92 for ( int r = 0; r < num_receivers; ++r ) {
93 tbb::flow::make_edge( b, *receivers[r] );
94 }
95
96 for (int n = 0; n < N; ++n ) {
97 CHECK_MESSAGE( b.try_put( (T)n ), "" );
98 }
99
100 for ( int r = 0; r < num_receivers; ++r ) {
101 for (int n = 0; n < N; ++n ) {
102 CHECK_MESSAGE( (*receivers[r])[n] == 1, "" );
103 }
104 tbb::flow::remove_edge( b, *receivers[r] );
105 }
106 CHECK_MESSAGE( b.try_put( (T)0 ), "" );
107 for ( int r = 0; r < num_receivers; ++r )
108 CHECK_MESSAGE( (*receivers[0])[0] == 1, "" );
109 }
110
111 }
112
113 template< typename T >
114 class native_body : private utils::NoAssign {
115
116 tbb::flow::broadcast_node<T> &my_b;
117
118 public:
119
native_body(tbb::flow::broadcast_node<T> & b)120 native_body( tbb::flow::broadcast_node<T> &b ) : my_b(b) {}
121
operator ()(int) const122 void operator()(int) const {
123 for (int n = 0; n < N; ++n ) {
124 CHECK_MESSAGE( my_b.try_put( (T)n ), "" );
125 }
126 }
127
128 };
129
130 template< typename T >
run_parallel_broadcasts(tbb::flow::graph & g,int p,tbb::flow::broadcast_node<T> & b)131 void run_parallel_broadcasts(tbb::flow::graph& g, int p, tbb::flow::broadcast_node<T>& b) {
132 for ( int num_receivers = 1; num_receivers < R; ++num_receivers ) {
133 std::vector< std::shared_ptr<counting_array_receiver<T>> > receivers;
134 for( int i = 0; i < num_receivers; ++i )
135 receivers.push_back( std::make_shared< counting_array_receiver<T> >(g) );
136
137 for ( int r = 0; r < num_receivers; ++r ) {
138 tbb::flow::make_edge( b, *receivers[r] );
139 }
140
141 utils::NativeParallelFor( p, native_body<T>( b ) );
142
143 for ( int r = 0; r < num_receivers; ++r ) {
144 for (int n = 0; n < N; ++n ) {
145 CHECK_MESSAGE( (int)(*receivers[r])[n] == p, "" );
146 }
147 tbb::flow::remove_edge( b, *receivers[r] );
148 }
149 CHECK_MESSAGE( b.try_put( (T)0 ), "" );
150 for ( int r = 0; r < num_receivers; ++r )
151 CHECK_MESSAGE( (int)(*receivers[r])[0] == p, "" );
152 }
153 }
154
155 template< typename T >
test_parallel_broadcasts(int p)156 void test_parallel_broadcasts(int p) {
157
158 tbb::flow::graph g;
159 tbb::flow::broadcast_node<T> b(g);
160 run_parallel_broadcasts(g, p, b);
161
162 // test copy constructor
163 tbb::flow::broadcast_node<T> b_copy(b);
164 run_parallel_broadcasts(g, p, b_copy);
165 }
166
167 // broadcast_node does not allow successors to try_get from it (it does not allow
168 // the flow edge to switch) so we only need test the forward direction.
169 template<typename T>
test_resets()170 void test_resets() {
171 tbb::flow::graph g;
172 tbb::flow::broadcast_node<T> b0(g);
173 tbb::flow::broadcast_node<T> b1(g);
174 tbb::flow::queue_node<T> q0(g);
175 tbb::flow::make_edge(b0,b1);
176 tbb::flow::make_edge(b1,q0);
177 T j;
178
179 // test standard reset
180 for(int testNo = 0; testNo < 2; ++testNo) {
181 for(T i= 0; i <= 3; i += 1) {
182 b0.try_put(i);
183 }
184 g.wait_for_all();
185 for(T i= 0; i <= 3; i += 1) {
186 CHECK_MESSAGE( (q0.try_get(j) && j == i), "Bad value in queue");
187 }
188 CHECK_MESSAGE( (!q0.try_get(j)), "extra value in queue");
189
190 // reset the graph. It should work as before.
191 if (testNo == 0) g.reset();
192 }
193
194 g.reset(tbb::flow::rf_clear_edges);
195 for(T i= 0; i <= 3; i += 1) {
196 b0.try_put(i);
197 }
198 g.wait_for_all();
199 CHECK_MESSAGE( (!q0.try_get(j)), "edge between nodes not removed");
200 for(T i= 0; i <= 3; i += 1) {
201 b1.try_put(i);
202 }
203 g.wait_for_all();
204 CHECK_MESSAGE( (!q0.try_get(j)), "edge between nodes not removed");
205 }
206
207 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
208 #include <array>
209 #include <vector>
test_follows_and_precedes_api()210 void test_follows_and_precedes_api() {
211 using msg_t = tbb::flow::continue_msg;
212
213 std::array<msg_t, 3> messages_for_follows= { {msg_t(), msg_t(), msg_t()} };
214 std::vector<msg_t> messages_for_precedes = {msg_t()};
215
216 follows_and_precedes_testing::test_follows <msg_t, tbb::flow::broadcast_node<msg_t>>(messages_for_follows);
217 follows_and_precedes_testing::test_precedes <msg_t, tbb::flow::broadcast_node<msg_t>>(messages_for_precedes);
218 }
219 #endif
220
221 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
test_deduction_guides()222 void test_deduction_guides() {
223 using namespace tbb::flow;
224
225 graph g;
226
227 broadcast_node<int> b0(g);
228 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
229 buffer_node<int> buf(g);
230
231 broadcast_node b1(follows(buf));
232 static_assert(std::is_same_v<decltype(b1), broadcast_node<int>>);
233
234 broadcast_node b2(precedes(buf));
235 static_assert(std::is_same_v<decltype(b2), broadcast_node<int>>);
236 #endif
237
238 broadcast_node b3(b0);
239 static_assert(std::is_same_v<decltype(b3), broadcast_node<int>>);
240 g.wait_for_all();
241 }
242 #endif
243
244 //! Test serial broadcasts
245 //! \brief \ref error_guessing
246 TEST_CASE("Serial broadcasts"){
247 test_serial_broadcasts<int>();
248 test_serial_broadcasts<float>();
249 test_serial_broadcasts<int_convertable_type>();
250 }
251
252 //! Test parallel broadcasts
253 //! \brief \ref error_guessing
254 TEST_CASE("Parallel broadcasts"){
255 for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) {
256 test_parallel_broadcasts<int>(p);
257 test_parallel_broadcasts<float>(p);
258 test_parallel_broadcasts<int_convertable_type>(p);
259 }
260 }
261
262 //! Test reset and cancellation behavior
263 //! \brief \ref error_guessing
264 TEST_CASE("Resets"){
265 test_resets<int>();
266 test_resets<float>();
267 }
268
269 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
270 //! Test deprecated follows and precedes API
271 //! \brief \ref error_guessing
272 TEST_CASE("Follows and precedes API"){
273 test_follows_and_precedes_api();
274 }
275 #endif
276
277 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
278 //! Test deduction guides
279 //! \brief requirement
280 TEST_CASE("Deduction guides"){
281 test_deduction_guides();
282 }
283 #endif
284
285