xref: /oneTBB/test/tbb/test_join_node.cpp (revision de0109be)
1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #ifdef TBB_TEST_LOW_WORKLOAD
18     #undef MAX_TUPLE_TEST_SIZE
19     #define MAX_TUPLE_TEST_SIZE 3
20 #endif
21 
22 #include "common/config.h"
23 
24 #include "test_join_node.h"
25 
26 //! \file test_join_node.cpp
27 //! \brief Test for [flow_graph.join_node] specification
28 
29 
30 static std::atomic<int> output_count;
31 
32 // get the tag from the output tuple and emit it.
33 // the first tuple component is tag * 2 cast to the type
34 template<typename OutputTupleType>
35 class recirc_output_func_body {
36 public:
37     // we only need this to use input_node_helper
38     typedef typename tbb::flow::join_node<OutputTupleType, tbb::flow::tag_matching> join_node_type;
39     static const int N = std::tuple_size<OutputTupleType>::value;
40     int operator()(const OutputTupleType &v) {
41         int out = int(std::get<0>(v))/2;
42         input_node_helper<N, join_node_type>::only_check_value(out, v);
43         ++output_count;
44         return out;
45     }
46 };
47 
48 template<typename JType>
49 class tag_recirculation_test {
50 public:
51     typedef typename JType::output_type TType;
52     typedef typename std::tuple<int, tbb::flow::continue_msg> input_tuple_type;
53     typedef tbb::flow::join_node<input_tuple_type, tbb::flow::reserving> input_join_type;
54     static const int N = std::tuple_size<TType>::value;
55     static void test() {
56         input_node_helper<N, JType>::print_remark("Recirculation test of tag-matching join");
57         INFO(" >\n");
58         for(int maxTag = 1; maxTag <10; maxTag *= 3) {
59             for(int i = 0; i < N; ++i) all_input_nodes[i][0] = NULL;
60 
61             tbb::flow::graph g;
62             // this is the tag-matching join we're testing
63             JType * my_join = makeJoin<N, JType, tbb::flow::tag_matching>::create(g);
64             // input_node for continue messages
65             tbb::flow::input_node<tbb::flow::continue_msg> snode(g, recirc_input_node_body());
66             // reserving join that matches recirculating tags with continue messages.
67             input_join_type * my_input_join = makeJoin<2, input_join_type, tbb::flow::reserving>::create(g);
68             // tbb::flow::make_edge(snode, tbb::flow::input_port<1>(*my_input_join));
69             tbb::flow::make_edge(snode, std::get<1>(my_input_join->input_ports()));
70             // queue to hold the tags
71             tbb::flow::queue_node<int> tag_queue(g);
72             tbb::flow::make_edge(tag_queue, tbb::flow::input_port<0>(*my_input_join));
73             // add all the function_nodes that are inputs to the tag-matching join
74             input_node_helper<N, JType>::add_recirc_func_nodes(*my_join, *my_input_join, g);
75             // add the function_node that accepts the output of the join and emits the int tag it was based on
76             tbb::flow::function_node<TType, int> recreate_tag(g, tbb::flow::unlimited, recirc_output_func_body<TType>());
77             tbb::flow::make_edge(*my_join, recreate_tag);
78             // now the recirculating part (output back to the queue)
79             tbb::flow::make_edge(recreate_tag, tag_queue);
80 
81             // put the tags into the queue
82             for(int t = 1; t<=maxTag; ++t) tag_queue.try_put(t);
83 
84             input_count = Recirc_count;
85             output_count = 0;
86 
87             // start up the source node to get things going
88             snode.activate();
89 
90             // wait for everything to stop
91             g.wait_for_all();
92 
93             CHECK_MESSAGE( (output_count==Recirc_count), "not all instances were received");
94 
95             int j{};
96             // grab the tags from the queue, record them
97             std::vector<bool> out_tally(maxTag, false);
98             for(int i = 0; i < maxTag; ++i) {
99                 CHECK_MESSAGE( (tag_queue.try_get(j)), "not enough tags in queue");
100                 CHECK_MESSAGE( (!out_tally.at(j-1)), "duplicate tag from queue");
101                 out_tally[j-1] = true;
102             }
103             CHECK_MESSAGE( (!tag_queue.try_get(j)), "Extra tags in recirculation queue");
104 
105             // deconstruct graph
106             input_node_helper<N, JType>::remove_recirc_func_nodes(*my_join, *my_input_join);
107             tbb::flow::remove_edge(*my_join, recreate_tag);
108             makeJoin<N, JType, tbb::flow::tag_matching>::destroy(my_join);
109             tbb::flow::remove_edge(tag_queue, tbb::flow::input_port<0>(*my_input_join));
110             tbb::flow::remove_edge(snode, tbb::flow::input_port<1>(*my_input_join));
111             makeJoin<2, input_join_type, tbb::flow::reserving>::destroy(my_input_join);
112         }
113     }
114 };
115 
116 template<typename JType>
117 class generate_recirc_test {
118 public:
119     typedef tbb::flow::join_node<JType, tbb::flow::tag_matching> join_node_type;
120     static void do_test() {
121         tag_recirculation_test<join_node_type>::test();
122     }
123 };
124 
125 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
126 #include <array>
127 #include <vector>
128 void test_follows_and_precedes_api() {
129     using msg_t = tbb::flow::continue_msg;
130     using JoinOutputType = std::tuple<msg_t, msg_t, msg_t>;
131 
132     std::array<msg_t, 3> messages_for_follows = { {msg_t(), msg_t(), msg_t()} };
133     std::vector<msg_t> messages_for_precedes = {msg_t(), msg_t(), msg_t()};
134 
135     follows_and_precedes_testing::test_follows
136         <msg_t, tbb::flow::join_node<JoinOutputType>, tbb::flow::buffer_node<msg_t>>(messages_for_follows);
137     follows_and_precedes_testing::test_follows
138         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::queueing>>(messages_for_follows);
139     follows_and_precedes_testing::test_follows
140         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::reserving>, tbb::flow::buffer_node<msg_t>>(messages_for_follows);
141     auto b = [](msg_t) { return msg_t(); };
142     class hash_compare {
143     public:
144         std::size_t hash(msg_t) const { return 0; }
145         bool equal(msg_t, msg_t) const { return true; }
146     };
147     follows_and_precedes_testing::test_follows
148         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::key_matching<msg_t, hash_compare>>, tbb::flow::buffer_node<msg_t>>
149         (messages_for_follows, b, b, b);
150 
151     follows_and_precedes_testing::test_precedes
152         <msg_t, tbb::flow::join_node<JoinOutputType>>(messages_for_precedes);
153     follows_and_precedes_testing::test_precedes
154         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::queueing>>(messages_for_precedes);
155     follows_and_precedes_testing::test_precedes
156         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::reserving>>(messages_for_precedes);
157     follows_and_precedes_testing::test_precedes
158         <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::key_matching<msg_t, hash_compare>>>
159         (messages_for_precedes, b, b, b);
160 }
161 #endif
162 
163 namespace multiple_predecessors {
164 
165 using namespace tbb::flow;
166 
167 using join_node_t = join_node<std::tuple<continue_msg, continue_msg, continue_msg>, reserving>;
168 using queue_node_t = queue_node<std::tuple<continue_msg, continue_msg, continue_msg>>;
169 
170 void twist_join_connections(
171     buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2, buffer_node<continue_msg>& bn3,
172     join_node_t& jn)
173 {
174     // order, in which edges are created/destroyed, is important
175     make_edge(bn1, input_port<0>(jn));
176     make_edge(bn2, input_port<0>(jn));
177     make_edge(bn3, input_port<0>(jn));
178 
179     remove_edge(bn3, input_port<0>(jn));
180     make_edge  (bn3, input_port<2>(jn));
181 
182     remove_edge(bn2, input_port<0>(jn));
183     make_edge  (bn2, input_port<1>(jn));
184 }
185 
186 std::unique_ptr<join_node_t> connect_join_via_make_edge(
187     graph& g, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2,
188     buffer_node<continue_msg>& bn3, queue_node_t& qn)
189 {
190     std::unique_ptr<join_node_t> jn( new join_node_t(g) );
191     twist_join_connections( bn1, bn2, bn3, *jn );
192     make_edge(*jn, qn);
193     return jn;
194 }
195 
196 #if TBB_PREVIEW_FLOW_GRAPH_FEATURES
197 std::unique_ptr<join_node_t> connect_join_via_follows(
198     graph&, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2,
199     buffer_node<continue_msg>& bn3, queue_node_t& qn)
200 {
201     auto bn_set = make_node_set(bn1, bn2, bn3);
202     std::unique_ptr<join_node_t> jn( new join_node_t(follows(bn_set)) );
203     make_edge(*jn, qn);
204     return jn;
205 }
206 
207 std::unique_ptr<join_node_t> connect_join_via_precedes(
208     graph&, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2,
209     buffer_node<continue_msg>& bn3, queue_node_t& qn)
210 {
211     auto qn_set = make_node_set(qn);
212     auto qn_copy_set = qn_set;
213     std::unique_ptr<join_node_t> jn( new join_node_t(precedes(qn_copy_set)) );
214     twist_join_connections( bn1, bn2, bn3, *jn );
215     return jn;
216 }
217 #endif // TBB_PREVIEW_FLOW_GRAPH_FEATURES
218 
219 void run_and_check(
220     graph& g, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2,
221     buffer_node<continue_msg>& bn3, queue_node_t& qn, bool expected)
222 {
223     std::tuple<continue_msg, continue_msg, continue_msg> msg;
224 
225     bn1.try_put(continue_msg());
226     bn2.try_put(continue_msg());
227     bn3.try_put(continue_msg());
228     g.wait_for_all();
229 
230     CHECK_MESSAGE(
231         (qn.try_get(msg) == expected),
232         "Unexpected message absence/existence at the end of the graph."
233     );
234 }
235 
236 template<typename ConnectJoinNodeFunc>
237 void test(ConnectJoinNodeFunc&& connect_join_node) {
238     graph g;
239     buffer_node<continue_msg> bn1(g);
240     buffer_node<continue_msg> bn2(g);
241     buffer_node<continue_msg> bn3(g);
242     queue_node_t qn(g);
243 
244     auto jn = connect_join_node(g, bn1, bn2, bn3, qn);
245 
246     run_and_check(g, bn1, bn2, bn3, qn, /*expected=*/true);
247 
248     remove_edge(bn3, input_port<2>(*jn));
249     remove_edge(bn2, input_port<1>(*jn));
250     remove_edge(bn1, input_port<0>(*jn));
251     remove_edge(*jn, qn);
252 
253     run_and_check(g, bn1, bn2, bn3, qn, /*expected=*/false);
254 }
255 } // namespace multiple_predecessors
256 
257 
258 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
259 //! Test follows and precedes API
260 //! \brief \ref error_guessing
261 TEST_CASE("Test follows and preceedes API"){
262     test_follows_and_precedes_api();
263 }
264 #endif
265 
266 //! Test hash buffers behavior
267 //! \brief \ref error_guessing
268 TEST_CASE("Tagged buffers test"){
269     TestTaggedBuffers();
270 }
271 
272 //! Test with various policies and tuple sizes
273 //! \brief \ref error_guessing
274 TEST_CASE("Main test"){
275     test_main<tbb::flow::queueing>();
276     test_main<tbb::flow::reserving>();
277     test_main<tbb::flow::tag_matching>();
278 }
279 
280 //! Test with recirculating tags
281 //! \brief \ref error_guessing
282 TEST_CASE("Recirculation test"){
283     generate_recirc_test<std::tuple<int,float> >::do_test();
284 }
285 
286 //! Test maintaining correct count of ports without input
287 //! \brief \ref error_guessing
288 TEST_CASE("Test removal of the predecessor while having none") {
289     using namespace multiple_predecessors;
290 
291     test(connect_join_via_make_edge);
292 #if TBB_PREVIEW_FLOW_GRAPH_FEATURES
293     test(connect_join_via_follows);
294     test(connect_join_via_precedes);
295 #endif
296 }
297