1 /* 2 Copyright (c) 2020-2023 Intel Corporation 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 #if __INTEL_COMPILER && _MSC_VER 18 #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated 19 #endif 20 21 #include "conformance_flowgraph.h" 22 #include "common/test_invoke.h" 23 24 //! \file conformance_join_node.cpp 25 //! \brief Test for [flow_graph.join_node] specification 26 27 using input_msg = conformance::message</*default_ctor*/true, /*copy_ctor*/true, /*copy_assign*/true>; 28 using my_input_tuple = std::tuple<int, float, input_msg>; 29 30 std::vector<my_input_tuple> get_values( conformance::test_push_receiver<my_input_tuple>& rr ) { 31 std::vector<my_input_tuple> messages; 32 my_input_tuple tmp(0, 0.f, input_msg(0)); 33 while(rr.try_get(tmp)) { 34 messages.push_back(tmp); 35 } 36 return messages; 37 } 38 39 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 40 void test_deduction_guides() { 41 using namespace tbb::flow; 42 43 graph g; 44 using tuple_type = std::tuple<int, int, int>; 45 broadcast_node<int> b1(g), b2(g), b3(g); 46 broadcast_node<tuple_type> b4(g); 47 join_node<tuple_type> j0(g); 48 49 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET 50 join_node j1(follows(b1, b2, b3)); 51 static_assert(std::is_same_v<decltype(j1), join_node<tuple_type>>); 52 53 join_node j2(follows(b1, b2, b3), reserving()); 54 static_assert(std::is_same_v<decltype(j2), join_node<tuple_type, reserving>>); 55 56 join_node j3(precedes(b4)); 57 static_assert(std::is_same_v<decltype(j3), join_node<tuple_type>>); 58 59 join_node j4(precedes(b4), reserving()); 60 static_assert(std::is_same_v<decltype(j4), join_node<tuple_type, reserving>>); 61 #endif 62 63 join_node j5(j0); 64 static_assert(std::is_same_v<decltype(j5), join_node<tuple_type>>); 65 } 66 67 #endif 68 69 //! The node that is constructed has a reference to the same graph object as src. 70 //! The list of predecessors, messages in the input ports, and successors are not copied. 71 //! \brief \ref interface 72 TEST_CASE("join_node copy constructor"){ 73 oneapi::tbb::flow::graph g; 74 oneapi::tbb::flow::continue_node<int> node0( g, 75 [](oneapi::tbb::flow::continue_msg) { return 1; } ); 76 77 oneapi::tbb::flow::join_node<std::tuple<int>> node1(g); 78 conformance::test_push_receiver<std::tuple<int>> node2(g); 79 conformance::test_push_receiver<std::tuple<int>> node3(g); 80 81 oneapi::tbb::flow::make_edge(node0, oneapi::tbb::flow::input_port<0>(node1)); 82 oneapi::tbb::flow::make_edge(node1, node2); 83 oneapi::tbb::flow::join_node<std::tuple<int>> node_copy(node1); 84 85 oneapi::tbb::flow::make_edge(node_copy, node3); 86 87 oneapi::tbb::flow::input_port<0>(node_copy).try_put(1); 88 g.wait_for_all(); 89 90 auto values = conformance::get_values(node3); 91 CHECK_MESSAGE((conformance::get_values(node2).size() == 0 && values.size() == 1), "Copied node doesn`t copy successor"); 92 93 node0.try_put(oneapi::tbb::flow::continue_msg()); 94 g.wait_for_all(); 95 96 CHECK_MESSAGE((conformance::get_values(node2).size() == 1 && conformance::get_values(node3).size() == 0), "Copied node doesn`t copy predecessor"); 97 98 oneapi::tbb::flow::remove_edge(node1, node2); 99 oneapi::tbb::flow::input_port<0>(node1).try_put(1); 100 g.wait_for_all(); 101 oneapi::tbb::flow::join_node<std::tuple<int>> node_copy2(node1); 102 oneapi::tbb::flow::make_edge(node_copy2, node3); 103 oneapi::tbb::flow::input_port<0>(node_copy2).try_put(2); 104 g.wait_for_all(); 105 CHECK_MESSAGE((std::get<0>(conformance::get_values(node3)[0]) == 2), "Copied node doesn`t copy messages in the input ports"); 106 } 107 108 //! Test inheritance relations 109 //! \brief \ref interface 110 TEST_CASE("join_node inheritance"){ 111 CHECK_MESSAGE((std::is_base_of<oneapi::tbb::flow::graph_node, 112 oneapi::tbb::flow::join_node<my_input_tuple>>::value), 113 "join_node should be derived from graph_node"); 114 CHECK_MESSAGE((std::is_base_of<oneapi::tbb::flow::sender<my_input_tuple>, 115 oneapi::tbb::flow::join_node<my_input_tuple>>::value), 116 "join_node should be derived from sender<input_tuple>"); 117 } 118 119 //! Test join_node<queueing> behavior and broadcast property 120 //! \brief \ref requirement 121 TEST_CASE("join_node queueing policy and broadcast property") { 122 oneapi::tbb::flow::graph g; 123 oneapi::tbb::flow::function_node<int, int> 124 f1( g, oneapi::tbb::flow::unlimited, [](const int &i) { return i; } ); 125 oneapi::tbb::flow::function_node<float, float> 126 f2( g, oneapi::tbb::flow::unlimited, [](const float &f) { return f; } ); 127 oneapi::tbb::flow::continue_node<input_msg> c1( g, 128 [](oneapi::tbb::flow::continue_msg) { return input_msg(1); } ); 129 130 oneapi::tbb::flow::join_node<my_input_tuple, oneapi::tbb::flow::queueing> testing_node(g); 131 132 conformance::test_push_receiver<my_input_tuple> q_node(g); 133 134 std::atomic<int> number{1}; 135 oneapi::tbb::flow::function_node<my_input_tuple, my_input_tuple> 136 f3( g, oneapi::tbb::flow::unlimited, 137 [&]( const my_input_tuple &t ) { 138 CHECK_MESSAGE((std::get<0>(t) == number), "Messages must be in first-in first-out order" ); 139 CHECK_MESSAGE((std::get<1>(t) == static_cast<float>(number) + 0.5f), "Messages must be in first-in first-out order" ); 140 CHECK_MESSAGE((std::get<2>(t) == 1), "Messages must be in first-in first-out order" ); 141 ++number; 142 return t; 143 } ); 144 145 oneapi::tbb::flow::make_edge(f1, oneapi::tbb::flow::input_port<0>(testing_node)); 146 oneapi::tbb::flow::make_edge(f2, oneapi::tbb::flow::input_port<1>(testing_node)); 147 oneapi::tbb::flow::make_edge(c1, oneapi::tbb::flow::input_port<2>(testing_node)); 148 make_edge(testing_node, f3); 149 make_edge(f3, q_node); 150 151 f1.try_put(1); 152 g.wait_for_all(); 153 CHECK_MESSAGE((get_values(q_node).size() == 0), 154 "join_node must broadcast when there is at least one message at each input port"); 155 f1.try_put(2); 156 f2.try_put(1.5f); 157 g.wait_for_all(); 158 CHECK_MESSAGE((get_values(q_node).size() == 0), 159 "join_node must broadcast when there is at least one message at each input port"); 160 f1.try_put(3); 161 f2.try_put(2.5f); 162 c1.try_put(oneapi::tbb::flow::continue_msg()); 163 g.wait_for_all(); 164 CHECK_MESSAGE((get_values(q_node).size() == 1), 165 "join_node must broadcast when there is at least one message at each input port"); 166 f2.try_put(3.5f); 167 c1.try_put(oneapi::tbb::flow::continue_msg()); 168 g.wait_for_all(); 169 CHECK_MESSAGE((get_values(q_node).size() == 1), 170 "If at least one successor accepts the tuple, the head of each input port’s queue is removed"); 171 c1.try_put(oneapi::tbb::flow::continue_msg()); 172 g.wait_for_all(); 173 CHECK_MESSAGE((get_values(q_node).size() == 1), 174 "If at least one successor accepts the tuple, the head of each input port’s queue is removed"); 175 c1.try_put(oneapi::tbb::flow::continue_msg()); 176 g.wait_for_all(); 177 CHECK_MESSAGE((get_values(q_node).size() == 0), 178 "join_node must broadcast when there is at least one message at each input port"); 179 180 oneapi::tbb::flow::remove_edge(testing_node, f3); 181 182 f1.try_put(1); 183 f2.try_put(1); 184 c1.try_put(oneapi::tbb::flow::continue_msg()); 185 g.wait_for_all(); 186 187 my_input_tuple tmp(0, 0.f, input_msg(0)); 188 CHECK_MESSAGE((testing_node.try_get(tmp)), "If no one successor accepts the tuple the messages\ 189 must remain in their respective input port queues"); 190 CHECK_MESSAGE((tmp == my_input_tuple(1, 1.f, input_msg(1))), "If no one successor accepts the tuple\ 191 the messages must remain in their respective input port queues"); 192 } 193 194 //! Test join_node<reserving> behavior 195 //! \brief \ref requirement 196 TEST_CASE("join_node reserving policy") { 197 conformance::test_with_reserving_join_node_class<oneapi::tbb::flow::write_once_node<int>>(); 198 } 199 200 template<typename KeyType> 201 struct MyHash{ 202 std::size_t hash(const KeyType &k) const { 203 return k * 2000 + 3; 204 } 205 206 bool equal(const KeyType &k1, const KeyType &k2) const{ 207 return hash(k1) == hash(k2); 208 } 209 }; 210 211 //! Test join_node<key_matching> behavior 212 //! \brief \ref requirement 213 TEST_CASE("join_node key_matching policy"){ 214 oneapi::tbb::flow::graph g; 215 auto body1 = [](const oneapi::tbb::flow::continue_msg &) -> int { return 1; }; 216 auto body2 = [](const float &val) -> int { return static_cast<int>(val); }; 217 218 oneapi::tbb::flow::join_node<std::tuple<oneapi::tbb::flow::continue_msg, float>, 219 oneapi::tbb::flow::key_matching<int, MyHash<int>>> testing_node(g, body1, body2); 220 221 oneapi::tbb::flow::input_port<0>(testing_node).try_put(oneapi::tbb::flow::continue_msg()); 222 oneapi::tbb::flow::input_port<1>(testing_node).try_put(1.3f); 223 224 g.wait_for_all(); 225 226 std::tuple<oneapi::tbb::flow::continue_msg, float> tmp; 227 CHECK_MESSAGE((testing_node.try_get(tmp)), "Mapped keys should match.\ 228 If no successor accepts the tuple, it is must been saved and will be forwarded on a subsequent try_get"); 229 CHECK_MESSAGE((!testing_node.try_get(tmp)), "Message should not exist after item is consumed"); 230 } 231 232 //! Test join_node<tag_matching> behavior 233 //! \brief \ref requirement 234 TEST_CASE("join_node tag_matching policy"){ 235 oneapi::tbb::flow::graph g; 236 auto body1 = [](const oneapi::tbb::flow::continue_msg &) -> oneapi::tbb::flow::tag_value { return 1; }; 237 auto body2 = [](const float &val) -> oneapi::tbb::flow::tag_value { return static_cast<oneapi::tbb::flow::tag_value>(val); }; 238 239 oneapi::tbb::flow::join_node<std::tuple<oneapi::tbb::flow::continue_msg, float>, 240 oneapi::tbb::flow::tag_matching> testing_node(g, body1, body2); 241 242 oneapi::tbb::flow::input_port<0>(testing_node).try_put(oneapi::tbb::flow::continue_msg()); 243 oneapi::tbb::flow::input_port<1>(testing_node).try_put(1.3f); 244 245 g.wait_for_all(); 246 247 std::tuple<oneapi::tbb::flow::continue_msg, float> tmp; 248 CHECK_MESSAGE((testing_node.try_get(tmp) == true), "Mapped keys should match"); 249 } 250 251 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 252 //! Test deduction guides 253 //! \brief \ref requirement 254 TEST_CASE("Deduction guides test"){ 255 test_deduction_guides(); 256 } 257 #endif 258 259 //! Test join_node input_ports() returns a tuple of input ports. 260 //! \brief \ref interface \ref requirement 261 TEST_CASE("join_node output_ports") { 262 oneapi::tbb::flow::graph g; 263 oneapi::tbb::flow::join_node<std::tuple<int>> node(g); 264 265 CHECK_MESSAGE((std::is_same<oneapi::tbb::flow::join_node<std::tuple<int>>::input_ports_type&, 266 decltype(node.input_ports())>::value), "join_node input_ports should returns a tuple of input ports"); 267 } 268 269 #if __TBB_CPP17_INVOKE_PRESENT 270 271 template <typename K, typename Body1, typename Body2> 272 void test_invoke_basic(Body1 body1, Body2 body2) { 273 static_assert(std::is_same_v<std::decay_t<K>, std::size_t>, "incorrect test setup"); 274 using namespace oneapi::tbb::flow; 275 auto generator = [](std::size_t n) { return test_invoke::SmartID<std::size_t>(n); }; 276 graph g; 277 278 function_node<std::size_t, test_invoke::SmartID<std::size_t>> f1(g, unlimited, generator); 279 function_node<std::size_t, test_invoke::SmartID<std::size_t>> f2(g, unlimited, generator); 280 281 using tuple_type = std::tuple<test_invoke::SmartID<std::size_t>, test_invoke::SmartID<std::size_t>>; 282 using join_type = join_node<tuple_type, key_matching<K>>; 283 284 285 join_type j(g, body1, body2); 286 287 buffer_node<tuple_type> buf(g); 288 289 make_edge(f1, input_port<0>(j)); 290 make_edge(f2, input_port<1>(j)); 291 make_edge(j, buf); 292 293 std::size_t objects_count = 100; 294 for (std::size_t i = 0; i < objects_count; ++i) { 295 f1.try_put(i); 296 f2.try_put(objects_count - i - 1); 297 } 298 299 g.wait_for_all(); 300 301 std::size_t buf_size = 0; 302 tuple_type tpl; 303 304 while(buf.try_get(tpl)) { 305 ++buf_size; 306 CHECK(std::get<0>(tpl).id == std::get<1>(tpl).id); 307 } 308 CHECK(buf_size == objects_count); 309 } 310 311 //! Test that key_matching join_node uses std::invoke to run the body 312 //! \brief \ref requirement 313 TEST_CASE("key_matching join_node invoke semantics") { 314 test_invoke_basic</*K = */std::size_t>(&test_invoke::SmartID<std::size_t>::get_id, &test_invoke::SmartID<std::size_t>::id); 315 test_invoke_basic</*K = */const std::size_t&>(&test_invoke::SmartID<std::size_t>::get_id_ref, &test_invoke::SmartID<std::size_t>::get_id_ref); 316 } 317 #endif // __TBB_CPP17_INVOKE_PRESENT 318