xref: /oneTBB/test/tbb/test_join_node.cpp (revision 57f524ca)
151c0b2f7Stbbdev /*
2e0cc5187SIlya Isaev     Copyright (c) 2005-2022 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
1764715f43SIlya Mishin #ifdef TBB_TEST_LOW_WORKLOAD
1864715f43SIlya Mishin     #undef MAX_TUPLE_TEST_SIZE
1964715f43SIlya Mishin     #define MAX_TUPLE_TEST_SIZE 3
2064715f43SIlya Mishin #endif
2164715f43SIlya Mishin 
2251c0b2f7Stbbdev #include "common/config.h"
2351c0b2f7Stbbdev 
2451c0b2f7Stbbdev #include "test_join_node.h"
25e0cc5187SIlya Isaev #include "common/test_join_node_multiple_predecessors.h"
2651c0b2f7Stbbdev 
2751c0b2f7Stbbdev //! \file test_join_node.cpp
2851c0b2f7Stbbdev //! \brief Test for [flow_graph.join_node] specification
2951c0b2f7Stbbdev 
3051c0b2f7Stbbdev static std::atomic<int> output_count;
3151c0b2f7Stbbdev 
3251c0b2f7Stbbdev // get the tag from the output tuple and emit it.
3351c0b2f7Stbbdev // the first tuple component is tag * 2 cast to the type
3451c0b2f7Stbbdev template<typename OutputTupleType>
3551c0b2f7Stbbdev class recirc_output_func_body {
3651c0b2f7Stbbdev public:
3751c0b2f7Stbbdev     // we only need this to use input_node_helper
3851c0b2f7Stbbdev     typedef typename tbb::flow::join_node<OutputTupleType, tbb::flow::tag_matching> join_node_type;
3951c0b2f7Stbbdev     static const int N = std::tuple_size<OutputTupleType>::value;
operator ()(const OutputTupleType & v)4051c0b2f7Stbbdev     int operator()(const OutputTupleType &v) {
4151c0b2f7Stbbdev         int out = int(std::get<0>(v))/2;
4251c0b2f7Stbbdev         input_node_helper<N, join_node_type>::only_check_value(out, v);
4351c0b2f7Stbbdev         ++output_count;
4451c0b2f7Stbbdev         return out;
4551c0b2f7Stbbdev     }
4651c0b2f7Stbbdev };
4751c0b2f7Stbbdev 
4851c0b2f7Stbbdev template<typename JType>
4951c0b2f7Stbbdev class tag_recirculation_test {
5051c0b2f7Stbbdev public:
5151c0b2f7Stbbdev     typedef typename JType::output_type TType;
5251c0b2f7Stbbdev     typedef typename std::tuple<int, tbb::flow::continue_msg> input_tuple_type;
5351c0b2f7Stbbdev     typedef tbb::flow::join_node<input_tuple_type, tbb::flow::reserving> input_join_type;
5451c0b2f7Stbbdev     static const int N = std::tuple_size<TType>::value;
test()5551c0b2f7Stbbdev     static void test() {
5651c0b2f7Stbbdev         input_node_helper<N, JType>::print_remark("Recirculation test of tag-matching join");
5751c0b2f7Stbbdev         INFO(" >\n");
5851c0b2f7Stbbdev         for(int maxTag = 1; maxTag <10; maxTag *= 3) {
59*57f524caSIlya Isaev             for(int i = 0; i < N; ++i) all_input_nodes[i][0] = nullptr;
6051c0b2f7Stbbdev 
6151c0b2f7Stbbdev             tbb::flow::graph g;
6251c0b2f7Stbbdev             // this is the tag-matching join we're testing
6351c0b2f7Stbbdev             JType * my_join = makeJoin<N, JType, tbb::flow::tag_matching>::create(g);
6451c0b2f7Stbbdev             // input_node for continue messages
6551c0b2f7Stbbdev             tbb::flow::input_node<tbb::flow::continue_msg> snode(g, recirc_input_node_body());
6651c0b2f7Stbbdev             // reserving join that matches recirculating tags with continue messages.
6751c0b2f7Stbbdev             input_join_type * my_input_join = makeJoin<2, input_join_type, tbb::flow::reserving>::create(g);
6851c0b2f7Stbbdev             // tbb::flow::make_edge(snode, tbb::flow::input_port<1>(*my_input_join));
6951c0b2f7Stbbdev             tbb::flow::make_edge(snode, std::get<1>(my_input_join->input_ports()));
7051c0b2f7Stbbdev             // queue to hold the tags
7151c0b2f7Stbbdev             tbb::flow::queue_node<int> tag_queue(g);
7251c0b2f7Stbbdev             tbb::flow::make_edge(tag_queue, tbb::flow::input_port<0>(*my_input_join));
7351c0b2f7Stbbdev             // add all the function_nodes that are inputs to the tag-matching join
7451c0b2f7Stbbdev             input_node_helper<N, JType>::add_recirc_func_nodes(*my_join, *my_input_join, g);
7551c0b2f7Stbbdev             // add the function_node that accepts the output of the join and emits the int tag it was based on
7651c0b2f7Stbbdev             tbb::flow::function_node<TType, int> recreate_tag(g, tbb::flow::unlimited, recirc_output_func_body<TType>());
7751c0b2f7Stbbdev             tbb::flow::make_edge(*my_join, recreate_tag);
7851c0b2f7Stbbdev             // now the recirculating part (output back to the queue)
7951c0b2f7Stbbdev             tbb::flow::make_edge(recreate_tag, tag_queue);
8051c0b2f7Stbbdev 
8151c0b2f7Stbbdev             // put the tags into the queue
8251c0b2f7Stbbdev             for(int t = 1; t<=maxTag; ++t) tag_queue.try_put(t);
8351c0b2f7Stbbdev 
8451c0b2f7Stbbdev             input_count = Recirc_count;
8551c0b2f7Stbbdev             output_count = 0;
8651c0b2f7Stbbdev 
8751c0b2f7Stbbdev             // start up the source node to get things going
8851c0b2f7Stbbdev             snode.activate();
8951c0b2f7Stbbdev 
9051c0b2f7Stbbdev             // wait for everything to stop
9151c0b2f7Stbbdev             g.wait_for_all();
9251c0b2f7Stbbdev 
9351c0b2f7Stbbdev             CHECK_MESSAGE( (output_count==Recirc_count), "not all instances were received");
9451c0b2f7Stbbdev 
9551c0b2f7Stbbdev             int j{};
9651c0b2f7Stbbdev             // grab the tags from the queue, record them
9751c0b2f7Stbbdev             std::vector<bool> out_tally(maxTag, false);
9851c0b2f7Stbbdev             for(int i = 0; i < maxTag; ++i) {
9951c0b2f7Stbbdev                 CHECK_MESSAGE( (tag_queue.try_get(j)), "not enough tags in queue");
10051c0b2f7Stbbdev                 CHECK_MESSAGE( (!out_tally.at(j-1)), "duplicate tag from queue");
10151c0b2f7Stbbdev                 out_tally[j-1] = true;
10251c0b2f7Stbbdev             }
10351c0b2f7Stbbdev             CHECK_MESSAGE( (!tag_queue.try_get(j)), "Extra tags in recirculation queue");
10451c0b2f7Stbbdev 
10551c0b2f7Stbbdev             // deconstruct graph
10651c0b2f7Stbbdev             input_node_helper<N, JType>::remove_recirc_func_nodes(*my_join, *my_input_join);
10751c0b2f7Stbbdev             tbb::flow::remove_edge(*my_join, recreate_tag);
10851c0b2f7Stbbdev             makeJoin<N, JType, tbb::flow::tag_matching>::destroy(my_join);
10951c0b2f7Stbbdev             tbb::flow::remove_edge(tag_queue, tbb::flow::input_port<0>(*my_input_join));
11051c0b2f7Stbbdev             tbb::flow::remove_edge(snode, tbb::flow::input_port<1>(*my_input_join));
11151c0b2f7Stbbdev             makeJoin<2, input_join_type, tbb::flow::reserving>::destroy(my_input_join);
11251c0b2f7Stbbdev         }
11351c0b2f7Stbbdev     }
11451c0b2f7Stbbdev };
11551c0b2f7Stbbdev 
11651c0b2f7Stbbdev template<typename JType>
11751c0b2f7Stbbdev class generate_recirc_test {
11851c0b2f7Stbbdev public:
11951c0b2f7Stbbdev     typedef tbb::flow::join_node<JType, tbb::flow::tag_matching> join_node_type;
do_test()12051c0b2f7Stbbdev     static void do_test() {
12151c0b2f7Stbbdev         tag_recirculation_test<join_node_type>::test();
12251c0b2f7Stbbdev     }
12351c0b2f7Stbbdev };
12451c0b2f7Stbbdev 
12551c0b2f7Stbbdev //! Test hash buffers behavior
12651c0b2f7Stbbdev //! \brief \ref error_guessing
12751c0b2f7Stbbdev TEST_CASE("Tagged buffers test"){
12851c0b2f7Stbbdev     TestTaggedBuffers();
12951c0b2f7Stbbdev }
13051c0b2f7Stbbdev 
13151c0b2f7Stbbdev //! Test with various policies and tuple sizes
13251c0b2f7Stbbdev //! \brief \ref error_guessing
13351c0b2f7Stbbdev TEST_CASE("Main test"){
13451c0b2f7Stbbdev     test_main<tbb::flow::queueing>();
13551c0b2f7Stbbdev     test_main<tbb::flow::reserving>();
13651c0b2f7Stbbdev     test_main<tbb::flow::tag_matching>();
13751c0b2f7Stbbdev }
13851c0b2f7Stbbdev 
13951c0b2f7Stbbdev //! Test with recirculating tags
14051c0b2f7Stbbdev //! \brief \ref error_guessing
14151c0b2f7Stbbdev TEST_CASE("Recirculation test"){
14251c0b2f7Stbbdev     generate_recirc_test<std::tuple<int,float> >::do_test();
14351c0b2f7Stbbdev }
14451c0b2f7Stbbdev 
145e0cc5187SIlya Isaev // TODO: Look deeper into this test to see if it has the right name
146e0cc5187SIlya Isaev // and if it actually tests some kind of regression. It is possible
147e0cc5187SIlya Isaev // that `connect_join_via_follows` and `connect_join_via_precedes`
148e0cc5187SIlya Isaev // functions are redundant.
149e0cc5187SIlya Isaev 
15051c0b2f7Stbbdev //! Test maintaining correct count of ports without input
15151c0b2f7Stbbdev //! \brief \ref error_guessing
15251c0b2f7Stbbdev TEST_CASE("Test removal of the predecessor while having none") {
15351c0b2f7Stbbdev     using namespace multiple_predecessors;
15451c0b2f7Stbbdev 
15551c0b2f7Stbbdev     test(connect_join_via_make_edge);
15651c0b2f7Stbbdev }
157