xref: /oneTBB/test/tbb/test_join_node.cpp (revision b15aabb3)
1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #include "common/config.h"
18 
19 #include "test_join_node.h"
20 
21 
22 //! \file test_join_node.cpp
23 //! \brief Test for [flow_graph.join_node] specification
24 
25 
26 static std::atomic<int> output_count;
27 
28 // get the tag from the output tuple and emit it.
29 // the first tuple component is tag * 2 cast to the type
30 template<typename OutputTupleType>
31 class recirc_output_func_body {
32 public:
33     // we only need this to use input_node_helper
34     typedef typename tbb::flow::join_node<OutputTupleType, tbb::flow::tag_matching> join_node_type;
35     static const int N = std::tuple_size<OutputTupleType>::value;
36     int operator()(const OutputTupleType &v) {
37         int out = int(std::get<0>(v))/2;
38         input_node_helper<N, join_node_type>::only_check_value(out, v);
39         ++output_count;
40         return out;
41     }
42 };
43 
44 template<typename JType>
45 class tag_recirculation_test {
46 public:
47     typedef typename JType::output_type TType;
48     typedef typename std::tuple<int, tbb::flow::continue_msg> input_tuple_type;
49     typedef tbb::flow::join_node<input_tuple_type, tbb::flow::reserving> input_join_type;
50     static const int N = std::tuple_size<TType>::value;
51     static void test() {
52         input_node_helper<N, JType>::print_remark("Recirculation test of tag-matching join");
53         INFO(" >\n");
54         for(int maxTag = 1; maxTag <10; maxTag *= 3) {
55             for(int i = 0; i < N; ++i) all_input_nodes[i][0] = NULL;
56 
57             tbb::flow::graph g;
58             // this is the tag-matching join we're testing
59             JType * my_join = makeJoin<N, JType, tbb::flow::tag_matching>::create(g);
60             // input_node for continue messages
61             tbb::flow::input_node<tbb::flow::continue_msg> snode(g, recirc_input_node_body());
62             // reserving join that matches recirculating tags with continue messages.
63             input_join_type * my_input_join = makeJoin<2, input_join_type, tbb::flow::reserving>::create(g);
64             // tbb::flow::make_edge(snode, tbb::flow::input_port<1>(*my_input_join));
65             tbb::flow::make_edge(snode, std::get<1>(my_input_join->input_ports()));
66             // queue to hold the tags
67             tbb::flow::queue_node<int> tag_queue(g);
68             tbb::flow::make_edge(tag_queue, tbb::flow::input_port<0>(*my_input_join));
69             // add all the function_nodes that are inputs to the tag-matching join
70             input_node_helper<N, JType>::add_recirc_func_nodes(*my_join, *my_input_join, g);
71             // add the function_node that accepts the output of the join and emits the int tag it was based on
72             tbb::flow::function_node<TType, int> recreate_tag(g, tbb::flow::unlimited, recirc_output_func_body<TType>());
73             tbb::flow::make_edge(*my_join, recreate_tag);
74             // now the recirculating part (output back to the queue)
75             tbb::flow::make_edge(recreate_tag, tag_queue);
76 
77             // put the tags into the queue
78             for(int t = 1; t<=maxTag; ++t) tag_queue.try_put(t);
79 
80             input_count = Recirc_count;
81             output_count = 0;
82 
83             // start up the source node to get things going
84             snode.activate();
85 
86             // wait for everything to stop
87             g.wait_for_all();
88 
89             CHECK_MESSAGE( (output_count==Recirc_count), "not all instances were received");
90 
91             int j{};
92             // grab the tags from the queue, record them
93             std::vector<bool> out_tally(maxTag, false);
94             for(int i = 0; i < maxTag; ++i) {
95                 CHECK_MESSAGE( (tag_queue.try_get(j)), "not enough tags in queue");
96                 CHECK_MESSAGE( (!out_tally.at(j-1)), "duplicate tag from queue");
97                 out_tally[j-1] = true;
98             }
99             CHECK_MESSAGE( (!tag_queue.try_get(j)), "Extra tags in recirculation queue");
100 
101             // deconstruct graph
102             input_node_helper<N, JType>::remove_recirc_func_nodes(*my_join, *my_input_join);
103             tbb::flow::remove_edge(*my_join, recreate_tag);
104             makeJoin<N, JType, tbb::flow::tag_matching>::destroy(my_join);
105             tbb::flow::remove_edge(tag_queue, tbb::flow::input_port<0>(*my_input_join));
106             tbb::flow::remove_edge(snode, tbb::flow::input_port<1>(*my_input_join));
107             makeJoin<2, input_join_type, tbb::flow::reserving>::destroy(my_input_join);
108         }
109     }
110 };
111 
112 template<typename JType>
113 class generate_recirc_test {
114 public:
115     typedef tbb::flow::join_node<JType, tbb::flow::tag_matching> join_node_type;
116     static void do_test() {
117         tag_recirculation_test<join_node_type>::test();
118     }
119 };
120 
121 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
122 #include <array>
123 #include <vector>
124 void test_follows_and_precedes_api() {
125     using msg_t = tbb::flow::continue_msg;
126     using JoinOutputType = std::tuple<msg_t, msg_t, msg_t>;
127 
128     std::array<msg_t, 3> messages_for_follows = { {msg_t(), msg_t(), msg_t()} };
129     std::vector<msg_t> messages_for_precedes = {msg_t(), msg_t(), msg_t()};
130 
131     follows_and_precedes_testing::test_follows
132         <msg_t, tbb::flow::join_node<JoinOutputType>, tbb::flow::buffer_node<msg_t>>(messages_for_follows);
133     follows_and_precedes_testing::test_follows
134         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::queueing>>(messages_for_follows);
135     follows_and_precedes_testing::test_follows
136         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::reserving>, tbb::flow::buffer_node<msg_t>>(messages_for_follows);
137     auto b = [](msg_t) { return msg_t(); };
138     class hash_compare {
139     public:
140         std::size_t hash(msg_t) const { return 0; }
141         bool equal(msg_t, msg_t) const { return true; }
142     };
143     follows_and_precedes_testing::test_follows
144         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::key_matching<msg_t, hash_compare>>, tbb::flow::buffer_node<msg_t>>
145         (messages_for_follows, b, b, b);
146 
147     follows_and_precedes_testing::test_precedes
148         <msg_t, tbb::flow::join_node<JoinOutputType>>(messages_for_precedes);
149     follows_and_precedes_testing::test_precedes
150         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::queueing>>(messages_for_precedes);
151     follows_and_precedes_testing::test_precedes
152         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::reserving>>(messages_for_precedes);
153     follows_and_precedes_testing::test_precedes
154         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::key_matching<msg_t, hash_compare>>>
155         (messages_for_precedes, b, b, b);
156 }
157 #endif
158 
159 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
160 void test_deduction_guides() {
161     using namespace tbb::flow;
162 
163     graph g;
164     using tuple_type = std::tuple<int, int, int>;
165     broadcast_node<int> b1(g), b2(g), b3(g);
166     broadcast_node<tuple_type> b4(g);
167     join_node<tuple_type> j0(g);
168 
169 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
170     join_node j1(follows(b1, b2, b3));
171     static_assert(std::is_same_v<decltype(j1), join_node<tuple_type>>);
172 
173     join_node j2(follows(b1, b2, b3), reserving());
174     static_assert(std::is_same_v<decltype(j2), join_node<tuple_type, reserving>>);
175 
176     join_node j3(precedes(b4));
177     static_assert(std::is_same_v<decltype(j3), join_node<tuple_type>>);
178 
179     join_node j4(precedes(b4), reserving());
180     static_assert(std::is_same_v<decltype(j4), join_node<tuple_type, reserving>>);
181 #endif
182 
183     join_node j5(j0);
184     static_assert(std::is_same_v<decltype(j5), join_node<tuple_type>>);
185 }
186 
187 #endif
188 
189 namespace multiple_predecessors {
190 
191 using namespace tbb::flow;
192 
193 using join_node_t = join_node<std::tuple<continue_msg, continue_msg, continue_msg>, reserving>;
194 using queue_node_t = queue_node<std::tuple<continue_msg, continue_msg, continue_msg>>;
195 
196 void twist_join_connections(
197     buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2, buffer_node<continue_msg>& bn3,
198     join_node_t& jn)
199 {
200     // order, in which edges are created/destroyed, is important
201     make_edge(bn1, input_port<0>(jn));
202     make_edge(bn2, input_port<0>(jn));
203     make_edge(bn3, input_port<0>(jn));
204 
205     remove_edge(bn3, input_port<0>(jn));
206     make_edge  (bn3, input_port<2>(jn));
207 
208     remove_edge(bn2, input_port<0>(jn));
209     make_edge  (bn2, input_port<1>(jn));
210 }
211 
212 std::unique_ptr<join_node_t> connect_join_via_make_edge(
213     graph& g, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2,
214     buffer_node<continue_msg>& bn3, queue_node_t& qn)
215 {
216     std::unique_ptr<join_node_t> jn( new join_node_t(g) );
217     twist_join_connections( bn1, bn2, bn3, *jn );
218     make_edge(*jn, qn);
219     return jn;
220 }
221 
222 #if TBB_PREVIEW_FLOW_GRAPH_FEATURES
223 std::unique_ptr<join_node_t> connect_join_via_follows(
224     graph&, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2,
225     buffer_node<continue_msg>& bn3, queue_node_t& qn)
226 {
227     auto bn_set = make_node_set(bn1, bn2, bn3);
228     std::unique_ptr<join_node_t> jn( new join_node_t(follows(bn_set)) );
229     make_edge(*jn, qn);
230     return jn;
231 }
232 
233 std::unique_ptr<join_node_t> connect_join_via_precedes(
234     graph&, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2,
235     buffer_node<continue_msg>& bn3, queue_node_t& qn)
236 {
237     auto qn_set = make_node_set(qn);
238     auto qn_copy_set = qn_set;
239     std::unique_ptr<join_node_t> jn( new join_node_t(precedes(qn_copy_set)) );
240     twist_join_connections( bn1, bn2, bn3, *jn );
241     return jn;
242 }
243 #endif // TBB_PREVIEW_FLOW_GRAPH_FEATURES
244 
245 void run_and_check(
246     graph& g, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2,
247     buffer_node<continue_msg>& bn3, queue_node_t& qn, bool expected)
248 {
249     std::tuple<continue_msg, continue_msg, continue_msg> msg;
250 
251     bn1.try_put(continue_msg());
252     bn2.try_put(continue_msg());
253     bn3.try_put(continue_msg());
254     g.wait_for_all();
255 
256     CHECK_MESSAGE(
257         (qn.try_get(msg) == expected),
258         "Unexpected message absence/existence at the end of the graph."
259     );
260 }
261 
262 template<typename ConnectJoinNodeFunc>
263 void test(ConnectJoinNodeFunc&& connect_join_node) {
264     graph g;
265     buffer_node<continue_msg> bn1(g);
266     buffer_node<continue_msg> bn2(g);
267     buffer_node<continue_msg> bn3(g);
268     queue_node_t qn(g);
269 
270     auto jn = connect_join_node(g, bn1, bn2, bn3, qn);
271 
272     run_and_check(g, bn1, bn2, bn3, qn, /*expected=*/true);
273 
274     remove_edge(bn3, input_port<2>(*jn));
275     remove_edge(bn2, input_port<1>(*jn));
276     remove_edge(bn1, input_port<0>(*jn));
277     remove_edge(*jn, qn);
278 
279     run_and_check(g, bn1, bn2, bn3, qn, /*expected=*/false);
280 }
281 } // namespace multiple_predecessors
282 
283 
284 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
285 //! Test follows and precedes API
286 //! \brief \ref error_guessing
287 TEST_CASE("Test follows and preceedes API"){
288     test_follows_and_precedes_api();
289 }
290 #endif
291 
292 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
293 //! Test deduction guides
294 //! \brief \ref requirement
295 TEST_CASE("Deduction guides test"){
296     test_deduction_guides();
297 }
298 #endif
299 
300 //! Test hash buffers behavior
301 //! \brief \ref error_guessing
302 TEST_CASE("Tagged buffers test"){
303     TestTaggedBuffers();
304 }
305 
306 //! Test with various policies and tuple sizes
307 //! \brief \ref error_guessing
308 TEST_CASE("Main test"){
309     test_main<tbb::flow::queueing>();
310     test_main<tbb::flow::reserving>();
311     test_main<tbb::flow::tag_matching>();
312 }
313 
314 //! Test with recirculating tags
315 //! \brief \ref error_guessing
316 TEST_CASE("Recirculation test"){
317     generate_recirc_test<std::tuple<int,float> >::do_test();
318 }
319 
320 //! Test maintaining correct count of ports without input
321 //! \brief \ref error_guessing
322 TEST_CASE("Test removal of the predecessor while having none") {
323     using namespace multiple_predecessors;
324 
325     test(connect_join_via_make_edge);
326 #if TBB_PREVIEW_FLOW_GRAPH_FEATURES
327     test(connect_join_via_follows);
328     test(connect_join_via_precedes);
329 #endif
330 }
331