151c0b2f7Stbbdev /*
2*b15aabb3Stbbdev     Copyright (c) 2020-2021 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
17*b15aabb3Stbbdev #if __INTEL_COMPILER && _MSC_VER
18*b15aabb3Stbbdev #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated
19*b15aabb3Stbbdev #endif
2051c0b2f7Stbbdev 
2151c0b2f7Stbbdev #include "common/test.h"
2251c0b2f7Stbbdev 
2351c0b2f7Stbbdev #include "common/utils.h"
2451c0b2f7Stbbdev #include "common/graph_utils.h"
2551c0b2f7Stbbdev 
2649e08aacStbbdev #include "oneapi/tbb/flow_graph.h"
2749e08aacStbbdev #include "oneapi/tbb/task_arena.h"
2849e08aacStbbdev #include "oneapi/tbb/global_control.h"
2951c0b2f7Stbbdev 
3051c0b2f7Stbbdev #include "conformance_flowgraph.h"
3151c0b2f7Stbbdev 
3251c0b2f7Stbbdev //! \file conformance_join_node.cpp
3351c0b2f7Stbbdev //! \brief Test for [flow_graph.join_node] specification
3451c0b2f7Stbbdev 
3551c0b2f7Stbbdev /*
3651c0b2f7Stbbdev TODO: implement missing conformance tests for join_node:
3751c0b2f7Stbbdev   - [ ] Check that `OutputTuple' is an instantiation of a tuple.
3851c0b2f7Stbbdev   - [ ] The copy constructor and copy assignment are called for each type within the `OutputTuple'.
3951c0b2f7Stbbdev   - [ ] Check all possible policies of the node: `reserving', `key_matching', `queueing',
4051c0b2f7Stbbdev     `tag_matching'. Check the semantics the node has with each policy separately.
4151c0b2f7Stbbdev   - [ ] Check that corresponding methods are invoked in specified `KHash' type.
4251c0b2f7Stbbdev   - [ ] Improve test for constructors, including their availability based on used Policy for the
4351c0b2f7Stbbdev     node.
4451c0b2f7Stbbdev   - [ ] Unify code style in the test by extracting the implementation from the `TEST_CASE' scope
4551c0b2f7Stbbdev     into separate functions.
4651c0b2f7Stbbdev   - [ ] Check that corresponding methods mentioned in the requirements are called for `Bi' types.
4751c0b2f7Stbbdev   - [ ] Explicitly check that `input_ports_type' is defined, accessible and is a tuple of
4851c0b2f7Stbbdev     corresponding to `OutputTuple' receivers.
4951c0b2f7Stbbdev   - [ ] Explicitly check the method `join_node::input_ports()' exists, is accessible and it returns
5051c0b2f7Stbbdev     a reference to the `input_ports_type' type.
5151c0b2f7Stbbdev   - [ ] Implement `test_buffering' (for node policy).
5251c0b2f7Stbbdev   - [ ] Check `try_get()' copies the generated tuple into passed argument and returns `true'. If
5351c0b2f7Stbbdev     node is empty returns `false'.
5451c0b2f7Stbbdev   - [ ] Check `tag_value' is defined and has properties specified.
5551c0b2f7Stbbdev   - [ ] Add test for CTAD.
5651c0b2f7Stbbdev */
5751c0b2f7Stbbdev 
5849e08aacStbbdev using namespace oneapi::tbb::flow;
5951c0b2f7Stbbdev using namespace std;
6051c0b2f7Stbbdev 
6151c0b2f7Stbbdev template<typename T>
6251c0b2f7Stbbdev void test_inheritance(){
6351c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<graph_node, join_node<std::tuple<T, T>>>::value), "join_node should be derived from graph_node");
6451c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<sender<std::tuple<T, T>>, join_node<std::tuple<T, T>>>::value), "join_node should be derived from graph_node");
6551c0b2f7Stbbdev }
6651c0b2f7Stbbdev 
6751c0b2f7Stbbdev void test_copies(){
6849e08aacStbbdev     using namespace oneapi::tbb::flow;
6951c0b2f7Stbbdev 
7051c0b2f7Stbbdev     graph g;
7151c0b2f7Stbbdev     join_node<std::tuple<int, int>> n(g);
7251c0b2f7Stbbdev     join_node<std::tuple<int, int>> n2(n);
7351c0b2f7Stbbdev 
7449e08aacStbbdev     join_node <std::tuple<int, int, oneapi::tbb::flow::reserving>> nr(g);
7549e08aacStbbdev     join_node <std::tuple<int, int, oneapi::tbb::flow::reserving>> nr2(nr);
7651c0b2f7Stbbdev }
7751c0b2f7Stbbdev 
7851c0b2f7Stbbdev void test_forwarding(){
7949e08aacStbbdev     oneapi::tbb::flow::graph g;
8051c0b2f7Stbbdev 
8151c0b2f7Stbbdev     join_node<std::tuple<int, int>> node1(g);
8251c0b2f7Stbbdev 
8351c0b2f7Stbbdev     using output_t = join_node<std::tuple<int, int>>::output_type;
8451c0b2f7Stbbdev 
8551c0b2f7Stbbdev     test_push_receiver<output_t> node2(g);
8651c0b2f7Stbbdev     test_push_receiver<output_t> node3(g);
8751c0b2f7Stbbdev 
8849e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node2);
8949e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node3);
9051c0b2f7Stbbdev 
9151c0b2f7Stbbdev     input_port<0>(node1).try_put(1);
9251c0b2f7Stbbdev     input_port<1>(node1).try_put(1);
9351c0b2f7Stbbdev 
9451c0b2f7Stbbdev     g.wait_for_all();
9551c0b2f7Stbbdev 
9651c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node needs to be receive N messages");
9751c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message.");
9851c0b2f7Stbbdev }
9951c0b2f7Stbbdev 
10051c0b2f7Stbbdev //! Test broadcast
10151c0b2f7Stbbdev //! \brief \ref interface
10251c0b2f7Stbbdev TEST_CASE("join_node broadcast") {
10351c0b2f7Stbbdev     test_forwarding();
10451c0b2f7Stbbdev }
10551c0b2f7Stbbdev 
10651c0b2f7Stbbdev 
10751c0b2f7Stbbdev //! Test copy constructor
10851c0b2f7Stbbdev //! \brief \ref interface
10951c0b2f7Stbbdev TEST_CASE("join_node copy constructor") {
11051c0b2f7Stbbdev     test_copies();
11151c0b2f7Stbbdev }
11251c0b2f7Stbbdev 
11351c0b2f7Stbbdev //! Test inheritance relations
11451c0b2f7Stbbdev //! \brief \ref interface
11551c0b2f7Stbbdev TEST_CASE("join_node inheritance"){
11651c0b2f7Stbbdev     test_inheritance<int>();
11751c0b2f7Stbbdev }
11851c0b2f7Stbbdev 
11951c0b2f7Stbbdev //! Test join_node behavior
12051c0b2f7Stbbdev //! \brief \ref requirement
12151c0b2f7Stbbdev TEST_CASE("join_node") {
12251c0b2f7Stbbdev     graph g;
12351c0b2f7Stbbdev     function_node<int,int>
12451c0b2f7Stbbdev         f1( g, unlimited, [](const int &i) { return 2*i; } );
12551c0b2f7Stbbdev     function_node<float,float>
12651c0b2f7Stbbdev         f2( g, unlimited, [](const float &f) { return f/2; } );
12751c0b2f7Stbbdev 
12851c0b2f7Stbbdev     join_node< std::tuple<int,float> > j(g);
12951c0b2f7Stbbdev 
13051c0b2f7Stbbdev     function_node< std::tuple<int,float> >
13151c0b2f7Stbbdev         f3( g, unlimited,
13251c0b2f7Stbbdev             []( const std::tuple<int,float> &t ) {
13351c0b2f7Stbbdev                 CHECK_MESSAGE( (std::get<0>(t) == 6), "Expected to receive 6" );
13451c0b2f7Stbbdev                 CHECK_MESSAGE( (std::get<1>(t) == 1.5), "Expected to receive 1.5" );
13551c0b2f7Stbbdev             } );
13651c0b2f7Stbbdev 
13751c0b2f7Stbbdev     make_edge( f1, input_port<0>( j ) );
13851c0b2f7Stbbdev     make_edge( f2, input_port<1>( j ) );
13951c0b2f7Stbbdev     make_edge( j, f3 );
14051c0b2f7Stbbdev 
14151c0b2f7Stbbdev     f1.try_put( 3 );
14251c0b2f7Stbbdev     f2.try_put( 3 );
14351c0b2f7Stbbdev     g.wait_for_all( );
14451c0b2f7Stbbdev }
14551c0b2f7Stbbdev 
14651c0b2f7Stbbdev //! Test join_node key matching behavior
14751c0b2f7Stbbdev //! \brief \ref requirement
14851c0b2f7Stbbdev TEST_CASE("remove edge to join_node"){
14951c0b2f7Stbbdev     graph g;
15051c0b2f7Stbbdev     continue_node<int> c(g, [](const continue_msg&){ return 1; });
15151c0b2f7Stbbdev     join_node<tuple<int> > jn(g);
15251c0b2f7Stbbdev     queue_node<tuple<int> > q(g);
15351c0b2f7Stbbdev 
15451c0b2f7Stbbdev     make_edge(jn, q);
15551c0b2f7Stbbdev 
15651c0b2f7Stbbdev     make_edge(c, jn);
15751c0b2f7Stbbdev 
15851c0b2f7Stbbdev     c.try_put(continue_msg());
15951c0b2f7Stbbdev     g.wait_for_all();
16051c0b2f7Stbbdev 
16151c0b2f7Stbbdev     tuple<int> tmp = tuple<int>(0);
16251c0b2f7Stbbdev     CHECK_MESSAGE( (q.try_get(tmp)== true), "Message should pass when edge exists");
16351c0b2f7Stbbdev     CHECK_MESSAGE( (tmp == tuple<int>(1) ), "Message should pass when edge exists");
16451c0b2f7Stbbdev     CHECK_MESSAGE( (q.try_get(tmp)== false), "Message should not pass after item is consumed");
16551c0b2f7Stbbdev 
16651c0b2f7Stbbdev     remove_edge(c, jn);
16751c0b2f7Stbbdev 
16851c0b2f7Stbbdev     c.try_put(continue_msg());
16951c0b2f7Stbbdev     g.wait_for_all();
17051c0b2f7Stbbdev 
17151c0b2f7Stbbdev     tmp = tuple<int>(0);
17251c0b2f7Stbbdev     CHECK_MESSAGE( (q.try_get(tmp)== false), "Message should not pass when edge doesn't exist");
17351c0b2f7Stbbdev     CHECK_MESSAGE( (tmp == tuple<int>(0)), "Value should not be altered");
17451c0b2f7Stbbdev }
17551c0b2f7Stbbdev 
17651c0b2f7Stbbdev //! Test join_node key matching behavior
17751c0b2f7Stbbdev //! \brief \ref requirement
17851c0b2f7Stbbdev TEST_CASE("join_node key_matching"){
17951c0b2f7Stbbdev     graph g;
18051c0b2f7Stbbdev     auto body1 = [](const continue_msg &) -> int { return 1; };
18151c0b2f7Stbbdev     auto body2 = [](const double &val) -> int { return int(val); };
18251c0b2f7Stbbdev 
18351c0b2f7Stbbdev     join_node<std::tuple<continue_msg, double>, key_matching<int>> jn(g, body1, body2);
18451c0b2f7Stbbdev 
18551c0b2f7Stbbdev     input_port<0>(jn).try_put(continue_msg());
18651c0b2f7Stbbdev     input_port<1>(jn).try_put(1.3);
18751c0b2f7Stbbdev 
18851c0b2f7Stbbdev     g.wait_for_all( );
18951c0b2f7Stbbdev 
19051c0b2f7Stbbdev     tuple<continue_msg, double> tmp;
19151c0b2f7Stbbdev     CHECK_MESSAGE( (jn.try_get(tmp) == true), "Mapped keys should match");
19251c0b2f7Stbbdev }
193