1 /*
2     Copyright (c) 2020-2023 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #if __INTEL_COMPILER && _MSC_VER
18 #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated
19 #endif
20 
21 #include "conformance_flowgraph.h"
22 #include "common/test_invoke.h"
23 
24 //! \file conformance_join_node.cpp
25 //! \brief Test for [flow_graph.join_node] specification
26 
27 using input_msg = conformance::message</*default_ctor*/true, /*copy_ctor*/true, /*copy_assign*/true>;
28 using my_input_tuple = std::tuple<int, float, input_msg>;
29 
30 std::vector<my_input_tuple> get_values( conformance::test_push_receiver<my_input_tuple>& rr ) {
31     std::vector<my_input_tuple> messages;
32     my_input_tuple tmp(0, 0.f, input_msg(0));
33     while(rr.try_get(tmp)) {
34         messages.push_back(tmp);
35     }
36     return messages;
37 }
38 
39 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
40 void test_deduction_guides() {
41     using namespace tbb::flow;
42 
43     graph g;
44     using tuple_type = std::tuple<int, int, int>;
45     broadcast_node<int> b1(g), b2(g), b3(g);
46     broadcast_node<tuple_type> b4(g);
47     join_node<tuple_type> j0(g);
48 
49 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
50     join_node j1(follows(b1, b2, b3));
51     static_assert(std::is_same_v<decltype(j1), join_node<tuple_type>>);
52 
53     join_node j2(follows(b1, b2, b3), reserving());
54     static_assert(std::is_same_v<decltype(j2), join_node<tuple_type, reserving>>);
55 
56     join_node j3(precedes(b4));
57     static_assert(std::is_same_v<decltype(j3), join_node<tuple_type>>);
58 
59     join_node j4(precedes(b4), reserving());
60     static_assert(std::is_same_v<decltype(j4), join_node<tuple_type, reserving>>);
61 #endif
62 
63     join_node j5(j0);
64     static_assert(std::is_same_v<decltype(j5), join_node<tuple_type>>);
65 }
66 
67 #endif
68 
69 //! The node that is constructed has a reference to the same graph object as src.
70 //! The list of predecessors, messages in the input ports, and successors are not copied.
71 //! \brief \ref interface
72 TEST_CASE("join_node copy constructor"){
73     oneapi::tbb::flow::graph g;
74     oneapi::tbb::flow::continue_node<int> node0( g,
75                                 [](oneapi::tbb::flow::continue_msg) { return 1; } );
76 
77     oneapi::tbb::flow::join_node<std::tuple<int>> node1(g);
78     conformance::test_push_receiver<std::tuple<int>> node2(g);
79     conformance::test_push_receiver<std::tuple<int>> node3(g);
80 
81     oneapi::tbb::flow::make_edge(node0, oneapi::tbb::flow::input_port<0>(node1));
82     oneapi::tbb::flow::make_edge(node1, node2);
83     oneapi::tbb::flow::join_node<std::tuple<int>> node_copy(node1);
84 
85     oneapi::tbb::flow::make_edge(node_copy, node3);
86 
87     oneapi::tbb::flow::input_port<0>(node_copy).try_put(1);
88     g.wait_for_all();
89 
90     auto values = conformance::get_values(node3);
91     CHECK_MESSAGE((conformance::get_values(node2).size() == 0 && values.size() == 1), "Copied node doesn`t copy successor");
92 
93     node0.try_put(oneapi::tbb::flow::continue_msg());
94     g.wait_for_all();
95 
96     CHECK_MESSAGE((conformance::get_values(node2).size() == 1 && conformance::get_values(node3).size() == 0), "Copied node doesn`t copy predecessor");
97 
98     oneapi::tbb::flow::remove_edge(node1, node2);
99     oneapi::tbb::flow::input_port<0>(node1).try_put(1);
100     g.wait_for_all();
101     oneapi::tbb::flow::join_node<std::tuple<int>> node_copy2(node1);
102     oneapi::tbb::flow::make_edge(node_copy2, node3);
103     oneapi::tbb::flow::input_port<0>(node_copy2).try_put(2);
104     g.wait_for_all();
105     CHECK_MESSAGE((std::get<0>(conformance::get_values(node3)[0]) == 2), "Copied node doesn`t copy messages in the input ports");
106 }
107 
108 //! Test inheritance relations
109 //! \brief \ref interface
110 TEST_CASE("join_node inheritance"){
111     CHECK_MESSAGE((std::is_base_of<oneapi::tbb::flow::graph_node,
112                    oneapi::tbb::flow::join_node<my_input_tuple>>::value),
113                    "join_node should be derived from graph_node");
114     CHECK_MESSAGE((std::is_base_of<oneapi::tbb::flow::sender<my_input_tuple>,
115                    oneapi::tbb::flow::join_node<my_input_tuple>>::value),
116                    "join_node should be derived from sender<input_tuple>");
117 }
118 
119 //! Test join_node<queueing> behavior and broadcast property
120 //! \brief \ref requirement
121 TEST_CASE("join_node queueing policy and broadcast property") {
122     oneapi::tbb::flow::graph g;
123     oneapi::tbb::flow::function_node<int, int>
124         f1( g, oneapi::tbb::flow::unlimited, [](const int &i) { return i; } );
125     oneapi::tbb::flow::function_node<float, float>
126         f2( g, oneapi::tbb::flow::unlimited, [](const float &f) { return f; } );
127     oneapi::tbb::flow::continue_node<input_msg> c1( g,
128                             [](oneapi::tbb::flow::continue_msg) { return input_msg(1); } );
129 
130     oneapi::tbb::flow::join_node<my_input_tuple, oneapi::tbb::flow::queueing> testing_node(g);
131 
132     conformance::test_push_receiver<my_input_tuple> q_node(g);
133 
134     std::atomic<int> number{1};
135     oneapi::tbb::flow::function_node<my_input_tuple, my_input_tuple>
136         f3( g, oneapi::tbb::flow::unlimited,
137             [&]( const my_input_tuple &t ) {
138                 CHECK_MESSAGE((std::get<0>(t) == number), "Messages must be in first-in first-out order" );
139                 CHECK_MESSAGE((std::get<1>(t) == static_cast<float>(number) + 0.5f), "Messages must be in first-in first-out order" );
140                 CHECK_MESSAGE((std::get<2>(t) == 1), "Messages must be in first-in first-out order" );
141                 ++number;
142                 return t;
143             } );
144 
145     oneapi::tbb::flow::make_edge(f1, oneapi::tbb::flow::input_port<0>(testing_node));
146     oneapi::tbb::flow::make_edge(f2, oneapi::tbb::flow::input_port<1>(testing_node));
147     oneapi::tbb::flow::make_edge(c1, oneapi::tbb::flow::input_port<2>(testing_node));
148     make_edge(testing_node, f3);
149     make_edge(f3, q_node);
150 
151     f1.try_put(1);
152     g.wait_for_all();
153     CHECK_MESSAGE((get_values(q_node).size() == 0),
154         "join_node must broadcast when there is at least one message at each input port");
155     f1.try_put(2);
156     f2.try_put(1.5f);
157     g.wait_for_all();
158     CHECK_MESSAGE((get_values(q_node).size() == 0),
159         "join_node must broadcast when there is at least one message at each input port");
160     f1.try_put(3);
161     f2.try_put(2.5f);
162     c1.try_put(oneapi::tbb::flow::continue_msg());
163     g.wait_for_all();
164     CHECK_MESSAGE((get_values(q_node).size() == 1),
165         "join_node must broadcast when there is at least one message at each input port");
166     f2.try_put(3.5f);
167     c1.try_put(oneapi::tbb::flow::continue_msg());
168     g.wait_for_all();
169     CHECK_MESSAGE((get_values(q_node).size() == 1),
170         "If at least one successor accepts the tuple, the head of each input port’s queue is removed");
171     c1.try_put(oneapi::tbb::flow::continue_msg());
172     g.wait_for_all();
173     CHECK_MESSAGE((get_values(q_node).size() == 1),
174         "If at least one successor accepts the tuple, the head of each input port’s queue is removed");
175     c1.try_put(oneapi::tbb::flow::continue_msg());
176     g.wait_for_all();
177     CHECK_MESSAGE((get_values(q_node).size() == 0),
178         "join_node must broadcast when there is at least one message at each input port");
179 
180     oneapi::tbb::flow::remove_edge(testing_node, f3);
181 
182     f1.try_put(1);
183     f2.try_put(1);
184     c1.try_put(oneapi::tbb::flow::continue_msg());
185     g.wait_for_all();
186 
187     my_input_tuple tmp(0, 0.f, input_msg(0));
188     CHECK_MESSAGE((testing_node.try_get(tmp)), "If no one successor accepts the tuple the messages\
189         must remain in their respective input port queues");
190     CHECK_MESSAGE((tmp == my_input_tuple(1, 1.f, input_msg(1))), "If no one successor accepts the tuple\
191         the messages must remain in their respective input port queues");
192 }
193 
194 //! Test join_node<reserving> behavior
195 //! \brief \ref requirement
196 TEST_CASE("join_node reserving policy") {
197     conformance::test_with_reserving_join_node_class<oneapi::tbb::flow::write_once_node<int>>();
198 }
199 
200 template<typename KeyType>
201 struct MyHash{
202     std::size_t hash(const KeyType &k) const {
203         return k * 2000 + 3;
204     }
205 
206     bool equal(const KeyType &k1, const KeyType &k2) const{
207         return hash(k1) == hash(k2);
208     }
209 };
210 
211 //! Test join_node<key_matching> behavior
212 //! \brief \ref requirement
213 TEST_CASE("join_node key_matching policy"){
214     oneapi::tbb::flow::graph g;
215     auto body1 = [](const oneapi::tbb::flow::continue_msg &) -> int { return 1; };
216     auto body2 = [](const float &val) -> int { return static_cast<int>(val); };
217 
218     oneapi::tbb::flow::join_node<std::tuple<oneapi::tbb::flow::continue_msg, float>,
219         oneapi::tbb::flow::key_matching<int, MyHash<int>>> testing_node(g, body1, body2);
220 
221     oneapi::tbb::flow::input_port<0>(testing_node).try_put(oneapi::tbb::flow::continue_msg());
222     oneapi::tbb::flow::input_port<1>(testing_node).try_put(1.3f);
223 
224     g.wait_for_all();
225 
226     std::tuple<oneapi::tbb::flow::continue_msg, float> tmp;
227     CHECK_MESSAGE((testing_node.try_get(tmp)), "Mapped keys should match.\
228         If no successor accepts the tuple, it is must been saved and will be forwarded on a subsequent try_get");
229     CHECK_MESSAGE((!testing_node.try_get(tmp)), "Message should not exist after item is consumed");
230 }
231 
232 //! Test join_node<tag_matching> behavior
233 //! \brief \ref requirement
234 TEST_CASE("join_node tag_matching policy"){
235     oneapi::tbb::flow::graph g;
236     auto body1 = [](const oneapi::tbb::flow::continue_msg &) -> oneapi::tbb::flow::tag_value { return 1; };
237     auto body2 = [](const float &val) -> oneapi::tbb::flow::tag_value { return static_cast<oneapi::tbb::flow::tag_value>(val); };
238 
239     oneapi::tbb::flow::join_node<std::tuple<oneapi::tbb::flow::continue_msg, float>,
240         oneapi::tbb::flow::tag_matching> testing_node(g, body1, body2);
241 
242     oneapi::tbb::flow::input_port<0>(testing_node).try_put(oneapi::tbb::flow::continue_msg());
243     oneapi::tbb::flow::input_port<1>(testing_node).try_put(1.3f);
244 
245     g.wait_for_all();
246 
247     std::tuple<oneapi::tbb::flow::continue_msg, float> tmp;
248     CHECK_MESSAGE((testing_node.try_get(tmp) == true), "Mapped keys should match");
249 }
250 
251 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
252 //! Test deduction guides
253 //! \brief \ref requirement
254 TEST_CASE("Deduction guides test"){
255     test_deduction_guides();
256 }
257 #endif
258 
259 //! Test join_node input_ports() returns a tuple of input ports.
260 //! \brief \ref interface \ref requirement
261 TEST_CASE("join_node output_ports") {
262     oneapi::tbb::flow::graph g;
263     oneapi::tbb::flow::join_node<std::tuple<int>> node(g);
264 
265     CHECK_MESSAGE((std::is_same<oneapi::tbb::flow::join_node<std::tuple<int>>::input_ports_type&,
266         decltype(node.input_ports())>::value), "join_node input_ports should returns a tuple of input ports");
267 }
268 
269 #if __TBB_CPP17_INVOKE_PRESENT
270 
271 template <typename K, typename Body1, typename Body2>
272 void test_invoke_basic(Body1 body1, Body2 body2) {
273     static_assert(std::is_same_v<std::decay_t<K>, std::size_t>, "incorrect test setup");
274     using namespace oneapi::tbb::flow;
275     auto generator = [](std::size_t n) { return test_invoke::SmartID<std::size_t>(n); };
276     graph g;
277 
278     function_node<std::size_t, test_invoke::SmartID<std::size_t>> f1(g, unlimited, generator);
279     function_node<std::size_t, test_invoke::SmartID<std::size_t>> f2(g, unlimited, generator);
280 
281     using tuple_type = std::tuple<test_invoke::SmartID<std::size_t>, test_invoke::SmartID<std::size_t>>;
282     using join_type = join_node<tuple_type, key_matching<K>>;
283 
284 
285     join_type j(g, body1, body2);
286 
287     buffer_node<tuple_type> buf(g);
288 
289     make_edge(f1, input_port<0>(j));
290     make_edge(f2, input_port<1>(j));
291     make_edge(j, buf);
292 
293     std::size_t objects_count = 100;
294     for (std::size_t i = 0; i < objects_count; ++i) {
295         f1.try_put(i);
296         f2.try_put(objects_count - i - 1);
297     }
298 
299     g.wait_for_all();
300 
301     std::size_t buf_size = 0;
302     tuple_type tpl;
303 
304     while(buf.try_get(tpl)) {
305         ++buf_size;
306         CHECK(std::get<0>(tpl).id == std::get<1>(tpl).id);
307     }
308     CHECK(buf_size == objects_count);
309 }
310 
311 //! Test that key_matching join_node uses std::invoke to run the body
312 //! \brief \ref requirement
313 TEST_CASE("key_matching join_node invoke semantics") {
314     test_invoke_basic</*K = */std::size_t>(&test_invoke::SmartID<std::size_t>::get_id, &test_invoke::SmartID<std::size_t>::id);
315     test_invoke_basic</*K = */const std::size_t&>(&test_invoke::SmartID<std::size_t>::get_id_ref, &test_invoke::SmartID<std::size_t>::get_id_ref);
316 }
317 #endif // __TBB_CPP17_INVOKE_PRESENT
318