1 /*
2 Copyright (c) 2020-2023 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 #if __INTEL_COMPILER && _MSC_VER
18 #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated
19 #endif
20
21 #include "conformance_flowgraph.h"
22 #include "common/test_invoke.h"
23
24 //! \file conformance_join_node.cpp
25 //! \brief Test for [flow_graph.join_node] specification
26
27 using input_msg = conformance::message</*default_ctor*/true, /*copy_ctor*/true, /*copy_assign*/true>;
28 using my_input_tuple = std::tuple<int, float, input_msg>;
29
get_values(conformance::test_push_receiver<my_input_tuple> & rr)30 std::vector<my_input_tuple> get_values( conformance::test_push_receiver<my_input_tuple>& rr ) {
31 std::vector<my_input_tuple> messages;
32 my_input_tuple tmp(0, 0.f, input_msg(0));
33 while(rr.try_get(tmp)) {
34 messages.push_back(tmp);
35 }
36 return messages;
37 }
38
39 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
test_deduction_guides()40 void test_deduction_guides() {
41 using namespace tbb::flow;
42
43 graph g;
44 using tuple_type = std::tuple<int, int, int>;
45 broadcast_node<int> b1(g), b2(g), b3(g);
46 broadcast_node<tuple_type> b4(g);
47 join_node<tuple_type> j0(g);
48
49 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
50 join_node j1(follows(b1, b2, b3));
51 static_assert(std::is_same_v<decltype(j1), join_node<tuple_type>>);
52
53 join_node j2(follows(b1, b2, b3), reserving());
54 static_assert(std::is_same_v<decltype(j2), join_node<tuple_type, reserving>>);
55
56 join_node j3(precedes(b4));
57 static_assert(std::is_same_v<decltype(j3), join_node<tuple_type>>);
58
59 join_node j4(precedes(b4), reserving());
60 static_assert(std::is_same_v<decltype(j4), join_node<tuple_type, reserving>>);
61 #endif
62
63 join_node j5(j0);
64 static_assert(std::is_same_v<decltype(j5), join_node<tuple_type>>);
65 }
66
67 #endif
68
69 //! The node that is constructed has a reference to the same graph object as src.
70 //! The list of predecessors, messages in the input ports, and successors are not copied.
71 //! \brief \ref interface
72 TEST_CASE("join_node copy constructor"){
73 oneapi::tbb::flow::graph g;
74 oneapi::tbb::flow::continue_node<int> node0( g,
__anon798ebb9d0102(oneapi::tbb::flow::continue_msg) 75 [](oneapi::tbb::flow::continue_msg) { return 1; } );
76
77 oneapi::tbb::flow::join_node<std::tuple<int>> node1(g);
78 conformance::test_push_receiver<std::tuple<int>> node2(g);
79 conformance::test_push_receiver<std::tuple<int>> node3(g);
80
81 oneapi::tbb::flow::make_edge(node0, oneapi::tbb::flow::input_port<0>(node1));
82 oneapi::tbb::flow::make_edge(node1, node2);
83 oneapi::tbb::flow::join_node<std::tuple<int>> node_copy(node1);
84
85 oneapi::tbb::flow::make_edge(node_copy, node3);
86
87 oneapi::tbb::flow::input_port<0>(node_copy).try_put(1);
88 g.wait_for_all();
89
90 auto values = conformance::get_values(node3);
91 CHECK_MESSAGE((conformance::get_values(node2).size() == 0 && values.size() == 1), "Copied node doesn`t copy successor");
92
93 node0.try_put(oneapi::tbb::flow::continue_msg());
94 g.wait_for_all();
95
96 CHECK_MESSAGE((conformance::get_values(node2).size() == 1 && conformance::get_values(node3).size() == 0), "Copied node doesn`t copy predecessor");
97
98 oneapi::tbb::flow::remove_edge(node1, node2);
99 oneapi::tbb::flow::input_port<0>(node1).try_put(1);
100 g.wait_for_all();
101 oneapi::tbb::flow::join_node<std::tuple<int>> node_copy2(node1);
102 oneapi::tbb::flow::make_edge(node_copy2, node3);
103 oneapi::tbb::flow::input_port<0>(node_copy2).try_put(2);
104 g.wait_for_all();
105 CHECK_MESSAGE((std::get<0>(conformance::get_values(node3)[0]) == 2), "Copied node doesn`t copy messages in the input ports");
106 }
107
108 //! Test inheritance relations
109 //! \brief \ref interface
110 TEST_CASE("join_node inheritance"){
111 CHECK_MESSAGE((std::is_base_of<oneapi::tbb::flow::graph_node,
112 oneapi::tbb::flow::join_node<my_input_tuple>>::value),
113 "join_node should be derived from graph_node");
114 CHECK_MESSAGE((std::is_base_of<oneapi::tbb::flow::sender<my_input_tuple>,
115 oneapi::tbb::flow::join_node<my_input_tuple>>::value),
116 "join_node should be derived from sender<input_tuple>");
117 }
118
119 //! Test join_node<queueing> behavior and broadcast property
120 //! \brief \ref requirement
121 TEST_CASE("join_node queueing policy and broadcast property") {
122 oneapi::tbb::flow::graph g;
123 oneapi::tbb::flow::function_node<int, int>
__anon798ebb9d0202(const int &i) 124 f1( g, oneapi::tbb::flow::unlimited, [](const int &i) { return i; } );
125 oneapi::tbb::flow::function_node<float, float>
__anon798ebb9d0302(const float &f) 126 f2( g, oneapi::tbb::flow::unlimited, [](const float &f) { return f; } );
127 oneapi::tbb::flow::continue_node<input_msg> c1( g,
__anon798ebb9d0402(oneapi::tbb::flow::continue_msg) 128 [](oneapi::tbb::flow::continue_msg) { return input_msg(1); } );
129
130 oneapi::tbb::flow::join_node<my_input_tuple, oneapi::tbb::flow::queueing> testing_node(g);
131
132 conformance::test_push_receiver<my_input_tuple> q_node(g);
133
134 std::atomic<int> number{1};
135 oneapi::tbb::flow::function_node<my_input_tuple, my_input_tuple>
136 f3( g, oneapi::tbb::flow::unlimited,
__anon798ebb9d0502( const my_input_tuple &t ) 137 [&]( const my_input_tuple &t ) {
138 CHECK_MESSAGE((std::get<0>(t) == number), "Messages must be in first-in first-out order" );
139 CHECK_MESSAGE((std::get<1>(t) == static_cast<float>(number) + 0.5f), "Messages must be in first-in first-out order" );
140 CHECK_MESSAGE((std::get<2>(t) == 1), "Messages must be in first-in first-out order" );
141 ++number;
142 return t;
143 } );
144
145 oneapi::tbb::flow::make_edge(f1, oneapi::tbb::flow::input_port<0>(testing_node));
146 oneapi::tbb::flow::make_edge(f2, oneapi::tbb::flow::input_port<1>(testing_node));
147 oneapi::tbb::flow::make_edge(c1, oneapi::tbb::flow::input_port<2>(testing_node));
148 make_edge(testing_node, f3);
149 make_edge(f3, q_node);
150
151 f1.try_put(1);
152 g.wait_for_all();
153 CHECK_MESSAGE((get_values(q_node).size() == 0),
154 "join_node must broadcast when there is at least one message at each input port");
155 f1.try_put(2);
156 f2.try_put(1.5f);
157 g.wait_for_all();
158 CHECK_MESSAGE((get_values(q_node).size() == 0),
159 "join_node must broadcast when there is at least one message at each input port");
160 f1.try_put(3);
161 f2.try_put(2.5f);
162 c1.try_put(oneapi::tbb::flow::continue_msg());
163 g.wait_for_all();
164 CHECK_MESSAGE((get_values(q_node).size() == 1),
165 "join_node must broadcast when there is at least one message at each input port");
166 f2.try_put(3.5f);
167 c1.try_put(oneapi::tbb::flow::continue_msg());
168 g.wait_for_all();
169 CHECK_MESSAGE((get_values(q_node).size() == 1),
170 "If at least one successor accepts the tuple, the head of each input port’s queue is removed");
171 c1.try_put(oneapi::tbb::flow::continue_msg());
172 g.wait_for_all();
173 CHECK_MESSAGE((get_values(q_node).size() == 1),
174 "If at least one successor accepts the tuple, the head of each input port’s queue is removed");
175 c1.try_put(oneapi::tbb::flow::continue_msg());
176 g.wait_for_all();
177 CHECK_MESSAGE((get_values(q_node).size() == 0),
178 "join_node must broadcast when there is at least one message at each input port");
179
180 oneapi::tbb::flow::remove_edge(testing_node, f3);
181
182 f1.try_put(1);
183 f2.try_put(1);
184 c1.try_put(oneapi::tbb::flow::continue_msg());
185 g.wait_for_all();
186
187 my_input_tuple tmp(0, 0.f, input_msg(0));
188 CHECK_MESSAGE((testing_node.try_get(tmp)), "If no one successor accepts the tuple the messages\
189 must remain in their respective input port queues");
190 CHECK_MESSAGE((tmp == my_input_tuple(1, 1.f, input_msg(1))), "If no one successor accepts the tuple\
191 the messages must remain in their respective input port queues");
192 }
193
194 //! Test join_node<reserving> behavior
195 //! \brief \ref requirement
196 TEST_CASE("join_node reserving policy") {
197 conformance::test_with_reserving_join_node_class<oneapi::tbb::flow::write_once_node<int>>();
198 }
199
200 template<typename KeyType>
201 struct MyHash{
hashMyHash202 std::size_t hash(const KeyType &k) const {
203 return k * 2000 + 3;
204 }
205
equalMyHash206 bool equal(const KeyType &k1, const KeyType &k2) const{
207 return hash(k1) == hash(k2);
208 }
209 };
210
211 //! Test join_node<key_matching> behavior
212 //! \brief \ref requirement
213 TEST_CASE("join_node key_matching policy"){
214 oneapi::tbb::flow::graph g;
__anon798ebb9d0602(const oneapi::tbb::flow::continue_msg &) 215 auto body1 = [](const oneapi::tbb::flow::continue_msg &) -> int { return 1; };
__anon798ebb9d0702(const float &val) 216 auto body2 = [](const float &val) -> int { return static_cast<int>(val); };
217
218 oneapi::tbb::flow::join_node<std::tuple<oneapi::tbb::flow::continue_msg, float>,
219 oneapi::tbb::flow::key_matching<int, MyHash<int>>> testing_node(g, body1, body2);
220
221 oneapi::tbb::flow::input_port<0>(testing_node).try_put(oneapi::tbb::flow::continue_msg());
222 oneapi::tbb::flow::input_port<1>(testing_node).try_put(1.3f);
223
224 g.wait_for_all();
225
226 std::tuple<oneapi::tbb::flow::continue_msg, float> tmp;
227 CHECK_MESSAGE((testing_node.try_get(tmp)), "Mapped keys should match.\
228 If no successor accepts the tuple, it is must been saved and will be forwarded on a subsequent try_get");
229 CHECK_MESSAGE((!testing_node.try_get(tmp)), "Message should not exist after item is consumed");
230 }
231
232 //! Test join_node<tag_matching> behavior
233 //! \brief \ref requirement
234 TEST_CASE("join_node tag_matching policy"){
235 oneapi::tbb::flow::graph g;
__anon798ebb9d0802(const oneapi::tbb::flow::continue_msg &) 236 auto body1 = [](const oneapi::tbb::flow::continue_msg &) -> oneapi::tbb::flow::tag_value { return 1; };
__anon798ebb9d0902(const float &val) 237 auto body2 = [](const float &val) -> oneapi::tbb::flow::tag_value { return static_cast<oneapi::tbb::flow::tag_value>(val); };
238
239 oneapi::tbb::flow::join_node<std::tuple<oneapi::tbb::flow::continue_msg, float>,
240 oneapi::tbb::flow::tag_matching> testing_node(g, body1, body2);
241
242 oneapi::tbb::flow::input_port<0>(testing_node).try_put(oneapi::tbb::flow::continue_msg());
243 oneapi::tbb::flow::input_port<1>(testing_node).try_put(1.3f);
244
245 g.wait_for_all();
246
247 std::tuple<oneapi::tbb::flow::continue_msg, float> tmp;
248 CHECK_MESSAGE((testing_node.try_get(tmp) == true), "Mapped keys should match");
249 }
250
251 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
252 //! Test deduction guides
253 //! \brief \ref requirement
254 TEST_CASE("Deduction guides test"){
255 test_deduction_guides();
256 }
257 #endif
258
259 //! Test join_node input_ports() returns a tuple of input ports.
260 //! \brief \ref interface \ref requirement
261 TEST_CASE("join_node output_ports") {
262 oneapi::tbb::flow::graph g;
263 oneapi::tbb::flow::join_node<std::tuple<int>> node(g);
264
265 CHECK_MESSAGE((std::is_same<oneapi::tbb::flow::join_node<std::tuple<int>>::input_ports_type&,
266 decltype(node.input_ports())>::value), "join_node input_ports should returns a tuple of input ports");
267 }
268
269 #if __TBB_CPP17_INVOKE_PRESENT
270
271 template <typename K, typename Body1, typename Body2>
test_invoke_basic(Body1 body1,Body2 body2)272 void test_invoke_basic(Body1 body1, Body2 body2) {
273 static_assert(std::is_same_v<std::decay_t<K>, std::size_t>, "incorrect test setup");
274 using namespace oneapi::tbb::flow;
275 auto generator = [](std::size_t n) { return test_invoke::SmartID<std::size_t>(n); };
276 graph g;
277
278 function_node<std::size_t, test_invoke::SmartID<std::size_t>> f1(g, unlimited, generator);
279 function_node<std::size_t, test_invoke::SmartID<std::size_t>> f2(g, unlimited, generator);
280
281 using tuple_type = std::tuple<test_invoke::SmartID<std::size_t>, test_invoke::SmartID<std::size_t>>;
282 using join_type = join_node<tuple_type, key_matching<K>>;
283
284
285 join_type j(g, body1, body2);
286
287 buffer_node<tuple_type> buf(g);
288
289 make_edge(f1, input_port<0>(j));
290 make_edge(f2, input_port<1>(j));
291 make_edge(j, buf);
292
293 std::size_t objects_count = 100;
294 for (std::size_t i = 0; i < objects_count; ++i) {
295 f1.try_put(i);
296 f2.try_put(objects_count - i - 1);
297 }
298
299 g.wait_for_all();
300
301 std::size_t buf_size = 0;
302 tuple_type tpl;
303
304 while(buf.try_get(tpl)) {
305 ++buf_size;
306 CHECK(std::get<0>(tpl).id == std::get<1>(tpl).id);
307 }
308 CHECK(buf_size == objects_count);
309 }
310
311 //! Test that key_matching join_node uses std::invoke to run the body
312 //! \brief \ref requirement
313 TEST_CASE("key_matching join_node invoke semantics") {
314 test_invoke_basic</*K = */std::size_t>(&test_invoke::SmartID<std::size_t>::get_id, &test_invoke::SmartID<std::size_t>::id);
315 test_invoke_basic</*K = */const std::size_t&>(&test_invoke::SmartID<std::size_t>::get_id_ref, &test_invoke::SmartID<std::size_t>::get_id_ref);
316 }
317 #endif // __TBB_CPP17_INVOKE_PRESENT
318