151c0b2f7Stbbdev /* 251c0b2f7Stbbdev Copyright (c) 2020 Intel Corporation 351c0b2f7Stbbdev 451c0b2f7Stbbdev Licensed under the Apache License, Version 2.0 (the "License"); 551c0b2f7Stbbdev you may not use this file except in compliance with the License. 651c0b2f7Stbbdev You may obtain a copy of the License at 751c0b2f7Stbbdev 851c0b2f7Stbbdev http://www.apache.org/licenses/LICENSE-2.0 951c0b2f7Stbbdev 1051c0b2f7Stbbdev Unless required by applicable law or agreed to in writing, software 1151c0b2f7Stbbdev distributed under the License is distributed on an "AS IS" BASIS, 1251c0b2f7Stbbdev WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1351c0b2f7Stbbdev See the License for the specific language governing permissions and 1451c0b2f7Stbbdev limitations under the License. 1551c0b2f7Stbbdev */ 1651c0b2f7Stbbdev 1751c0b2f7Stbbdev #include "common/test.h" 1851c0b2f7Stbbdev 1951c0b2f7Stbbdev #include "common/utils.h" 2051c0b2f7Stbbdev #include "common/graph_utils.h" 2151c0b2f7Stbbdev 22*49e08aacStbbdev #include "oneapi/tbb/flow_graph.h" 23*49e08aacStbbdev #include "oneapi/tbb/task_arena.h" 24*49e08aacStbbdev #include "oneapi/tbb/global_control.h" 2551c0b2f7Stbbdev 2651c0b2f7Stbbdev #include "conformance_flowgraph.h" 2751c0b2f7Stbbdev 2851c0b2f7Stbbdev //! \file conformance_multifunction_node.cpp 2951c0b2f7Stbbdev //! \brief Test for [flow_graph.function_node] specification 3051c0b2f7Stbbdev 3151c0b2f7Stbbdev /* 3251c0b2f7Stbbdev TODO: implement missing conformance tests for multifunction_node: 3351c0b2f7Stbbdev - [ ] Implement test_forwarding that checks messages are broadcast to all the successors connected 3451c0b2f7Stbbdev to the output port the message is being sent to. And check that the value passed is the 3551c0b2f7Stbbdev actual one received. 3651c0b2f7Stbbdev - [ ] Explicit test for copy constructor of the node. 3751c0b2f7Stbbdev - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body> 3851c0b2f7Stbbdev multifunction_node( graph &g, size_t concurrency, Body body, Policy(), node_priority_t priority = no_priority )'. 3951c0b2f7Stbbdev - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is 40*49e08aacStbbdev important to test at least on five values: 1, oneapi::tbb::flow::serial, `max_allowed_parallelism' 41*49e08aacStbbdev obtained from `oneapi::tbb::global_control', `oneapi::tbb::flow::unlimited', and, if `max allowed 4251c0b2f7Stbbdev parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism] 4351c0b2f7Stbbdev interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary). 4451c0b2f7Stbbdev - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads; 4551c0b2f7Stbbdev add check that `try_put()' returns `false' 4651c0b2f7Stbbdev - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call). 4751c0b2f7Stbbdev - [ ] `output_ports_type' is defined and accessible by the user. 4851c0b2f7Stbbdev - [ ] Explicit test on `mfn::output_ports()' method. 4951c0b2f7Stbbdev - [ ] The copy constructor and copy assignment are called for the node's input and output types. 5051c0b2f7Stbbdev - [ ] Add CTAD test. 5151c0b2f7Stbbdev */ 5251c0b2f7Stbbdev 5351c0b2f7Stbbdev template< typename OutputType > 5451c0b2f7Stbbdev struct mf_functor { 5551c0b2f7Stbbdev 5651c0b2f7Stbbdev std::atomic<std::size_t>& local_execute_count; 5751c0b2f7Stbbdev 5851c0b2f7Stbbdev mf_functor(std::atomic<std::size_t>& execute_count ) : 5951c0b2f7Stbbdev local_execute_count (execute_count) 6051c0b2f7Stbbdev { } 6151c0b2f7Stbbdev 6251c0b2f7Stbbdev mf_functor( const mf_functor &f ) : local_execute_count(f.local_execute_count) { } 6351c0b2f7Stbbdev void operator=(const mf_functor &f) { local_execute_count = std::size_t(f.local_execute_count); } 6451c0b2f7Stbbdev 65*49e08aacStbbdev void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) { 6651c0b2f7Stbbdev ++local_execute_count; 6751c0b2f7Stbbdev std::get<0>(op).try_put(argument); 6851c0b2f7Stbbdev } 6951c0b2f7Stbbdev 7051c0b2f7Stbbdev }; 7151c0b2f7Stbbdev 7251c0b2f7Stbbdev template<typename I, typename O> 7351c0b2f7Stbbdev void test_inheritance(){ 74*49e08aacStbbdev using namespace oneapi::tbb::flow; 7551c0b2f7Stbbdev 7651c0b2f7Stbbdev CHECK_MESSAGE( (std::is_base_of<graph_node, multifunction_node<I, O>>::value), "multifunction_node should be derived from graph_node"); 7751c0b2f7Stbbdev CHECK_MESSAGE( (std::is_base_of<receiver<I>, multifunction_node<I, O>>::value), "multifunction_node should be derived from receiver<Input>"); 7851c0b2f7Stbbdev } 7951c0b2f7Stbbdev 8051c0b2f7Stbbdev void test_multifunc_body(){ 81*49e08aacStbbdev oneapi::tbb::flow::graph g; 8251c0b2f7Stbbdev std::atomic<size_t> local_count(0); 8351c0b2f7Stbbdev mf_functor<std::tuple<int>> fun(local_count); 8451c0b2f7Stbbdev 85*49e08aacStbbdev oneapi::tbb::flow::multifunction_node<int, std::tuple<int>, oneapi::tbb::flow::rejecting> node1(g, oneapi::tbb::flow::unlimited, fun); 8651c0b2f7Stbbdev 8751c0b2f7Stbbdev const size_t n = 10; 8851c0b2f7Stbbdev for(size_t i = 0; i < n; ++i) { 8951c0b2f7Stbbdev CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true"); 9051c0b2f7Stbbdev } 9151c0b2f7Stbbdev g.wait_for_all(); 9251c0b2f7Stbbdev 9351c0b2f7Stbbdev CHECK_MESSAGE( (local_count == n), "Body of the node needs to be executed N times"); 9451c0b2f7Stbbdev } 9551c0b2f7Stbbdev 9651c0b2f7Stbbdev template<typename I, typename O> 9751c0b2f7Stbbdev struct CopyCounterBody{ 9851c0b2f7Stbbdev size_t copy_count; 9951c0b2f7Stbbdev 10051c0b2f7Stbbdev CopyCounterBody(): 10151c0b2f7Stbbdev copy_count(0) {} 10251c0b2f7Stbbdev 10351c0b2f7Stbbdev CopyCounterBody(const CopyCounterBody<I, O>& other): 10451c0b2f7Stbbdev copy_count(other.copy_count + 1) {} 10551c0b2f7Stbbdev 10651c0b2f7Stbbdev CopyCounterBody& operator=(const CopyCounterBody<I, O>& other) 10751c0b2f7Stbbdev { copy_count = other.copy_count + 1; return *this;} 10851c0b2f7Stbbdev 109*49e08aacStbbdev void operator()( const I& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) { 11051c0b2f7Stbbdev std::get<0>(op).try_put(argument); 11151c0b2f7Stbbdev } 11251c0b2f7Stbbdev }; 11351c0b2f7Stbbdev 11451c0b2f7Stbbdev void test_copies(){ 115*49e08aacStbbdev using namespace oneapi::tbb::flow; 11651c0b2f7Stbbdev 11751c0b2f7Stbbdev CopyCounterBody<int, std::tuple<int>> b; 11851c0b2f7Stbbdev 11951c0b2f7Stbbdev graph g; 12051c0b2f7Stbbdev multifunction_node<int, std::tuple<int>> fn(g, unlimited, b); 12151c0b2f7Stbbdev 12251c0b2f7Stbbdev CopyCounterBody<int, std::tuple<int>> b2 = copy_body<CopyCounterBody<int, std::tuple<int>>, 12351c0b2f7Stbbdev multifunction_node<int, std::tuple<int>>>(fn); 12451c0b2f7Stbbdev 12551c0b2f7Stbbdev CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies"); 12651c0b2f7Stbbdev } 12751c0b2f7Stbbdev 12851c0b2f7Stbbdev template< typename OutputType > 12951c0b2f7Stbbdev struct id_functor { 130*49e08aacStbbdev void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) { 13151c0b2f7Stbbdev std::get<0>(op).try_put(argument); 13251c0b2f7Stbbdev } 13351c0b2f7Stbbdev }; 13451c0b2f7Stbbdev 13551c0b2f7Stbbdev void test_forwarding(){ 136*49e08aacStbbdev oneapi::tbb::flow::graph g; 13751c0b2f7Stbbdev id_functor<int> fun; 13851c0b2f7Stbbdev 139*49e08aacStbbdev oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> node1(g, oneapi::tbb::flow::unlimited, fun); 14051c0b2f7Stbbdev test_push_receiver<int> node2(g); 14151c0b2f7Stbbdev test_push_receiver<int> node3(g); 14251c0b2f7Stbbdev 143*49e08aacStbbdev oneapi::tbb::flow::make_edge(node1, node2); 144*49e08aacStbbdev oneapi::tbb::flow::make_edge(node1, node3); 14551c0b2f7Stbbdev 14651c0b2f7Stbbdev node1.try_put(1); 14751c0b2f7Stbbdev g.wait_for_all(); 14851c0b2f7Stbbdev 14951c0b2f7Stbbdev CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message."); 15051c0b2f7Stbbdev CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message."); 15151c0b2f7Stbbdev } 15251c0b2f7Stbbdev 15351c0b2f7Stbbdev void test_rejecting_buffering(){ 154*49e08aacStbbdev oneapi::tbb::flow::graph g; 15551c0b2f7Stbbdev id_functor<int> fun; 15651c0b2f7Stbbdev 157*49e08aacStbbdev oneapi::tbb::flow::multifunction_node<int, std::tuple<int>, oneapi::tbb::flow::rejecting> node(g, oneapi::tbb::flow::unlimited, fun); 158*49e08aacStbbdev oneapi::tbb::flow::limiter_node<int> rejecter(g, 0); 15951c0b2f7Stbbdev 160*49e08aacStbbdev oneapi::tbb::flow::make_edge(node, rejecter); 16151c0b2f7Stbbdev node.try_put(1); 16251c0b2f7Stbbdev 16351c0b2f7Stbbdev int tmp = -1; 16451c0b2f7Stbbdev CHECK_MESSAGE( (std::get<0>(node.output_ports()).try_get(tmp) == false), "try_get after rejection should not succeed"); 16551c0b2f7Stbbdev CHECK_MESSAGE( (tmp == -1), "try_get after rejection should alter passed value"); 16651c0b2f7Stbbdev g.wait_for_all(); 16751c0b2f7Stbbdev } 16851c0b2f7Stbbdev 16951c0b2f7Stbbdev void test_policy_ctors(){ 170*49e08aacStbbdev using namespace oneapi::tbb::flow; 17151c0b2f7Stbbdev graph g; 17251c0b2f7Stbbdev 17351c0b2f7Stbbdev id_functor<int> fun; 17451c0b2f7Stbbdev 175*49e08aacStbbdev multifunction_node<int, std::tuple<int>, lightweight> lw_node(g, oneapi::tbb::flow::serial, fun); 176*49e08aacStbbdev multifunction_node<int, std::tuple<int>, queueing_lightweight> qlw_node(g, oneapi::tbb::flow::serial, fun); 177*49e08aacStbbdev multifunction_node<int, std::tuple<int>, rejecting_lightweight> rlw_node(g, oneapi::tbb::flow::serial, fun); 17851c0b2f7Stbbdev 17951c0b2f7Stbbdev } 18051c0b2f7Stbbdev 18151c0b2f7Stbbdev std::atomic<size_t> my_concurrency; 18251c0b2f7Stbbdev std::atomic<size_t> my_max_concurrency; 18351c0b2f7Stbbdev 18451c0b2f7Stbbdev struct concurrency_functor { 185*49e08aacStbbdev void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) { 18651c0b2f7Stbbdev ++my_concurrency; 18751c0b2f7Stbbdev 18851c0b2f7Stbbdev size_t old_value = my_max_concurrency; 18951c0b2f7Stbbdev while(my_max_concurrency < my_concurrency && 19051c0b2f7Stbbdev !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency)) 19151c0b2f7Stbbdev ; 19251c0b2f7Stbbdev 19351c0b2f7Stbbdev size_t ms = 1000; 19451c0b2f7Stbbdev std::chrono::milliseconds sleep_time( ms ); 19551c0b2f7Stbbdev std::this_thread::sleep_for( sleep_time ); 19651c0b2f7Stbbdev 19751c0b2f7Stbbdev --my_concurrency; 19851c0b2f7Stbbdev std::get<0>(op).try_put(argument); 19951c0b2f7Stbbdev } 20051c0b2f7Stbbdev 20151c0b2f7Stbbdev }; 20251c0b2f7Stbbdev 20351c0b2f7Stbbdev void test_node_concurrency(){ 20451c0b2f7Stbbdev my_concurrency = 0; 20551c0b2f7Stbbdev my_max_concurrency = 0; 20651c0b2f7Stbbdev 207*49e08aacStbbdev oneapi::tbb::flow::graph g; 20851c0b2f7Stbbdev 20951c0b2f7Stbbdev concurrency_functor counter; 210*49e08aacStbbdev oneapi::tbb::flow::multifunction_node <int, std::tuple<int>> fnode(g, oneapi::tbb::flow::serial, counter); 21151c0b2f7Stbbdev 21251c0b2f7Stbbdev test_push_receiver<int> sink(g); 21351c0b2f7Stbbdev 21451c0b2f7Stbbdev make_edge(std::get<0>(fnode.output_ports()), sink); 21551c0b2f7Stbbdev 21651c0b2f7Stbbdev for(int i = 0; i < 10; ++i){ 21751c0b2f7Stbbdev fnode.try_put(i); 21851c0b2f7Stbbdev } 21951c0b2f7Stbbdev 22051c0b2f7Stbbdev g.wait_for_all(); 22151c0b2f7Stbbdev CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism over limit"); 22251c0b2f7Stbbdev } 22351c0b2f7Stbbdev 22451c0b2f7Stbbdev 22551c0b2f7Stbbdev void test_priority(){ 22651c0b2f7Stbbdev size_t concurrency_limit = 1; 227*49e08aacStbbdev oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit); 22851c0b2f7Stbbdev 229*49e08aacStbbdev oneapi::tbb::flow::graph g; 23051c0b2f7Stbbdev 231*49e08aacStbbdev oneapi::tbb::flow::continue_node<int> source(g, 232*49e08aacStbbdev [](oneapi::tbb::flow::continue_msg){ return 1;}); 233*49e08aacStbbdev source.try_put(oneapi::tbb::flow::continue_msg()); 23451c0b2f7Stbbdev 23551c0b2f7Stbbdev first_functor<int>::first_id = -1; 23651c0b2f7Stbbdev first_functor<int> low_functor(1); 23751c0b2f7Stbbdev first_functor<int> high_functor(2); 23851c0b2f7Stbbdev 239*49e08aacStbbdev oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> high(g, oneapi::tbb::flow::unlimited, high_functor, oneapi::tbb::flow::node_priority_t(1)); 240*49e08aacStbbdev oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> low(g, oneapi::tbb::flow::unlimited, low_functor); 24151c0b2f7Stbbdev 24251c0b2f7Stbbdev make_edge(source, low); 24351c0b2f7Stbbdev make_edge(source, high); 24451c0b2f7Stbbdev 24551c0b2f7Stbbdev g.wait_for_all(); 24651c0b2f7Stbbdev 24751c0b2f7Stbbdev CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first"); 24851c0b2f7Stbbdev } 24951c0b2f7Stbbdev 25051c0b2f7Stbbdev void test_rejecting(){ 251*49e08aacStbbdev oneapi::tbb::flow::graph g; 252*49e08aacStbbdev oneapi::tbb::flow::multifunction_node <int, std::tuple<int>, oneapi::tbb::flow::rejecting> fnode(g, oneapi::tbb::flow::serial, 253*49e08aacStbbdev [&](const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ){ 25451c0b2f7Stbbdev size_t ms = 50; 25551c0b2f7Stbbdev std::chrono::milliseconds sleep_time( ms ); 25651c0b2f7Stbbdev std::this_thread::sleep_for( sleep_time ); 25751c0b2f7Stbbdev std::get<0>(op).try_put(argument); 25851c0b2f7Stbbdev }); 25951c0b2f7Stbbdev 26051c0b2f7Stbbdev test_push_receiver<int> sink(g); 26151c0b2f7Stbbdev 26251c0b2f7Stbbdev make_edge(std::get<0>(fnode.output_ports()), sink); 26351c0b2f7Stbbdev 26451c0b2f7Stbbdev for(int i = 0; i < 10; ++i){ 26551c0b2f7Stbbdev fnode.try_put(i); 26651c0b2f7Stbbdev } 26751c0b2f7Stbbdev 26851c0b2f7Stbbdev g.wait_for_all(); 26951c0b2f7Stbbdev CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed"); 27051c0b2f7Stbbdev } 27151c0b2f7Stbbdev 27251c0b2f7Stbbdev //! Test multifunction_node with rejecting policy 27351c0b2f7Stbbdev //! \brief \ref interface 27451c0b2f7Stbbdev TEST_CASE("multifunction_node with rejecting policy"){ 27551c0b2f7Stbbdev test_rejecting(); 27651c0b2f7Stbbdev } 27751c0b2f7Stbbdev 27851c0b2f7Stbbdev //! Test priorities 27951c0b2f7Stbbdev //! \brief \ref interface 28051c0b2f7Stbbdev TEST_CASE("multifunction_node priority"){ 28151c0b2f7Stbbdev test_priority(); 28251c0b2f7Stbbdev } 28351c0b2f7Stbbdev 28451c0b2f7Stbbdev //! Test concurrency 28551c0b2f7Stbbdev //! \brief \ref interface 28651c0b2f7Stbbdev TEST_CASE("multifunction_node concurrency"){ 28751c0b2f7Stbbdev test_node_concurrency(); 28851c0b2f7Stbbdev } 28951c0b2f7Stbbdev 29051c0b2f7Stbbdev //! Test constructors 29151c0b2f7Stbbdev //! \brief \ref interface 29251c0b2f7Stbbdev TEST_CASE("multifunction_node constructors"){ 29351c0b2f7Stbbdev test_policy_ctors(); 29451c0b2f7Stbbdev } 29551c0b2f7Stbbdev 29651c0b2f7Stbbdev //! Test function_node buffering 29751c0b2f7Stbbdev //! \brief \ref requirement 29851c0b2f7Stbbdev TEST_CASE("multifunction_node buffering"){ 29951c0b2f7Stbbdev test_rejecting_buffering(); 30051c0b2f7Stbbdev } 30151c0b2f7Stbbdev 30251c0b2f7Stbbdev //! Test function_node broadcasting 30351c0b2f7Stbbdev //! \brief \ref requirement 30451c0b2f7Stbbdev TEST_CASE("multifunction_node broadcast"){ 30551c0b2f7Stbbdev test_forwarding(); 30651c0b2f7Stbbdev } 30751c0b2f7Stbbdev 30851c0b2f7Stbbdev //! Test body copying and copy_body logic 30951c0b2f7Stbbdev //! \brief \ref interface 31051c0b2f7Stbbdev TEST_CASE("multifunction_node constructors"){ 31151c0b2f7Stbbdev test_copies(); 31251c0b2f7Stbbdev } 31351c0b2f7Stbbdev 31451c0b2f7Stbbdev //! Test calling function body 31551c0b2f7Stbbdev //! \brief \ref interface \ref requirement 31651c0b2f7Stbbdev TEST_CASE("multifunction_node body") { 31751c0b2f7Stbbdev test_multifunc_body(); 31851c0b2f7Stbbdev } 31951c0b2f7Stbbdev 32051c0b2f7Stbbdev //! Test inheritance relations 32151c0b2f7Stbbdev //! \brief \ref interface 32251c0b2f7Stbbdev TEST_CASE("multifunction_node superclasses"){ 32351c0b2f7Stbbdev test_inheritance<int, std::tuple<int>>(); 32451c0b2f7Stbbdev test_inheritance<void*, std::tuple<float>>(); 32551c0b2f7Stbbdev } 326