151c0b2f7Stbbdev /*
2*b15aabb3Stbbdev     Copyright (c) 2020-2021 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
17*b15aabb3Stbbdev #if __INTEL_COMPILER && _MSC_VER
18*b15aabb3Stbbdev #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated
19*b15aabb3Stbbdev #endif
20*b15aabb3Stbbdev 
21*b15aabb3Stbbdev 
2251c0b2f7Stbbdev #include "common/test.h"
2351c0b2f7Stbbdev 
2451c0b2f7Stbbdev #include "common/utils.h"
2551c0b2f7Stbbdev #include "common/graph_utils.h"
2651c0b2f7Stbbdev 
2749e08aacStbbdev #include "oneapi/tbb/flow_graph.h"
2849e08aacStbbdev #include "oneapi/tbb/task_arena.h"
2949e08aacStbbdev #include "oneapi/tbb/global_control.h"
3051c0b2f7Stbbdev 
3151c0b2f7Stbbdev #include "conformance_flowgraph.h"
3251c0b2f7Stbbdev 
3351c0b2f7Stbbdev //! \file conformance_multifunction_node.cpp
3451c0b2f7Stbbdev //! \brief Test for [flow_graph.function_node] specification
3551c0b2f7Stbbdev 
3651c0b2f7Stbbdev /*
3751c0b2f7Stbbdev TODO: implement missing conformance tests for multifunction_node:
3851c0b2f7Stbbdev   - [ ] Implement test_forwarding that checks messages are broadcast to all the successors connected
3951c0b2f7Stbbdev     to the output port the message is being sent to. And check that the value passed is the
4051c0b2f7Stbbdev     actual one received.
4151c0b2f7Stbbdev   - [ ] Explicit test for copy constructor of the node.
4251c0b2f7Stbbdev   - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body>
4351c0b2f7Stbbdev     multifunction_node( graph &g, size_t concurrency, Body body, Policy(), node_priority_t priority = no_priority )'.
4451c0b2f7Stbbdev   - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is
4549e08aacStbbdev     important to test at least on five values: 1, oneapi::tbb::flow::serial, `max_allowed_parallelism'
4649e08aacStbbdev     obtained from `oneapi::tbb::global_control', `oneapi::tbb::flow::unlimited', and, if `max allowed
4751c0b2f7Stbbdev     parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism]
4851c0b2f7Stbbdev     interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary).
4951c0b2f7Stbbdev   - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads;
5051c0b2f7Stbbdev     add check that `try_put()' returns `false'
5151c0b2f7Stbbdev   - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call).
5251c0b2f7Stbbdev   - [ ] `output_ports_type' is defined and accessible by the user.
5351c0b2f7Stbbdev   - [ ] Explicit test on `mfn::output_ports()' method.
5451c0b2f7Stbbdev   - [ ] The copy constructor and copy assignment are called for the node's input and output types.
5551c0b2f7Stbbdev   - [ ] Add CTAD test.
5651c0b2f7Stbbdev */
5751c0b2f7Stbbdev 
5851c0b2f7Stbbdev template< typename OutputType >
5951c0b2f7Stbbdev struct mf_functor {
6051c0b2f7Stbbdev 
6151c0b2f7Stbbdev     std::atomic<std::size_t>& local_execute_count;
6251c0b2f7Stbbdev 
6351c0b2f7Stbbdev     mf_functor(std::atomic<std::size_t>& execute_count ) :
6451c0b2f7Stbbdev         local_execute_count (execute_count)
6551c0b2f7Stbbdev     {  }
6651c0b2f7Stbbdev 
6751c0b2f7Stbbdev     mf_functor( const mf_functor &f ) : local_execute_count(f.local_execute_count) { }
6851c0b2f7Stbbdev     void operator=(const mf_functor &f) { local_execute_count = std::size_t(f.local_execute_count); }
6951c0b2f7Stbbdev 
7049e08aacStbbdev     void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
7151c0b2f7Stbbdev        ++local_execute_count;
7251c0b2f7Stbbdev        std::get<0>(op).try_put(argument);
7351c0b2f7Stbbdev     }
7451c0b2f7Stbbdev 
7551c0b2f7Stbbdev };
7651c0b2f7Stbbdev 
7751c0b2f7Stbbdev template<typename I, typename O>
7851c0b2f7Stbbdev void test_inheritance(){
7949e08aacStbbdev     using namespace oneapi::tbb::flow;
8051c0b2f7Stbbdev 
8151c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<graph_node, multifunction_node<I, O>>::value), "multifunction_node should be derived from graph_node");
8251c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<receiver<I>, multifunction_node<I, O>>::value), "multifunction_node should be derived from receiver<Input>");
8351c0b2f7Stbbdev }
8451c0b2f7Stbbdev 
8551c0b2f7Stbbdev void test_multifunc_body(){
8649e08aacStbbdev     oneapi::tbb::flow::graph g;
8751c0b2f7Stbbdev     std::atomic<size_t> local_count(0);
8851c0b2f7Stbbdev     mf_functor<std::tuple<int>> fun(local_count);
8951c0b2f7Stbbdev 
9049e08aacStbbdev     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>, oneapi::tbb::flow::rejecting> node1(g, oneapi::tbb::flow::unlimited, fun);
9151c0b2f7Stbbdev 
9251c0b2f7Stbbdev     const size_t n = 10;
9351c0b2f7Stbbdev     for(size_t i = 0; i < n; ++i) {
9451c0b2f7Stbbdev         CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true");
9551c0b2f7Stbbdev     }
9651c0b2f7Stbbdev     g.wait_for_all();
9751c0b2f7Stbbdev 
9851c0b2f7Stbbdev     CHECK_MESSAGE( (local_count == n), "Body of the node needs to be executed N times");
9951c0b2f7Stbbdev }
10051c0b2f7Stbbdev 
10151c0b2f7Stbbdev template<typename I, typename O>
10251c0b2f7Stbbdev struct CopyCounterBody{
10351c0b2f7Stbbdev     size_t copy_count;
10451c0b2f7Stbbdev 
10551c0b2f7Stbbdev     CopyCounterBody():
10651c0b2f7Stbbdev         copy_count(0) {}
10751c0b2f7Stbbdev 
10851c0b2f7Stbbdev     CopyCounterBody(const CopyCounterBody<I, O>& other):
10951c0b2f7Stbbdev         copy_count(other.copy_count + 1) {}
11051c0b2f7Stbbdev 
11151c0b2f7Stbbdev     CopyCounterBody& operator=(const CopyCounterBody<I, O>& other)
11251c0b2f7Stbbdev     { copy_count = other.copy_count + 1; return *this;}
11351c0b2f7Stbbdev 
11449e08aacStbbdev     void operator()( const I& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
11551c0b2f7Stbbdev        std::get<0>(op).try_put(argument);
11651c0b2f7Stbbdev     }
11751c0b2f7Stbbdev };
11851c0b2f7Stbbdev 
11951c0b2f7Stbbdev void test_copies(){
12049e08aacStbbdev      using namespace oneapi::tbb::flow;
12151c0b2f7Stbbdev 
12251c0b2f7Stbbdev      CopyCounterBody<int, std::tuple<int>> b;
12351c0b2f7Stbbdev 
12451c0b2f7Stbbdev      graph g;
12551c0b2f7Stbbdev      multifunction_node<int, std::tuple<int>> fn(g, unlimited, b);
12651c0b2f7Stbbdev 
12751c0b2f7Stbbdev      CopyCounterBody<int, std::tuple<int>> b2 = copy_body<CopyCounterBody<int, std::tuple<int>>,
12851c0b2f7Stbbdev                                                           multifunction_node<int, std::tuple<int>>>(fn);
12951c0b2f7Stbbdev 
13051c0b2f7Stbbdev      CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
13151c0b2f7Stbbdev }
13251c0b2f7Stbbdev 
13351c0b2f7Stbbdev template< typename OutputType >
13451c0b2f7Stbbdev struct id_functor {
13549e08aacStbbdev     void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
13651c0b2f7Stbbdev        std::get<0>(op).try_put(argument);
13751c0b2f7Stbbdev     }
13851c0b2f7Stbbdev };
13951c0b2f7Stbbdev 
14051c0b2f7Stbbdev void test_forwarding(){
14149e08aacStbbdev     oneapi::tbb::flow::graph g;
14251c0b2f7Stbbdev     id_functor<int> fun;
14351c0b2f7Stbbdev 
14449e08aacStbbdev     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> node1(g, oneapi::tbb::flow::unlimited, fun);
14551c0b2f7Stbbdev     test_push_receiver<int> node2(g);
14651c0b2f7Stbbdev     test_push_receiver<int> node3(g);
14751c0b2f7Stbbdev 
14849e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node2);
14949e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node3);
15051c0b2f7Stbbdev 
15151c0b2f7Stbbdev     node1.try_put(1);
15251c0b2f7Stbbdev     g.wait_for_all();
15351c0b2f7Stbbdev 
15451c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message.");
15551c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message.");
15651c0b2f7Stbbdev }
15751c0b2f7Stbbdev 
15851c0b2f7Stbbdev void test_rejecting_buffering(){
15949e08aacStbbdev     oneapi::tbb::flow::graph g;
16051c0b2f7Stbbdev     id_functor<int> fun;
16151c0b2f7Stbbdev 
16249e08aacStbbdev     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>, oneapi::tbb::flow::rejecting> node(g, oneapi::tbb::flow::unlimited, fun);
16349e08aacStbbdev     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
16451c0b2f7Stbbdev 
16549e08aacStbbdev     oneapi::tbb::flow::make_edge(node, rejecter);
16651c0b2f7Stbbdev     node.try_put(1);
16751c0b2f7Stbbdev 
16851c0b2f7Stbbdev     int tmp = -1;
16951c0b2f7Stbbdev     CHECK_MESSAGE( (std::get<0>(node.output_ports()).try_get(tmp) == false), "try_get after rejection should not succeed");
17051c0b2f7Stbbdev     CHECK_MESSAGE( (tmp == -1), "try_get after rejection should alter passed value");
17151c0b2f7Stbbdev     g.wait_for_all();
17251c0b2f7Stbbdev }
17351c0b2f7Stbbdev 
17451c0b2f7Stbbdev void test_policy_ctors(){
17549e08aacStbbdev     using namespace oneapi::tbb::flow;
17651c0b2f7Stbbdev     graph g;
17751c0b2f7Stbbdev 
17851c0b2f7Stbbdev     id_functor<int> fun;
17951c0b2f7Stbbdev 
18049e08aacStbbdev     multifunction_node<int, std::tuple<int>, lightweight> lw_node(g, oneapi::tbb::flow::serial, fun);
18149e08aacStbbdev     multifunction_node<int, std::tuple<int>, queueing_lightweight> qlw_node(g, oneapi::tbb::flow::serial, fun);
18249e08aacStbbdev     multifunction_node<int, std::tuple<int>, rejecting_lightweight> rlw_node(g, oneapi::tbb::flow::serial, fun);
18351c0b2f7Stbbdev 
18451c0b2f7Stbbdev }
18551c0b2f7Stbbdev 
18651c0b2f7Stbbdev std::atomic<size_t> my_concurrency;
18751c0b2f7Stbbdev std::atomic<size_t> my_max_concurrency;
18851c0b2f7Stbbdev 
18951c0b2f7Stbbdev struct concurrency_functor {
19049e08aacStbbdev     void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
19151c0b2f7Stbbdev         ++my_concurrency;
19251c0b2f7Stbbdev 
19351c0b2f7Stbbdev         size_t old_value = my_max_concurrency;
19451c0b2f7Stbbdev         while(my_max_concurrency < my_concurrency &&
19551c0b2f7Stbbdev               !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency))
19651c0b2f7Stbbdev             ;
19751c0b2f7Stbbdev 
19851c0b2f7Stbbdev         size_t ms = 1000;
19951c0b2f7Stbbdev         std::chrono::milliseconds sleep_time( ms );
20051c0b2f7Stbbdev         std::this_thread::sleep_for( sleep_time );
20151c0b2f7Stbbdev 
20251c0b2f7Stbbdev         --my_concurrency;
20351c0b2f7Stbbdev         std::get<0>(op).try_put(argument);
20451c0b2f7Stbbdev     }
20551c0b2f7Stbbdev 
20651c0b2f7Stbbdev };
20751c0b2f7Stbbdev 
20851c0b2f7Stbbdev void test_node_concurrency(){
20951c0b2f7Stbbdev     my_concurrency = 0;
21051c0b2f7Stbbdev     my_max_concurrency = 0;
21151c0b2f7Stbbdev 
21249e08aacStbbdev     oneapi::tbb::flow::graph g;
21351c0b2f7Stbbdev 
21451c0b2f7Stbbdev     concurrency_functor counter;
21549e08aacStbbdev     oneapi::tbb::flow::multifunction_node <int, std::tuple<int>> fnode(g, oneapi::tbb::flow::serial, counter);
21651c0b2f7Stbbdev 
21751c0b2f7Stbbdev     test_push_receiver<int> sink(g);
21851c0b2f7Stbbdev 
21951c0b2f7Stbbdev     make_edge(std::get<0>(fnode.output_ports()), sink);
22051c0b2f7Stbbdev 
22151c0b2f7Stbbdev     for(int i = 0; i < 10; ++i){
22251c0b2f7Stbbdev         fnode.try_put(i);
22351c0b2f7Stbbdev     }
22451c0b2f7Stbbdev 
22551c0b2f7Stbbdev     g.wait_for_all();
22651c0b2f7Stbbdev     CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism over limit");
22751c0b2f7Stbbdev }
22851c0b2f7Stbbdev 
22951c0b2f7Stbbdev 
23051c0b2f7Stbbdev void test_priority(){
23151c0b2f7Stbbdev     size_t concurrency_limit = 1;
23249e08aacStbbdev     oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit);
23351c0b2f7Stbbdev 
23449e08aacStbbdev     oneapi::tbb::flow::graph g;
23551c0b2f7Stbbdev 
23649e08aacStbbdev     oneapi::tbb::flow::continue_node<int> source(g,
23749e08aacStbbdev                                          [](oneapi::tbb::flow::continue_msg){ return 1;});
23849e08aacStbbdev     source.try_put(oneapi::tbb::flow::continue_msg());
23951c0b2f7Stbbdev 
24051c0b2f7Stbbdev     first_functor<int>::first_id = -1;
24151c0b2f7Stbbdev     first_functor<int> low_functor(1);
24251c0b2f7Stbbdev     first_functor<int> high_functor(2);
24351c0b2f7Stbbdev 
24449e08aacStbbdev     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> high(g, oneapi::tbb::flow::unlimited, high_functor, oneapi::tbb::flow::node_priority_t(1));
24549e08aacStbbdev     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> low(g, oneapi::tbb::flow::unlimited, low_functor);
24651c0b2f7Stbbdev 
24751c0b2f7Stbbdev     make_edge(source, low);
24851c0b2f7Stbbdev     make_edge(source, high);
24951c0b2f7Stbbdev 
25051c0b2f7Stbbdev     g.wait_for_all();
25151c0b2f7Stbbdev 
25251c0b2f7Stbbdev     CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first");
25351c0b2f7Stbbdev }
25451c0b2f7Stbbdev 
25551c0b2f7Stbbdev void test_rejecting(){
25649e08aacStbbdev     oneapi::tbb::flow::graph g;
25749e08aacStbbdev     oneapi::tbb::flow::multifunction_node <int, std::tuple<int>, oneapi::tbb::flow::rejecting> fnode(g, oneapi::tbb::flow::serial,
25849e08aacStbbdev                                                                     [&](const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ){
25951c0b2f7Stbbdev                                                                         size_t ms = 50;
26051c0b2f7Stbbdev                                                                         std::chrono::milliseconds sleep_time( ms );
26151c0b2f7Stbbdev                                                                         std::this_thread::sleep_for( sleep_time );
26251c0b2f7Stbbdev                                                                         std::get<0>(op).try_put(argument);
26351c0b2f7Stbbdev                                                                     });
26451c0b2f7Stbbdev 
26551c0b2f7Stbbdev     test_push_receiver<int> sink(g);
26651c0b2f7Stbbdev 
26751c0b2f7Stbbdev     make_edge(std::get<0>(fnode.output_ports()), sink);
26851c0b2f7Stbbdev 
26951c0b2f7Stbbdev     for(int i = 0; i < 10; ++i){
27051c0b2f7Stbbdev         fnode.try_put(i);
27151c0b2f7Stbbdev     }
27251c0b2f7Stbbdev 
27351c0b2f7Stbbdev     g.wait_for_all();
27451c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed");
27551c0b2f7Stbbdev }
27651c0b2f7Stbbdev 
27751c0b2f7Stbbdev //! Test multifunction_node with rejecting policy
27851c0b2f7Stbbdev //! \brief \ref interface
27951c0b2f7Stbbdev TEST_CASE("multifunction_node with rejecting policy"){
28051c0b2f7Stbbdev     test_rejecting();
28151c0b2f7Stbbdev }
28251c0b2f7Stbbdev 
28351c0b2f7Stbbdev //! Test priorities
28451c0b2f7Stbbdev //! \brief \ref interface
28551c0b2f7Stbbdev TEST_CASE("multifunction_node priority"){
28651c0b2f7Stbbdev     test_priority();
28751c0b2f7Stbbdev }
28851c0b2f7Stbbdev 
28951c0b2f7Stbbdev //! Test concurrency
29051c0b2f7Stbbdev //! \brief \ref interface
29151c0b2f7Stbbdev TEST_CASE("multifunction_node concurrency"){
29251c0b2f7Stbbdev     test_node_concurrency();
29351c0b2f7Stbbdev }
29451c0b2f7Stbbdev 
29551c0b2f7Stbbdev //! Test constructors
29651c0b2f7Stbbdev //! \brief \ref interface
29751c0b2f7Stbbdev TEST_CASE("multifunction_node constructors"){
29851c0b2f7Stbbdev     test_policy_ctors();
29951c0b2f7Stbbdev }
30051c0b2f7Stbbdev 
30151c0b2f7Stbbdev //! Test function_node buffering
30251c0b2f7Stbbdev //! \brief \ref requirement
30351c0b2f7Stbbdev TEST_CASE("multifunction_node buffering"){
30451c0b2f7Stbbdev     test_rejecting_buffering();
30551c0b2f7Stbbdev }
30651c0b2f7Stbbdev 
30751c0b2f7Stbbdev //! Test function_node broadcasting
30851c0b2f7Stbbdev //! \brief \ref requirement
30951c0b2f7Stbbdev TEST_CASE("multifunction_node broadcast"){
31051c0b2f7Stbbdev     test_forwarding();
31151c0b2f7Stbbdev }
31251c0b2f7Stbbdev 
31351c0b2f7Stbbdev //! Test body copying and copy_body logic
31451c0b2f7Stbbdev //! \brief \ref interface
31551c0b2f7Stbbdev TEST_CASE("multifunction_node constructors"){
31651c0b2f7Stbbdev     test_copies();
31751c0b2f7Stbbdev }
31851c0b2f7Stbbdev 
31951c0b2f7Stbbdev //! Test calling function body
32051c0b2f7Stbbdev //! \brief \ref interface \ref requirement
32151c0b2f7Stbbdev TEST_CASE("multifunction_node body") {
32251c0b2f7Stbbdev     test_multifunc_body();
32351c0b2f7Stbbdev }
32451c0b2f7Stbbdev 
32551c0b2f7Stbbdev //! Test inheritance relations
32651c0b2f7Stbbdev //! \brief \ref interface
32751c0b2f7Stbbdev TEST_CASE("multifunction_node superclasses"){
32851c0b2f7Stbbdev     test_inheritance<int, std::tuple<int>>();
32951c0b2f7Stbbdev     test_inheritance<void*, std::tuple<float>>();
33051c0b2f7Stbbdev }
331