151c0b2f7Stbbdev /* 2*b15aabb3Stbbdev Copyright (c) 2005-2021 Intel Corporation 351c0b2f7Stbbdev 451c0b2f7Stbbdev Licensed under the Apache License, Version 2.0 (the "License"); 551c0b2f7Stbbdev you may not use this file except in compliance with the License. 651c0b2f7Stbbdev You may obtain a copy of the License at 751c0b2f7Stbbdev 851c0b2f7Stbbdev http://www.apache.org/licenses/LICENSE-2.0 951c0b2f7Stbbdev 1051c0b2f7Stbbdev Unless required by applicable law or agreed to in writing, software 1151c0b2f7Stbbdev distributed under the License is distributed on an "AS IS" BASIS, 1251c0b2f7Stbbdev WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1351c0b2f7Stbbdev See the License for the specific language governing permissions and 1451c0b2f7Stbbdev limitations under the License. 1551c0b2f7Stbbdev */ 1651c0b2f7Stbbdev 1751c0b2f7Stbbdev #include "common/config.h" 1851c0b2f7Stbbdev 1951c0b2f7Stbbdev #include "test_join_node.h" 2051c0b2f7Stbbdev 2151c0b2f7Stbbdev 2251c0b2f7Stbbdev //! \file test_join_node.cpp 2351c0b2f7Stbbdev //! \brief Test for [flow_graph.join_node] specification 2451c0b2f7Stbbdev 2551c0b2f7Stbbdev 2651c0b2f7Stbbdev static std::atomic<int> output_count; 2751c0b2f7Stbbdev 2851c0b2f7Stbbdev // get the tag from the output tuple and emit it. 2951c0b2f7Stbbdev // the first tuple component is tag * 2 cast to the type 3051c0b2f7Stbbdev template<typename OutputTupleType> 3151c0b2f7Stbbdev class recirc_output_func_body { 3251c0b2f7Stbbdev public: 3351c0b2f7Stbbdev // we only need this to use input_node_helper 3451c0b2f7Stbbdev typedef typename tbb::flow::join_node<OutputTupleType, tbb::flow::tag_matching> join_node_type; 3551c0b2f7Stbbdev static const int N = std::tuple_size<OutputTupleType>::value; 3651c0b2f7Stbbdev int operator()(const OutputTupleType &v) { 3751c0b2f7Stbbdev int out = int(std::get<0>(v))/2; 3851c0b2f7Stbbdev input_node_helper<N, join_node_type>::only_check_value(out, v); 3951c0b2f7Stbbdev ++output_count; 4051c0b2f7Stbbdev return out; 4151c0b2f7Stbbdev } 4251c0b2f7Stbbdev }; 4351c0b2f7Stbbdev 4451c0b2f7Stbbdev template<typename JType> 4551c0b2f7Stbbdev class tag_recirculation_test { 4651c0b2f7Stbbdev public: 4751c0b2f7Stbbdev typedef typename JType::output_type TType; 4851c0b2f7Stbbdev typedef typename std::tuple<int, tbb::flow::continue_msg> input_tuple_type; 4951c0b2f7Stbbdev typedef tbb::flow::join_node<input_tuple_type, tbb::flow::reserving> input_join_type; 5051c0b2f7Stbbdev static const int N = std::tuple_size<TType>::value; 5151c0b2f7Stbbdev static void test() { 5251c0b2f7Stbbdev input_node_helper<N, JType>::print_remark("Recirculation test of tag-matching join"); 5351c0b2f7Stbbdev INFO(" >\n"); 5451c0b2f7Stbbdev for(int maxTag = 1; maxTag <10; maxTag *= 3) { 5551c0b2f7Stbbdev for(int i = 0; i < N; ++i) all_input_nodes[i][0] = NULL; 5651c0b2f7Stbbdev 5751c0b2f7Stbbdev tbb::flow::graph g; 5851c0b2f7Stbbdev // this is the tag-matching join we're testing 5951c0b2f7Stbbdev JType * my_join = makeJoin<N, JType, tbb::flow::tag_matching>::create(g); 6051c0b2f7Stbbdev // input_node for continue messages 6151c0b2f7Stbbdev tbb::flow::input_node<tbb::flow::continue_msg> snode(g, recirc_input_node_body()); 6251c0b2f7Stbbdev // reserving join that matches recirculating tags with continue messages. 6351c0b2f7Stbbdev input_join_type * my_input_join = makeJoin<2, input_join_type, tbb::flow::reserving>::create(g); 6451c0b2f7Stbbdev // tbb::flow::make_edge(snode, tbb::flow::input_port<1>(*my_input_join)); 6551c0b2f7Stbbdev tbb::flow::make_edge(snode, std::get<1>(my_input_join->input_ports())); 6651c0b2f7Stbbdev // queue to hold the tags 6751c0b2f7Stbbdev tbb::flow::queue_node<int> tag_queue(g); 6851c0b2f7Stbbdev tbb::flow::make_edge(tag_queue, tbb::flow::input_port<0>(*my_input_join)); 6951c0b2f7Stbbdev // add all the function_nodes that are inputs to the tag-matching join 7051c0b2f7Stbbdev input_node_helper<N, JType>::add_recirc_func_nodes(*my_join, *my_input_join, g); 7151c0b2f7Stbbdev // add the function_node that accepts the output of the join and emits the int tag it was based on 7251c0b2f7Stbbdev tbb::flow::function_node<TType, int> recreate_tag(g, tbb::flow::unlimited, recirc_output_func_body<TType>()); 7351c0b2f7Stbbdev tbb::flow::make_edge(*my_join, recreate_tag); 7451c0b2f7Stbbdev // now the recirculating part (output back to the queue) 7551c0b2f7Stbbdev tbb::flow::make_edge(recreate_tag, tag_queue); 7651c0b2f7Stbbdev 7751c0b2f7Stbbdev // put the tags into the queue 7851c0b2f7Stbbdev for(int t = 1; t<=maxTag; ++t) tag_queue.try_put(t); 7951c0b2f7Stbbdev 8051c0b2f7Stbbdev input_count = Recirc_count; 8151c0b2f7Stbbdev output_count = 0; 8251c0b2f7Stbbdev 8351c0b2f7Stbbdev // start up the source node to get things going 8451c0b2f7Stbbdev snode.activate(); 8551c0b2f7Stbbdev 8651c0b2f7Stbbdev // wait for everything to stop 8751c0b2f7Stbbdev g.wait_for_all(); 8851c0b2f7Stbbdev 8951c0b2f7Stbbdev CHECK_MESSAGE( (output_count==Recirc_count), "not all instances were received"); 9051c0b2f7Stbbdev 9151c0b2f7Stbbdev int j{}; 9251c0b2f7Stbbdev // grab the tags from the queue, record them 9351c0b2f7Stbbdev std::vector<bool> out_tally(maxTag, false); 9451c0b2f7Stbbdev for(int i = 0; i < maxTag; ++i) { 9551c0b2f7Stbbdev CHECK_MESSAGE( (tag_queue.try_get(j)), "not enough tags in queue"); 9651c0b2f7Stbbdev CHECK_MESSAGE( (!out_tally.at(j-1)), "duplicate tag from queue"); 9751c0b2f7Stbbdev out_tally[j-1] = true; 9851c0b2f7Stbbdev } 9951c0b2f7Stbbdev CHECK_MESSAGE( (!tag_queue.try_get(j)), "Extra tags in recirculation queue"); 10051c0b2f7Stbbdev 10151c0b2f7Stbbdev // deconstruct graph 10251c0b2f7Stbbdev input_node_helper<N, JType>::remove_recirc_func_nodes(*my_join, *my_input_join); 10351c0b2f7Stbbdev tbb::flow::remove_edge(*my_join, recreate_tag); 10451c0b2f7Stbbdev makeJoin<N, JType, tbb::flow::tag_matching>::destroy(my_join); 10551c0b2f7Stbbdev tbb::flow::remove_edge(tag_queue, tbb::flow::input_port<0>(*my_input_join)); 10651c0b2f7Stbbdev tbb::flow::remove_edge(snode, tbb::flow::input_port<1>(*my_input_join)); 10751c0b2f7Stbbdev makeJoin<2, input_join_type, tbb::flow::reserving>::destroy(my_input_join); 10851c0b2f7Stbbdev } 10951c0b2f7Stbbdev } 11051c0b2f7Stbbdev }; 11151c0b2f7Stbbdev 11251c0b2f7Stbbdev template<typename JType> 11351c0b2f7Stbbdev class generate_recirc_test { 11451c0b2f7Stbbdev public: 11551c0b2f7Stbbdev typedef tbb::flow::join_node<JType, tbb::flow::tag_matching> join_node_type; 11651c0b2f7Stbbdev static void do_test() { 11751c0b2f7Stbbdev tag_recirculation_test<join_node_type>::test(); 11851c0b2f7Stbbdev } 11951c0b2f7Stbbdev }; 12051c0b2f7Stbbdev 12151c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET 12251c0b2f7Stbbdev #include <array> 12351c0b2f7Stbbdev #include <vector> 12451c0b2f7Stbbdev void test_follows_and_precedes_api() { 12551c0b2f7Stbbdev using msg_t = tbb::flow::continue_msg; 12651c0b2f7Stbbdev using JoinOutputType = std::tuple<msg_t, msg_t, msg_t>; 12751c0b2f7Stbbdev 12851c0b2f7Stbbdev std::array<msg_t, 3> messages_for_follows = { {msg_t(), msg_t(), msg_t()} }; 12951c0b2f7Stbbdev std::vector<msg_t> messages_for_precedes = {msg_t(), msg_t(), msg_t()}; 13051c0b2f7Stbbdev 13151c0b2f7Stbbdev follows_and_precedes_testing::test_follows 13251c0b2f7Stbbdev <msg_t, tbb::flow::join_node<JoinOutputType>, tbb::flow::buffer_node<msg_t>>(messages_for_follows); 13351c0b2f7Stbbdev follows_and_precedes_testing::test_follows 13451c0b2f7Stbbdev <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::queueing>>(messages_for_follows); 13551c0b2f7Stbbdev follows_and_precedes_testing::test_follows 13651c0b2f7Stbbdev <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::reserving>, tbb::flow::buffer_node<msg_t>>(messages_for_follows); 13751c0b2f7Stbbdev auto b = [](msg_t) { return msg_t(); }; 13851c0b2f7Stbbdev class hash_compare { 13951c0b2f7Stbbdev public: 14051c0b2f7Stbbdev std::size_t hash(msg_t) const { return 0; } 14151c0b2f7Stbbdev bool equal(msg_t, msg_t) const { return true; } 14251c0b2f7Stbbdev }; 14351c0b2f7Stbbdev follows_and_precedes_testing::test_follows 14451c0b2f7Stbbdev <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::key_matching<msg_t, hash_compare>>, tbb::flow::buffer_node<msg_t>> 14551c0b2f7Stbbdev (messages_for_follows, b, b, b); 14651c0b2f7Stbbdev 14751c0b2f7Stbbdev follows_and_precedes_testing::test_precedes 14851c0b2f7Stbbdev <msg_t, tbb::flow::join_node<JoinOutputType>>(messages_for_precedes); 14951c0b2f7Stbbdev follows_and_precedes_testing::test_precedes 15051c0b2f7Stbbdev <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::queueing>>(messages_for_precedes); 15151c0b2f7Stbbdev follows_and_precedes_testing::test_precedes 15251c0b2f7Stbbdev <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::reserving>>(messages_for_precedes); 15351c0b2f7Stbbdev follows_and_precedes_testing::test_precedes 15451c0b2f7Stbbdev <msg_t, tbb::flow::join_node<JoinOutputType, tbb::flow::key_matching<msg_t, hash_compare>>> 15551c0b2f7Stbbdev (messages_for_precedes, b, b, b); 15651c0b2f7Stbbdev } 15751c0b2f7Stbbdev #endif 15851c0b2f7Stbbdev 15951c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 16051c0b2f7Stbbdev void test_deduction_guides() { 16151c0b2f7Stbbdev using namespace tbb::flow; 16251c0b2f7Stbbdev 16351c0b2f7Stbbdev graph g; 16451c0b2f7Stbbdev using tuple_type = std::tuple<int, int, int>; 16551c0b2f7Stbbdev broadcast_node<int> b1(g), b2(g), b3(g); 16651c0b2f7Stbbdev broadcast_node<tuple_type> b4(g); 16751c0b2f7Stbbdev join_node<tuple_type> j0(g); 16851c0b2f7Stbbdev 16951c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET 17051c0b2f7Stbbdev join_node j1(follows(b1, b2, b3)); 17151c0b2f7Stbbdev static_assert(std::is_same_v<decltype(j1), join_node<tuple_type>>); 17251c0b2f7Stbbdev 17351c0b2f7Stbbdev join_node j2(follows(b1, b2, b3), reserving()); 17451c0b2f7Stbbdev static_assert(std::is_same_v<decltype(j2), join_node<tuple_type, reserving>>); 17551c0b2f7Stbbdev 17651c0b2f7Stbbdev join_node j3(precedes(b4)); 17751c0b2f7Stbbdev static_assert(std::is_same_v<decltype(j3), join_node<tuple_type>>); 17851c0b2f7Stbbdev 17951c0b2f7Stbbdev join_node j4(precedes(b4), reserving()); 18051c0b2f7Stbbdev static_assert(std::is_same_v<decltype(j4), join_node<tuple_type, reserving>>); 18151c0b2f7Stbbdev #endif 18251c0b2f7Stbbdev 18351c0b2f7Stbbdev join_node j5(j0); 18451c0b2f7Stbbdev static_assert(std::is_same_v<decltype(j5), join_node<tuple_type>>); 18551c0b2f7Stbbdev } 18651c0b2f7Stbbdev 18751c0b2f7Stbbdev #endif 18851c0b2f7Stbbdev 18951c0b2f7Stbbdev namespace multiple_predecessors { 19051c0b2f7Stbbdev 19151c0b2f7Stbbdev using namespace tbb::flow; 19251c0b2f7Stbbdev 19351c0b2f7Stbbdev using join_node_t = join_node<std::tuple<continue_msg, continue_msg, continue_msg>, reserving>; 19451c0b2f7Stbbdev using queue_node_t = queue_node<std::tuple<continue_msg, continue_msg, continue_msg>>; 19551c0b2f7Stbbdev 19651c0b2f7Stbbdev void twist_join_connections( 19751c0b2f7Stbbdev buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2, buffer_node<continue_msg>& bn3, 19851c0b2f7Stbbdev join_node_t& jn) 19951c0b2f7Stbbdev { 20051c0b2f7Stbbdev // order, in which edges are created/destroyed, is important 20151c0b2f7Stbbdev make_edge(bn1, input_port<0>(jn)); 20251c0b2f7Stbbdev make_edge(bn2, input_port<0>(jn)); 20351c0b2f7Stbbdev make_edge(bn3, input_port<0>(jn)); 20451c0b2f7Stbbdev 20551c0b2f7Stbbdev remove_edge(bn3, input_port<0>(jn)); 20651c0b2f7Stbbdev make_edge (bn3, input_port<2>(jn)); 20751c0b2f7Stbbdev 20851c0b2f7Stbbdev remove_edge(bn2, input_port<0>(jn)); 20951c0b2f7Stbbdev make_edge (bn2, input_port<1>(jn)); 21051c0b2f7Stbbdev } 21151c0b2f7Stbbdev 21251c0b2f7Stbbdev std::unique_ptr<join_node_t> connect_join_via_make_edge( 21351c0b2f7Stbbdev graph& g, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2, 21451c0b2f7Stbbdev buffer_node<continue_msg>& bn3, queue_node_t& qn) 21551c0b2f7Stbbdev { 21651c0b2f7Stbbdev std::unique_ptr<join_node_t> jn( new join_node_t(g) ); 21751c0b2f7Stbbdev twist_join_connections( bn1, bn2, bn3, *jn ); 21851c0b2f7Stbbdev make_edge(*jn, qn); 21951c0b2f7Stbbdev return jn; 22051c0b2f7Stbbdev } 22151c0b2f7Stbbdev 22251c0b2f7Stbbdev #if TBB_PREVIEW_FLOW_GRAPH_FEATURES 22351c0b2f7Stbbdev std::unique_ptr<join_node_t> connect_join_via_follows( 22451c0b2f7Stbbdev graph&, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2, 22551c0b2f7Stbbdev buffer_node<continue_msg>& bn3, queue_node_t& qn) 22651c0b2f7Stbbdev { 22751c0b2f7Stbbdev auto bn_set = make_node_set(bn1, bn2, bn3); 22851c0b2f7Stbbdev std::unique_ptr<join_node_t> jn( new join_node_t(follows(bn_set)) ); 22951c0b2f7Stbbdev make_edge(*jn, qn); 23051c0b2f7Stbbdev return jn; 23151c0b2f7Stbbdev } 23251c0b2f7Stbbdev 23351c0b2f7Stbbdev std::unique_ptr<join_node_t> connect_join_via_precedes( 23451c0b2f7Stbbdev graph&, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2, 23551c0b2f7Stbbdev buffer_node<continue_msg>& bn3, queue_node_t& qn) 23651c0b2f7Stbbdev { 23751c0b2f7Stbbdev auto qn_set = make_node_set(qn); 23851c0b2f7Stbbdev auto qn_copy_set = qn_set; 23951c0b2f7Stbbdev std::unique_ptr<join_node_t> jn( new join_node_t(precedes(qn_copy_set)) ); 24051c0b2f7Stbbdev twist_join_connections( bn1, bn2, bn3, *jn ); 24151c0b2f7Stbbdev return jn; 24251c0b2f7Stbbdev } 24351c0b2f7Stbbdev #endif // TBB_PREVIEW_FLOW_GRAPH_FEATURES 24451c0b2f7Stbbdev 24551c0b2f7Stbbdev void run_and_check( 24651c0b2f7Stbbdev graph& g, buffer_node<continue_msg>& bn1, buffer_node<continue_msg>& bn2, 24751c0b2f7Stbbdev buffer_node<continue_msg>& bn3, queue_node_t& qn, bool expected) 24851c0b2f7Stbbdev { 24951c0b2f7Stbbdev std::tuple<continue_msg, continue_msg, continue_msg> msg; 25051c0b2f7Stbbdev 25151c0b2f7Stbbdev bn1.try_put(continue_msg()); 25251c0b2f7Stbbdev bn2.try_put(continue_msg()); 25351c0b2f7Stbbdev bn3.try_put(continue_msg()); 25451c0b2f7Stbbdev g.wait_for_all(); 25551c0b2f7Stbbdev 25651c0b2f7Stbbdev CHECK_MESSAGE( 25751c0b2f7Stbbdev (qn.try_get(msg) == expected), 25851c0b2f7Stbbdev "Unexpected message absence/existence at the end of the graph." 25951c0b2f7Stbbdev ); 26051c0b2f7Stbbdev } 26151c0b2f7Stbbdev 26251c0b2f7Stbbdev template<typename ConnectJoinNodeFunc> 26351c0b2f7Stbbdev void test(ConnectJoinNodeFunc&& connect_join_node) { 26451c0b2f7Stbbdev graph g; 26551c0b2f7Stbbdev buffer_node<continue_msg> bn1(g); 26651c0b2f7Stbbdev buffer_node<continue_msg> bn2(g); 26751c0b2f7Stbbdev buffer_node<continue_msg> bn3(g); 26851c0b2f7Stbbdev queue_node_t qn(g); 26951c0b2f7Stbbdev 27051c0b2f7Stbbdev auto jn = connect_join_node(g, bn1, bn2, bn3, qn); 27151c0b2f7Stbbdev 27251c0b2f7Stbbdev run_and_check(g, bn1, bn2, bn3, qn, /*expected=*/true); 27351c0b2f7Stbbdev 27451c0b2f7Stbbdev remove_edge(bn3, input_port<2>(*jn)); 27551c0b2f7Stbbdev remove_edge(bn2, input_port<1>(*jn)); 27651c0b2f7Stbbdev remove_edge(bn1, input_port<0>(*jn)); 27751c0b2f7Stbbdev remove_edge(*jn, qn); 27851c0b2f7Stbbdev 27951c0b2f7Stbbdev run_and_check(g, bn1, bn2, bn3, qn, /*expected=*/false); 28051c0b2f7Stbbdev } 28151c0b2f7Stbbdev } // namespace multiple_predecessors 28251c0b2f7Stbbdev 28351c0b2f7Stbbdev 28451c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET 28551c0b2f7Stbbdev //! Test follows and precedes API 28651c0b2f7Stbbdev //! \brief \ref error_guessing 28751c0b2f7Stbbdev TEST_CASE("Test follows and preceedes API"){ 28851c0b2f7Stbbdev test_follows_and_precedes_api(); 28951c0b2f7Stbbdev } 29051c0b2f7Stbbdev #endif 29151c0b2f7Stbbdev 29251c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 29351c0b2f7Stbbdev //! Test deduction guides 29451c0b2f7Stbbdev //! \brief \ref requirement 29551c0b2f7Stbbdev TEST_CASE("Deduction guides test"){ 29651c0b2f7Stbbdev test_deduction_guides(); 29751c0b2f7Stbbdev } 29851c0b2f7Stbbdev #endif 29951c0b2f7Stbbdev 30051c0b2f7Stbbdev //! Test hash buffers behavior 30151c0b2f7Stbbdev //! \brief \ref error_guessing 30251c0b2f7Stbbdev TEST_CASE("Tagged buffers test"){ 30351c0b2f7Stbbdev TestTaggedBuffers(); 30451c0b2f7Stbbdev } 30551c0b2f7Stbbdev 30651c0b2f7Stbbdev //! Test with various policies and tuple sizes 30751c0b2f7Stbbdev //! \brief \ref error_guessing 30851c0b2f7Stbbdev TEST_CASE("Main test"){ 30951c0b2f7Stbbdev test_main<tbb::flow::queueing>(); 31051c0b2f7Stbbdev test_main<tbb::flow::reserving>(); 31151c0b2f7Stbbdev test_main<tbb::flow::tag_matching>(); 31251c0b2f7Stbbdev } 31351c0b2f7Stbbdev 31451c0b2f7Stbbdev //! Test with recirculating tags 31551c0b2f7Stbbdev //! \brief \ref error_guessing 31651c0b2f7Stbbdev TEST_CASE("Recirculation test"){ 31751c0b2f7Stbbdev generate_recirc_test<std::tuple<int,float> >::do_test(); 31851c0b2f7Stbbdev } 31951c0b2f7Stbbdev 32051c0b2f7Stbbdev //! Test maintaining correct count of ports without input 32151c0b2f7Stbbdev //! \brief \ref error_guessing 32251c0b2f7Stbbdev TEST_CASE("Test removal of the predecessor while having none") { 32351c0b2f7Stbbdev using namespace multiple_predecessors; 32451c0b2f7Stbbdev 32551c0b2f7Stbbdev test(connect_join_via_make_edge); 32651c0b2f7Stbbdev #if TBB_PREVIEW_FLOW_GRAPH_FEATURES 32751c0b2f7Stbbdev test(connect_join_via_follows); 32851c0b2f7Stbbdev test(connect_join_via_precedes); 32951c0b2f7Stbbdev #endif 33051c0b2f7Stbbdev } 331