1 /*
2     Copyright (c) 2020-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #if __INTEL_COMPILER && _MSC_VER
18 #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated
19 #endif
20 
21 #include "conformance_flowgraph.h"
22 
23 //! \file conformance_join_node.cpp
24 //! \brief Test for [flow_graph.join_node] specification
25 
26 using input_msg = conformance::message</*default_ctor*/true, /*copy_ctor*/true, /*copy_assign*/true>;
27 using my_input_tuple = std::tuple<int, float, input_msg>;
28 
29 std::vector<my_input_tuple> get_values( conformance::test_push_receiver<my_input_tuple>& rr ) {
30     std::vector<my_input_tuple> messages;
31     int val = 0;
32     for(my_input_tuple tmp(0, 0.f, input_msg(0)); rr.try_get(tmp); ++val) {
33         messages.push_back(tmp);
34     }
35     return messages;
36 }
37 
38 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
39 void test_deduction_guides() {
40     using namespace tbb::flow;
41 
42     graph g;
43     using tuple_type = std::tuple<int, int, int>;
44     broadcast_node<int> b1(g), b2(g), b3(g);
45     broadcast_node<tuple_type> b4(g);
46     join_node<tuple_type> j0(g);
47 
48 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
49     join_node j1(follows(b1, b2, b3));
50     static_assert(std::is_same_v<decltype(j1), join_node<tuple_type>>);
51 
52     join_node j2(follows(b1, b2, b3), reserving());
53     static_assert(std::is_same_v<decltype(j2), join_node<tuple_type, reserving>>);
54 
55     join_node j3(precedes(b4));
56     static_assert(std::is_same_v<decltype(j3), join_node<tuple_type>>);
57 
58     join_node j4(precedes(b4), reserving());
59     static_assert(std::is_same_v<decltype(j4), join_node<tuple_type, reserving>>);
60 #endif
61 
62     join_node j5(j0);
63     static_assert(std::is_same_v<decltype(j5), join_node<tuple_type>>);
64 }
65 
66 #endif
67 
68 //! The node that is constructed has a reference to the same graph object as src.
69 //! The list of predecessors, messages in the input ports, and successors are not copied.
70 //! \brief \ref interface
71 TEST_CASE("join_node copy constructor"){
72     oneapi::tbb::flow::graph g;
73     oneapi::tbb::flow::continue_node<int> node0( g,
74                                 [](oneapi::tbb::flow::continue_msg) { return 1; } );
75 
76     oneapi::tbb::flow::join_node<std::tuple<int>> node1(g);
77     conformance::test_push_receiver<std::tuple<int>> node2(g);
78     conformance::test_push_receiver<std::tuple<int>> node3(g);
79 
80     oneapi::tbb::flow::make_edge(node0, oneapi::tbb::flow::input_port<0>(node1));
81     oneapi::tbb::flow::make_edge(node1, node2);
82     oneapi::tbb::flow::join_node<std::tuple<int>> node_copy(node1);
83 
84     oneapi::tbb::flow::make_edge(node_copy, node3);
85 
86     oneapi::tbb::flow::input_port<0>(node_copy).try_put(1);
87     g.wait_for_all();
88 
89     auto values = conformance::get_values(node3);
90     CHECK_MESSAGE((conformance::get_values(node2).size() == 0 && values.size() == 1), "Copied node doesn`t copy successor");
91 
92     node0.try_put(oneapi::tbb::flow::continue_msg());
93     g.wait_for_all();
94 
95     CHECK_MESSAGE((conformance::get_values(node2).size() == 1 && conformance::get_values(node3).size() == 0), "Copied node doesn`t copy predecessor");
96 
97     oneapi::tbb::flow::remove_edge(node1, node2);
98     oneapi::tbb::flow::input_port<0>(node1).try_put(1);
99     g.wait_for_all();
100     oneapi::tbb::flow::join_node<std::tuple<int>> node_copy2(node1);
101     oneapi::tbb::flow::make_edge(node_copy2, node3);
102     oneapi::tbb::flow::input_port<0>(node_copy2).try_put(2);
103     g.wait_for_all();
104     CHECK_MESSAGE((std::get<0>(conformance::get_values(node3)[0]) == 2), "Copied node doesn`t copy messages in the input ports");
105 }
106 
107 //! Test inheritance relations
108 //! \brief \ref interface
109 TEST_CASE("join_node inheritance"){
110     CHECK_MESSAGE((std::is_base_of<oneapi::tbb::flow::graph_node,
111                    oneapi::tbb::flow::join_node<my_input_tuple>>::value),
112                    "join_node should be derived from graph_node");
113     CHECK_MESSAGE((std::is_base_of<oneapi::tbb::flow::sender<my_input_tuple>,
114                    oneapi::tbb::flow::join_node<my_input_tuple>>::value),
115                    "join_node should be derived from sender<input_tuple>");
116 }
117 
118 //! Test join_node<queueing> behavior and broadcast property
119 //! \brief \ref requirement
120 TEST_CASE("join_node queueing policy and broadcast property") {
121     oneapi::tbb::flow::graph g;
122     oneapi::tbb::flow::function_node<int, int>
123         f1( g, oneapi::tbb::flow::unlimited, [](const int &i) { return i; } );
124     oneapi::tbb::flow::function_node<float, float>
125         f2( g, oneapi::tbb::flow::unlimited, [](const float &f) { return f; } );
126     oneapi::tbb::flow::continue_node<input_msg> c1( g,
127                             [](oneapi::tbb::flow::continue_msg) { return input_msg(1); } );
128 
129     oneapi::tbb::flow::join_node<my_input_tuple, oneapi::tbb::flow::queueing> testing_node(g);
130 
131     conformance::test_push_receiver<my_input_tuple> q_node(g);
132 
133     std::atomic<int> number{1};
134     oneapi::tbb::flow::function_node<my_input_tuple, my_input_tuple>
135         f3( g, oneapi::tbb::flow::unlimited,
136             [&]( const my_input_tuple &t ) {
137                 CHECK_MESSAGE((std::get<0>(t) == number), "Messages must be in first-in first-out order" );
138                 CHECK_MESSAGE((std::get<1>(t) == static_cast<float>(number) + 0.5f), "Messages must be in first-in first-out order" );
139                 CHECK_MESSAGE((std::get<2>(t) == 1), "Messages must be in first-in first-out order" );
140                 ++number;
141                 return t;
142             } );
143 
144     oneapi::tbb::flow::make_edge(f1, oneapi::tbb::flow::input_port<0>(testing_node));
145     oneapi::tbb::flow::make_edge(f2, oneapi::tbb::flow::input_port<1>(testing_node));
146     oneapi::tbb::flow::make_edge(c1, oneapi::tbb::flow::input_port<2>(testing_node));
147     make_edge(testing_node, f3);
148     make_edge(f3, q_node);
149 
150     f1.try_put(1);
151     g.wait_for_all();
152     CHECK_MESSAGE((get_values(q_node).size() == 0),
153         "join_node must broadcast when there is at least one message at each input port");
154     f1.try_put(2);
155     f2.try_put(1.5f);
156     g.wait_for_all();
157     CHECK_MESSAGE((get_values(q_node).size() == 0),
158         "join_node must broadcast when there is at least one message at each input port");
159     f1.try_put(3);
160     f2.try_put(2.5f);
161     c1.try_put(oneapi::tbb::flow::continue_msg());
162     g.wait_for_all();
163     CHECK_MESSAGE((get_values(q_node).size() == 1),
164         "join_node must broadcast when there is at least one message at each input port");
165     f2.try_put(3.5f);
166     c1.try_put(oneapi::tbb::flow::continue_msg());
167     g.wait_for_all();
168     CHECK_MESSAGE((get_values(q_node).size() == 1),
169         "If at least one successor accepts the tuple, the head of each input port’s queue is removed");
170     c1.try_put(oneapi::tbb::flow::continue_msg());
171     g.wait_for_all();
172     CHECK_MESSAGE((get_values(q_node).size() == 1),
173         "If at least one successor accepts the tuple, the head of each input port’s queue is removed");
174     c1.try_put(oneapi::tbb::flow::continue_msg());
175     g.wait_for_all();
176     CHECK_MESSAGE((get_values(q_node).size() == 0),
177         "join_node must broadcast when there is at least one message at each input port");
178 
179     oneapi::tbb::flow::remove_edge(testing_node, f3);
180 
181     f1.try_put(1);
182     f2.try_put(1);
183     c1.try_put(oneapi::tbb::flow::continue_msg());
184     g.wait_for_all();
185 
186     my_input_tuple tmp(0, 0.f, input_msg(0));
187     CHECK_MESSAGE((testing_node.try_get(tmp)), "If no one successor accepts the tuple the messages\
188         must remain in their respective input port queues");
189     CHECK_MESSAGE((tmp == my_input_tuple(1, 1.f, input_msg(1))), "If no one successor accepts the tuple\
190         the messages must remain in their respective input port queues");
191 }
192 
193 //! Test join_node<reserving> behavior
194 //! \brief \ref requirement
195 TEST_CASE("join_node reserving policy") {
196     conformance::test_with_reserving_join_node_class<oneapi::tbb::flow::write_once_node<int>>();
197 }
198 
199 template<typename KeyType>
200 struct MyHash{
201     std::size_t hash(const KeyType &k) const {
202         return k * 2000 + 3;
203     }
204 
205     bool equal(const KeyType &k1, const KeyType &k2) const{
206         return hash(k1) == hash(k2);
207     }
208 };
209 
210 //! Test join_node<key_matching> behavior
211 //! \brief \ref requirement
212 TEST_CASE("join_node key_matching policy"){
213     oneapi::tbb::flow::graph g;
214     auto body1 = [](const oneapi::tbb::flow::continue_msg &) -> int { return 1; };
215     auto body2 = [](const float &val) -> int { return static_cast<int>(val); };
216 
217     oneapi::tbb::flow::join_node<std::tuple<oneapi::tbb::flow::continue_msg, float>,
218         oneapi::tbb::flow::key_matching<int, MyHash<int>>> testing_node(g, body1, body2);
219 
220     oneapi::tbb::flow::input_port<0>(testing_node).try_put(oneapi::tbb::flow::continue_msg());
221     oneapi::tbb::flow::input_port<1>(testing_node).try_put(1.3f);
222 
223     g.wait_for_all();
224 
225     std::tuple<oneapi::tbb::flow::continue_msg, float> tmp;
226     CHECK_MESSAGE((testing_node.try_get(tmp)), "Mapped keys should match.\
227         If no successor accepts the tuple, it is must been saved and will be forwarded on a subsequent try_get");
228     CHECK_MESSAGE((!testing_node.try_get(tmp)), "Message should not exist after item is consumed");
229 }
230 
231 //! Test join_node<tag_matching> behavior
232 //! \brief \ref requirement
233 TEST_CASE("join_node tag_matching policy"){
234     oneapi::tbb::flow::graph g;
235     auto body1 = [](const oneapi::tbb::flow::continue_msg &) -> oneapi::tbb::flow::tag_value { return 1; };
236     auto body2 = [](const float &val) -> oneapi::tbb::flow::tag_value { return static_cast<oneapi::tbb::flow::tag_value>(val); };
237 
238     oneapi::tbb::flow::join_node<std::tuple<oneapi::tbb::flow::continue_msg, float>,
239         oneapi::tbb::flow::tag_matching> testing_node(g, body1, body2);
240 
241     oneapi::tbb::flow::input_port<0>(testing_node).try_put(oneapi::tbb::flow::continue_msg());
242     oneapi::tbb::flow::input_port<1>(testing_node).try_put(1.3f);
243 
244     g.wait_for_all();
245 
246     std::tuple<oneapi::tbb::flow::continue_msg, float> tmp;
247     CHECK_MESSAGE((testing_node.try_get(tmp) == true), "Mapped keys should match");
248 }
249 
250 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
251 //! Test deduction guides
252 //! \brief \ref requirement
253 TEST_CASE("Deduction guides test"){
254     test_deduction_guides();
255 }
256 #endif
257 
258 //! Test join_node input_ports() returns a tuple of input ports.
259 //! \brief \ref interface \ref requirement
260 TEST_CASE("join_node output_ports") {
261     oneapi::tbb::flow::graph g;
262     oneapi::tbb::flow::join_node<std::tuple<int>> node(g);
263 
264     CHECK_MESSAGE((std::is_same<oneapi::tbb::flow::join_node<std::tuple<int>>::input_ports_type&,
265         decltype(node.input_ports())>::value), "join_node input_ports should returns a tuple of input ports");
266 }
267