151c0b2f7Stbbdev /*
251c0b2f7Stbbdev     Copyright (c) 2020 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
1751c0b2f7Stbbdev #include "common/test.h"
1851c0b2f7Stbbdev 
1951c0b2f7Stbbdev #include "common/utils.h"
2051c0b2f7Stbbdev #include "common/graph_utils.h"
2151c0b2f7Stbbdev 
22*49e08aacStbbdev #include "oneapi/tbb/flow_graph.h"
23*49e08aacStbbdev #include "oneapi/tbb/task_arena.h"
24*49e08aacStbbdev #include "oneapi/tbb/global_control.h"
2551c0b2f7Stbbdev 
2651c0b2f7Stbbdev #include "conformance_flowgraph.h"
2751c0b2f7Stbbdev 
2851c0b2f7Stbbdev //! \file conformance_multifunction_node.cpp
2951c0b2f7Stbbdev //! \brief Test for [flow_graph.function_node] specification
3051c0b2f7Stbbdev 
3151c0b2f7Stbbdev /*
3251c0b2f7Stbbdev TODO: implement missing conformance tests for multifunction_node:
3351c0b2f7Stbbdev   - [ ] Implement test_forwarding that checks messages are broadcast to all the successors connected
3451c0b2f7Stbbdev     to the output port the message is being sent to. And check that the value passed is the
3551c0b2f7Stbbdev     actual one received.
3651c0b2f7Stbbdev   - [ ] Explicit test for copy constructor of the node.
3751c0b2f7Stbbdev   - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body>
3851c0b2f7Stbbdev     multifunction_node( graph &g, size_t concurrency, Body body, Policy(), node_priority_t priority = no_priority )'.
3951c0b2f7Stbbdev   - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is
40*49e08aacStbbdev     important to test at least on five values: 1, oneapi::tbb::flow::serial, `max_allowed_parallelism'
41*49e08aacStbbdev     obtained from `oneapi::tbb::global_control', `oneapi::tbb::flow::unlimited', and, if `max allowed
4251c0b2f7Stbbdev     parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism]
4351c0b2f7Stbbdev     interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary).
4451c0b2f7Stbbdev   - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads;
4551c0b2f7Stbbdev     add check that `try_put()' returns `false'
4651c0b2f7Stbbdev   - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call).
4751c0b2f7Stbbdev   - [ ] `output_ports_type' is defined and accessible by the user.
4851c0b2f7Stbbdev   - [ ] Explicit test on `mfn::output_ports()' method.
4951c0b2f7Stbbdev   - [ ] The copy constructor and copy assignment are called for the node's input and output types.
5051c0b2f7Stbbdev   - [ ] Add CTAD test.
5151c0b2f7Stbbdev */
5251c0b2f7Stbbdev 
5351c0b2f7Stbbdev template< typename OutputType >
5451c0b2f7Stbbdev struct mf_functor {
5551c0b2f7Stbbdev 
5651c0b2f7Stbbdev     std::atomic<std::size_t>& local_execute_count;
5751c0b2f7Stbbdev 
5851c0b2f7Stbbdev     mf_functor(std::atomic<std::size_t>& execute_count ) :
5951c0b2f7Stbbdev         local_execute_count (execute_count)
6051c0b2f7Stbbdev     {  }
6151c0b2f7Stbbdev 
6251c0b2f7Stbbdev     mf_functor( const mf_functor &f ) : local_execute_count(f.local_execute_count) { }
6351c0b2f7Stbbdev     void operator=(const mf_functor &f) { local_execute_count = std::size_t(f.local_execute_count); }
6451c0b2f7Stbbdev 
65*49e08aacStbbdev     void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
6651c0b2f7Stbbdev        ++local_execute_count;
6751c0b2f7Stbbdev        std::get<0>(op).try_put(argument);
6851c0b2f7Stbbdev     }
6951c0b2f7Stbbdev 
7051c0b2f7Stbbdev };
7151c0b2f7Stbbdev 
7251c0b2f7Stbbdev template<typename I, typename O>
7351c0b2f7Stbbdev void test_inheritance(){
74*49e08aacStbbdev     using namespace oneapi::tbb::flow;
7551c0b2f7Stbbdev 
7651c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<graph_node, multifunction_node<I, O>>::value), "multifunction_node should be derived from graph_node");
7751c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<receiver<I>, multifunction_node<I, O>>::value), "multifunction_node should be derived from receiver<Input>");
7851c0b2f7Stbbdev }
7951c0b2f7Stbbdev 
8051c0b2f7Stbbdev void test_multifunc_body(){
81*49e08aacStbbdev     oneapi::tbb::flow::graph g;
8251c0b2f7Stbbdev     std::atomic<size_t> local_count(0);
8351c0b2f7Stbbdev     mf_functor<std::tuple<int>> fun(local_count);
8451c0b2f7Stbbdev 
85*49e08aacStbbdev     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>, oneapi::tbb::flow::rejecting> node1(g, oneapi::tbb::flow::unlimited, fun);
8651c0b2f7Stbbdev 
8751c0b2f7Stbbdev     const size_t n = 10;
8851c0b2f7Stbbdev     for(size_t i = 0; i < n; ++i) {
8951c0b2f7Stbbdev         CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true");
9051c0b2f7Stbbdev     }
9151c0b2f7Stbbdev     g.wait_for_all();
9251c0b2f7Stbbdev 
9351c0b2f7Stbbdev     CHECK_MESSAGE( (local_count == n), "Body of the node needs to be executed N times");
9451c0b2f7Stbbdev }
9551c0b2f7Stbbdev 
9651c0b2f7Stbbdev template<typename I, typename O>
9751c0b2f7Stbbdev struct CopyCounterBody{
9851c0b2f7Stbbdev     size_t copy_count;
9951c0b2f7Stbbdev 
10051c0b2f7Stbbdev     CopyCounterBody():
10151c0b2f7Stbbdev         copy_count(0) {}
10251c0b2f7Stbbdev 
10351c0b2f7Stbbdev     CopyCounterBody(const CopyCounterBody<I, O>& other):
10451c0b2f7Stbbdev         copy_count(other.copy_count + 1) {}
10551c0b2f7Stbbdev 
10651c0b2f7Stbbdev     CopyCounterBody& operator=(const CopyCounterBody<I, O>& other)
10751c0b2f7Stbbdev     { copy_count = other.copy_count + 1; return *this;}
10851c0b2f7Stbbdev 
109*49e08aacStbbdev     void operator()( const I& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
11051c0b2f7Stbbdev        std::get<0>(op).try_put(argument);
11151c0b2f7Stbbdev     }
11251c0b2f7Stbbdev };
11351c0b2f7Stbbdev 
11451c0b2f7Stbbdev void test_copies(){
115*49e08aacStbbdev      using namespace oneapi::tbb::flow;
11651c0b2f7Stbbdev 
11751c0b2f7Stbbdev      CopyCounterBody<int, std::tuple<int>> b;
11851c0b2f7Stbbdev 
11951c0b2f7Stbbdev      graph g;
12051c0b2f7Stbbdev      multifunction_node<int, std::tuple<int>> fn(g, unlimited, b);
12151c0b2f7Stbbdev 
12251c0b2f7Stbbdev      CopyCounterBody<int, std::tuple<int>> b2 = copy_body<CopyCounterBody<int, std::tuple<int>>,
12351c0b2f7Stbbdev                                                           multifunction_node<int, std::tuple<int>>>(fn);
12451c0b2f7Stbbdev 
12551c0b2f7Stbbdev      CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
12651c0b2f7Stbbdev }
12751c0b2f7Stbbdev 
12851c0b2f7Stbbdev template< typename OutputType >
12951c0b2f7Stbbdev struct id_functor {
130*49e08aacStbbdev     void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
13151c0b2f7Stbbdev        std::get<0>(op).try_put(argument);
13251c0b2f7Stbbdev     }
13351c0b2f7Stbbdev };
13451c0b2f7Stbbdev 
13551c0b2f7Stbbdev void test_forwarding(){
136*49e08aacStbbdev     oneapi::tbb::flow::graph g;
13751c0b2f7Stbbdev     id_functor<int> fun;
13851c0b2f7Stbbdev 
139*49e08aacStbbdev     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> node1(g, oneapi::tbb::flow::unlimited, fun);
14051c0b2f7Stbbdev     test_push_receiver<int> node2(g);
14151c0b2f7Stbbdev     test_push_receiver<int> node3(g);
14251c0b2f7Stbbdev 
143*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node2);
144*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node3);
14551c0b2f7Stbbdev 
14651c0b2f7Stbbdev     node1.try_put(1);
14751c0b2f7Stbbdev     g.wait_for_all();
14851c0b2f7Stbbdev 
14951c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message.");
15051c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message.");
15151c0b2f7Stbbdev }
15251c0b2f7Stbbdev 
15351c0b2f7Stbbdev void test_rejecting_buffering(){
154*49e08aacStbbdev     oneapi::tbb::flow::graph g;
15551c0b2f7Stbbdev     id_functor<int> fun;
15651c0b2f7Stbbdev 
157*49e08aacStbbdev     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>, oneapi::tbb::flow::rejecting> node(g, oneapi::tbb::flow::unlimited, fun);
158*49e08aacStbbdev     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
15951c0b2f7Stbbdev 
160*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node, rejecter);
16151c0b2f7Stbbdev     node.try_put(1);
16251c0b2f7Stbbdev 
16351c0b2f7Stbbdev     int tmp = -1;
16451c0b2f7Stbbdev     CHECK_MESSAGE( (std::get<0>(node.output_ports()).try_get(tmp) == false), "try_get after rejection should not succeed");
16551c0b2f7Stbbdev     CHECK_MESSAGE( (tmp == -1), "try_get after rejection should alter passed value");
16651c0b2f7Stbbdev     g.wait_for_all();
16751c0b2f7Stbbdev }
16851c0b2f7Stbbdev 
16951c0b2f7Stbbdev void test_policy_ctors(){
170*49e08aacStbbdev     using namespace oneapi::tbb::flow;
17151c0b2f7Stbbdev     graph g;
17251c0b2f7Stbbdev 
17351c0b2f7Stbbdev     id_functor<int> fun;
17451c0b2f7Stbbdev 
175*49e08aacStbbdev     multifunction_node<int, std::tuple<int>, lightweight> lw_node(g, oneapi::tbb::flow::serial, fun);
176*49e08aacStbbdev     multifunction_node<int, std::tuple<int>, queueing_lightweight> qlw_node(g, oneapi::tbb::flow::serial, fun);
177*49e08aacStbbdev     multifunction_node<int, std::tuple<int>, rejecting_lightweight> rlw_node(g, oneapi::tbb::flow::serial, fun);
17851c0b2f7Stbbdev 
17951c0b2f7Stbbdev }
18051c0b2f7Stbbdev 
18151c0b2f7Stbbdev std::atomic<size_t> my_concurrency;
18251c0b2f7Stbbdev std::atomic<size_t> my_max_concurrency;
18351c0b2f7Stbbdev 
18451c0b2f7Stbbdev struct concurrency_functor {
185*49e08aacStbbdev     void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
18651c0b2f7Stbbdev         ++my_concurrency;
18751c0b2f7Stbbdev 
18851c0b2f7Stbbdev         size_t old_value = my_max_concurrency;
18951c0b2f7Stbbdev         while(my_max_concurrency < my_concurrency &&
19051c0b2f7Stbbdev               !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency))
19151c0b2f7Stbbdev             ;
19251c0b2f7Stbbdev 
19351c0b2f7Stbbdev         size_t ms = 1000;
19451c0b2f7Stbbdev         std::chrono::milliseconds sleep_time( ms );
19551c0b2f7Stbbdev         std::this_thread::sleep_for( sleep_time );
19651c0b2f7Stbbdev 
19751c0b2f7Stbbdev         --my_concurrency;
19851c0b2f7Stbbdev         std::get<0>(op).try_put(argument);
19951c0b2f7Stbbdev     }
20051c0b2f7Stbbdev 
20151c0b2f7Stbbdev };
20251c0b2f7Stbbdev 
20351c0b2f7Stbbdev void test_node_concurrency(){
20451c0b2f7Stbbdev     my_concurrency = 0;
20551c0b2f7Stbbdev     my_max_concurrency = 0;
20651c0b2f7Stbbdev 
207*49e08aacStbbdev     oneapi::tbb::flow::graph g;
20851c0b2f7Stbbdev 
20951c0b2f7Stbbdev     concurrency_functor counter;
210*49e08aacStbbdev     oneapi::tbb::flow::multifunction_node <int, std::tuple<int>> fnode(g, oneapi::tbb::flow::serial, counter);
21151c0b2f7Stbbdev 
21251c0b2f7Stbbdev     test_push_receiver<int> sink(g);
21351c0b2f7Stbbdev 
21451c0b2f7Stbbdev     make_edge(std::get<0>(fnode.output_ports()), sink);
21551c0b2f7Stbbdev 
21651c0b2f7Stbbdev     for(int i = 0; i < 10; ++i){
21751c0b2f7Stbbdev         fnode.try_put(i);
21851c0b2f7Stbbdev     }
21951c0b2f7Stbbdev 
22051c0b2f7Stbbdev     g.wait_for_all();
22151c0b2f7Stbbdev     CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism over limit");
22251c0b2f7Stbbdev }
22351c0b2f7Stbbdev 
22451c0b2f7Stbbdev 
22551c0b2f7Stbbdev void test_priority(){
22651c0b2f7Stbbdev     size_t concurrency_limit = 1;
227*49e08aacStbbdev     oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit);
22851c0b2f7Stbbdev 
229*49e08aacStbbdev     oneapi::tbb::flow::graph g;
23051c0b2f7Stbbdev 
231*49e08aacStbbdev     oneapi::tbb::flow::continue_node<int> source(g,
232*49e08aacStbbdev                                          [](oneapi::tbb::flow::continue_msg){ return 1;});
233*49e08aacStbbdev     source.try_put(oneapi::tbb::flow::continue_msg());
23451c0b2f7Stbbdev 
23551c0b2f7Stbbdev     first_functor<int>::first_id = -1;
23651c0b2f7Stbbdev     first_functor<int> low_functor(1);
23751c0b2f7Stbbdev     first_functor<int> high_functor(2);
23851c0b2f7Stbbdev 
239*49e08aacStbbdev     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> high(g, oneapi::tbb::flow::unlimited, high_functor, oneapi::tbb::flow::node_priority_t(1));
240*49e08aacStbbdev     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> low(g, oneapi::tbb::flow::unlimited, low_functor);
24151c0b2f7Stbbdev 
24251c0b2f7Stbbdev     make_edge(source, low);
24351c0b2f7Stbbdev     make_edge(source, high);
24451c0b2f7Stbbdev 
24551c0b2f7Stbbdev     g.wait_for_all();
24651c0b2f7Stbbdev 
24751c0b2f7Stbbdev     CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first");
24851c0b2f7Stbbdev }
24951c0b2f7Stbbdev 
25051c0b2f7Stbbdev void test_rejecting(){
251*49e08aacStbbdev     oneapi::tbb::flow::graph g;
252*49e08aacStbbdev     oneapi::tbb::flow::multifunction_node <int, std::tuple<int>, oneapi::tbb::flow::rejecting> fnode(g, oneapi::tbb::flow::serial,
253*49e08aacStbbdev                                                                     [&](const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ){
25451c0b2f7Stbbdev                                                                         size_t ms = 50;
25551c0b2f7Stbbdev                                                                         std::chrono::milliseconds sleep_time( ms );
25651c0b2f7Stbbdev                                                                         std::this_thread::sleep_for( sleep_time );
25751c0b2f7Stbbdev                                                                         std::get<0>(op).try_put(argument);
25851c0b2f7Stbbdev                                                                     });
25951c0b2f7Stbbdev 
26051c0b2f7Stbbdev     test_push_receiver<int> sink(g);
26151c0b2f7Stbbdev 
26251c0b2f7Stbbdev     make_edge(std::get<0>(fnode.output_ports()), sink);
26351c0b2f7Stbbdev 
26451c0b2f7Stbbdev     for(int i = 0; i < 10; ++i){
26551c0b2f7Stbbdev         fnode.try_put(i);
26651c0b2f7Stbbdev     }
26751c0b2f7Stbbdev 
26851c0b2f7Stbbdev     g.wait_for_all();
26951c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed");
27051c0b2f7Stbbdev }
27151c0b2f7Stbbdev 
27251c0b2f7Stbbdev //! Test multifunction_node with rejecting policy
27351c0b2f7Stbbdev //! \brief \ref interface
27451c0b2f7Stbbdev TEST_CASE("multifunction_node with rejecting policy"){
27551c0b2f7Stbbdev     test_rejecting();
27651c0b2f7Stbbdev }
27751c0b2f7Stbbdev 
27851c0b2f7Stbbdev //! Test priorities
27951c0b2f7Stbbdev //! \brief \ref interface
28051c0b2f7Stbbdev TEST_CASE("multifunction_node priority"){
28151c0b2f7Stbbdev     test_priority();
28251c0b2f7Stbbdev }
28351c0b2f7Stbbdev 
28451c0b2f7Stbbdev //! Test concurrency
28551c0b2f7Stbbdev //! \brief \ref interface
28651c0b2f7Stbbdev TEST_CASE("multifunction_node concurrency"){
28751c0b2f7Stbbdev     test_node_concurrency();
28851c0b2f7Stbbdev }
28951c0b2f7Stbbdev 
29051c0b2f7Stbbdev //! Test constructors
29151c0b2f7Stbbdev //! \brief \ref interface
29251c0b2f7Stbbdev TEST_CASE("multifunction_node constructors"){
29351c0b2f7Stbbdev     test_policy_ctors();
29451c0b2f7Stbbdev }
29551c0b2f7Stbbdev 
29651c0b2f7Stbbdev //! Test function_node buffering
29751c0b2f7Stbbdev //! \brief \ref requirement
29851c0b2f7Stbbdev TEST_CASE("multifunction_node buffering"){
29951c0b2f7Stbbdev     test_rejecting_buffering();
30051c0b2f7Stbbdev }
30151c0b2f7Stbbdev 
30251c0b2f7Stbbdev //! Test function_node broadcasting
30351c0b2f7Stbbdev //! \brief \ref requirement
30451c0b2f7Stbbdev TEST_CASE("multifunction_node broadcast"){
30551c0b2f7Stbbdev     test_forwarding();
30651c0b2f7Stbbdev }
30751c0b2f7Stbbdev 
30851c0b2f7Stbbdev //! Test body copying and copy_body logic
30951c0b2f7Stbbdev //! \brief \ref interface
31051c0b2f7Stbbdev TEST_CASE("multifunction_node constructors"){
31151c0b2f7Stbbdev     test_copies();
31251c0b2f7Stbbdev }
31351c0b2f7Stbbdev 
31451c0b2f7Stbbdev //! Test calling function body
31551c0b2f7Stbbdev //! \brief \ref interface \ref requirement
31651c0b2f7Stbbdev TEST_CASE("multifunction_node body") {
31751c0b2f7Stbbdev     test_multifunc_body();
31851c0b2f7Stbbdev }
31951c0b2f7Stbbdev 
32051c0b2f7Stbbdev //! Test inheritance relations
32151c0b2f7Stbbdev //! \brief \ref interface
32251c0b2f7Stbbdev TEST_CASE("multifunction_node superclasses"){
32351c0b2f7Stbbdev     test_inheritance<int, std::tuple<int>>();
32451c0b2f7Stbbdev     test_inheritance<void*, std::tuple<float>>();
32551c0b2f7Stbbdev }
326