151c0b2f7Stbbdev /* 251c0b2f7Stbbdev Copyright (c) 2020 Intel Corporation 351c0b2f7Stbbdev 451c0b2f7Stbbdev Licensed under the Apache License, Version 2.0 (the "License"); 551c0b2f7Stbbdev you may not use this file except in compliance with the License. 651c0b2f7Stbbdev You may obtain a copy of the License at 751c0b2f7Stbbdev 851c0b2f7Stbbdev http://www.apache.org/licenses/LICENSE-2.0 951c0b2f7Stbbdev 1051c0b2f7Stbbdev Unless required by applicable law or agreed to in writing, software 1151c0b2f7Stbbdev distributed under the License is distributed on an "AS IS" BASIS, 1251c0b2f7Stbbdev WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1351c0b2f7Stbbdev See the License for the specific language governing permissions and 1451c0b2f7Stbbdev limitations under the License. 1551c0b2f7Stbbdev */ 1651c0b2f7Stbbdev 1751c0b2f7Stbbdev #include "common/test.h" 1851c0b2f7Stbbdev 1951c0b2f7Stbbdev #include "common/utils.h" 2051c0b2f7Stbbdev #include "common/graph_utils.h" 2151c0b2f7Stbbdev 22*49e08aacStbbdev #include "oneapi/tbb/flow_graph.h" 23*49e08aacStbbdev #include "oneapi/tbb/task_arena.h" 24*49e08aacStbbdev #include "oneapi/tbb/global_control.h" 2551c0b2f7Stbbdev 2651c0b2f7Stbbdev #include "conformance_flowgraph.h" 2751c0b2f7Stbbdev 2851c0b2f7Stbbdev //! \file conformance_function_node.cpp 2951c0b2f7Stbbdev //! \brief Test for [flow_graph.function_node] specification 3051c0b2f7Stbbdev 3151c0b2f7Stbbdev /* 3251c0b2f7Stbbdev TODO: implement missing conformance tests for function_node: 3351c0b2f7Stbbdev - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body> function_node( 3451c0b2f7Stbbdev graph &g, size_t concurrency, Body body, Policy(), node_priority_t, priority = no_priority )' 3551c0b2f7Stbbdev - [ ] Explicit test for copy constructor of the node. 3651c0b2f7Stbbdev - [ ] Rename test_broadcast to test_forwarding and check that the value passed is the actual one 3751c0b2f7Stbbdev received. 3851c0b2f7Stbbdev - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is 39*49e08aacStbbdev important to test at least on five values: 1, oneapi::tbb::flow::serial, `max_allowed_parallelism' 40*49e08aacStbbdev obtained from `oneapi::tbb::global_control', `oneapi::tbb::flow::unlimited', and, if `max allowed 4151c0b2f7Stbbdev parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism] 4251c0b2f7Stbbdev interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary). 4351c0b2f7Stbbdev - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads; 4451c0b2f7Stbbdev add check that `try_put()' returns `false' 4551c0b2f7Stbbdev - [ ] The copy constructor and copy assignment are called for the node's input and output types. 4651c0b2f7Stbbdev - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call). 4751c0b2f7Stbbdev - [ ] Extend CTAD test to check all node's constructors. 4851c0b2f7Stbbdev */ 4951c0b2f7Stbbdev 5051c0b2f7Stbbdev std::atomic<size_t> my_concurrency; 5151c0b2f7Stbbdev std::atomic<size_t> my_max_concurrency; 5251c0b2f7Stbbdev 5351c0b2f7Stbbdev template< typename OutputType > 5451c0b2f7Stbbdev struct concurrency_functor { 5551c0b2f7Stbbdev OutputType operator()( int argument ) { 5651c0b2f7Stbbdev ++my_concurrency; 5751c0b2f7Stbbdev 5851c0b2f7Stbbdev size_t old_value = my_max_concurrency; 5951c0b2f7Stbbdev while(my_max_concurrency < my_concurrency && 6051c0b2f7Stbbdev !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency)) 6151c0b2f7Stbbdev ; 6251c0b2f7Stbbdev 6351c0b2f7Stbbdev size_t ms = 1000; 6451c0b2f7Stbbdev std::chrono::milliseconds sleep_time( ms ); 6551c0b2f7Stbbdev std::this_thread::sleep_for( sleep_time ); 6651c0b2f7Stbbdev 6751c0b2f7Stbbdev --my_concurrency; 6851c0b2f7Stbbdev return argument; 6951c0b2f7Stbbdev } 7051c0b2f7Stbbdev 7151c0b2f7Stbbdev }; 7251c0b2f7Stbbdev 7351c0b2f7Stbbdev void test_func_body(){ 74*49e08aacStbbdev oneapi::tbb::flow::graph g; 7551c0b2f7Stbbdev inc_functor<int> fun; 7651c0b2f7Stbbdev fun.execute_count = 0; 7751c0b2f7Stbbdev 78*49e08aacStbbdev oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun); 7951c0b2f7Stbbdev 8051c0b2f7Stbbdev const size_t n = 10; 8151c0b2f7Stbbdev for(size_t i = 0; i < n; ++i) { 8251c0b2f7Stbbdev CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true"); 8351c0b2f7Stbbdev } 8451c0b2f7Stbbdev g.wait_for_all(); 8551c0b2f7Stbbdev 8651c0b2f7Stbbdev CHECK_MESSAGE( (fun.execute_count == n), "Body of the node needs to be executed N times"); 8751c0b2f7Stbbdev } 8851c0b2f7Stbbdev 8951c0b2f7Stbbdev void test_priority(){ 9051c0b2f7Stbbdev size_t concurrency_limit = 1; 91*49e08aacStbbdev oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit); 9251c0b2f7Stbbdev 93*49e08aacStbbdev oneapi::tbb::flow::graph g; 9451c0b2f7Stbbdev 9551c0b2f7Stbbdev first_functor<int>::first_id.store(-1); 9651c0b2f7Stbbdev first_functor<int> low_functor(1); 9751c0b2f7Stbbdev first_functor<int> high_functor(2); 9851c0b2f7Stbbdev 99*49e08aacStbbdev oneapi::tbb::flow::continue_node<int> source(g, [&](oneapi::tbb::flow::continue_msg){return 1;} ); 10051c0b2f7Stbbdev 101*49e08aacStbbdev oneapi::tbb::flow::function_node<int, int> high(g, oneapi::tbb::flow::unlimited, high_functor, oneapi::tbb::flow::node_priority_t(1)); 102*49e08aacStbbdev oneapi::tbb::flow::function_node<int, int> low(g, oneapi::tbb::flow::unlimited, low_functor); 10351c0b2f7Stbbdev 10451c0b2f7Stbbdev make_edge(source, low); 10551c0b2f7Stbbdev make_edge(source, high); 10651c0b2f7Stbbdev 107*49e08aacStbbdev source.try_put(oneapi::tbb::flow::continue_msg()); 10851c0b2f7Stbbdev g.wait_for_all(); 10951c0b2f7Stbbdev 11051c0b2f7Stbbdev CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first"); 11151c0b2f7Stbbdev } 11251c0b2f7Stbbdev 11351c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 11451c0b2f7Stbbdev void test_deduction_guides(){ 115*49e08aacStbbdev using namespace oneapi::tbb::flow; 11651c0b2f7Stbbdev graph g; 11751c0b2f7Stbbdev 11851c0b2f7Stbbdev auto body = [](const int&)->int { return 1; }; 11951c0b2f7Stbbdev function_node f1(g, unlimited, body); 12051c0b2f7Stbbdev CHECK_MESSAGE((std::is_same_v<decltype(f1), function_node<int, int>>), "Function node type must be deducible from its body"); 12151c0b2f7Stbbdev } 12251c0b2f7Stbbdev #endif 12351c0b2f7Stbbdev 12451c0b2f7Stbbdev void test_broadcast(){ 125*49e08aacStbbdev oneapi::tbb::flow::graph g; 12651c0b2f7Stbbdev passthru_body fun; 12751c0b2f7Stbbdev 128*49e08aacStbbdev oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun); 12951c0b2f7Stbbdev test_push_receiver<int> node2(g); 13051c0b2f7Stbbdev test_push_receiver<int> node3(g); 13151c0b2f7Stbbdev 132*49e08aacStbbdev oneapi::tbb::flow::make_edge(node1, node2); 133*49e08aacStbbdev oneapi::tbb::flow::make_edge(node1, node3); 13451c0b2f7Stbbdev 13551c0b2f7Stbbdev node1.try_put(1); 13651c0b2f7Stbbdev g.wait_for_all(); 13751c0b2f7Stbbdev 13851c0b2f7Stbbdev CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message."); 13951c0b2f7Stbbdev CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message."); 14051c0b2f7Stbbdev } 14151c0b2f7Stbbdev 14251c0b2f7Stbbdev template<typename Policy> 14351c0b2f7Stbbdev void test_buffering(){ 144*49e08aacStbbdev oneapi::tbb::flow::graph g; 14551c0b2f7Stbbdev passthru_body fun; 14651c0b2f7Stbbdev 147*49e08aacStbbdev oneapi::tbb::flow::function_node<int, int, Policy> node(g, oneapi::tbb::flow::unlimited, fun); 148*49e08aacStbbdev oneapi::tbb::flow::limiter_node<int> rejecter(g, 0); 14951c0b2f7Stbbdev 150*49e08aacStbbdev oneapi::tbb::flow::make_edge(node, rejecter); 15151c0b2f7Stbbdev node.try_put(1); 15251c0b2f7Stbbdev 15351c0b2f7Stbbdev int tmp = -1; 15451c0b2f7Stbbdev CHECK_MESSAGE( (node.try_get(tmp) == false), "try_get after rejection should not succeed"); 15551c0b2f7Stbbdev CHECK_MESSAGE( (tmp == -1), "try_get after rejection should not alter passed value"); 15651c0b2f7Stbbdev g.wait_for_all(); 15751c0b2f7Stbbdev } 15851c0b2f7Stbbdev 15951c0b2f7Stbbdev void test_node_concurrency(){ 16051c0b2f7Stbbdev my_concurrency = 0; 16151c0b2f7Stbbdev my_max_concurrency = 0; 16251c0b2f7Stbbdev 163*49e08aacStbbdev oneapi::tbb::flow::graph g; 16451c0b2f7Stbbdev concurrency_functor<int> counter; 165*49e08aacStbbdev oneapi::tbb::flow::function_node <int, int> fnode(g, oneapi::tbb::flow::serial, counter); 16651c0b2f7Stbbdev 16751c0b2f7Stbbdev test_push_receiver<int> sink(g); 16851c0b2f7Stbbdev 16951c0b2f7Stbbdev make_edge(fnode, sink); 17051c0b2f7Stbbdev 17151c0b2f7Stbbdev for(int i = 0; i < 10; ++i){ 17251c0b2f7Stbbdev fnode.try_put(i); 17351c0b2f7Stbbdev } 17451c0b2f7Stbbdev 17551c0b2f7Stbbdev g.wait_for_all(); 17651c0b2f7Stbbdev 17751c0b2f7Stbbdev CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism is not expected"); 17851c0b2f7Stbbdev } 17951c0b2f7Stbbdev 18051c0b2f7Stbbdev template<typename I, typename O> 18151c0b2f7Stbbdev void test_inheritance(){ 182*49e08aacStbbdev using namespace oneapi::tbb::flow; 18351c0b2f7Stbbdev 18451c0b2f7Stbbdev CHECK_MESSAGE( (std::is_base_of<graph_node, function_node<I, O>>::value), "function_node should be derived from graph_node"); 18551c0b2f7Stbbdev CHECK_MESSAGE( (std::is_base_of<receiver<I>, function_node<I, O>>::value), "function_node should be derived from receiver<Input>"); 18651c0b2f7Stbbdev CHECK_MESSAGE( (std::is_base_of<sender<O>, function_node<I, O>>::value), "function_node should be derived from sender<Output>"); 18751c0b2f7Stbbdev } 18851c0b2f7Stbbdev 18951c0b2f7Stbbdev void test_policy_ctors(){ 190*49e08aacStbbdev using namespace oneapi::tbb::flow; 19151c0b2f7Stbbdev graph g; 19251c0b2f7Stbbdev 193*49e08aacStbbdev function_node<int, int, lightweight> lw_node(g, oneapi::tbb::flow::serial, 19451c0b2f7Stbbdev [](int v) { return v;}); 195*49e08aacStbbdev function_node<int, int, queueing_lightweight> qlw_node(g, oneapi::tbb::flow::serial, 19651c0b2f7Stbbdev [](int v) { return v;}); 197*49e08aacStbbdev function_node<int, int, rejecting_lightweight> rlw_node(g, oneapi::tbb::flow::serial, 19851c0b2f7Stbbdev [](int v) { return v;}); 19951c0b2f7Stbbdev 20051c0b2f7Stbbdev } 20151c0b2f7Stbbdev 20251c0b2f7Stbbdev class stateful_functor{ 20351c0b2f7Stbbdev public: 20451c0b2f7Stbbdev int stored; 20551c0b2f7Stbbdev stateful_functor(): stored(-1){} 20651c0b2f7Stbbdev int operator()(int value){ stored = 1; return value;} 20751c0b2f7Stbbdev }; 20851c0b2f7Stbbdev 20951c0b2f7Stbbdev void test_ctors(){ 210*49e08aacStbbdev using namespace oneapi::tbb::flow; 21151c0b2f7Stbbdev graph g; 21251c0b2f7Stbbdev 21351c0b2f7Stbbdev function_node<int, int> fn(g, unlimited, stateful_functor()); 21451c0b2f7Stbbdev fn.try_put(0); 21551c0b2f7Stbbdev g.wait_for_all(); 21651c0b2f7Stbbdev 21751c0b2f7Stbbdev stateful_functor b1 = copy_body<stateful_functor, function_node<int, int>>(fn); 21851c0b2f7Stbbdev CHECK_MESSAGE( (b1.stored == 1), "First node should update"); 21951c0b2f7Stbbdev 22051c0b2f7Stbbdev function_node<int, int> fn2(fn); 22151c0b2f7Stbbdev stateful_functor b2 = copy_body<stateful_functor, function_node<int, int>>(fn2); 22251c0b2f7Stbbdev CHECK_MESSAGE( (b2.stored == -1), "Copied node should not update"); 22351c0b2f7Stbbdev } 22451c0b2f7Stbbdev 22551c0b2f7Stbbdev template<typename I, typename O> 22651c0b2f7Stbbdev struct CopyCounterBody{ 22751c0b2f7Stbbdev size_t copy_count; 22851c0b2f7Stbbdev 22951c0b2f7Stbbdev CopyCounterBody(): 23051c0b2f7Stbbdev copy_count(0) {} 23151c0b2f7Stbbdev 23251c0b2f7Stbbdev CopyCounterBody(const CopyCounterBody<I, O>& other): 23351c0b2f7Stbbdev copy_count(other.copy_count + 1) {} 23451c0b2f7Stbbdev 23551c0b2f7Stbbdev CopyCounterBody& operator=(const CopyCounterBody<I, O>& other) 23651c0b2f7Stbbdev { copy_count = other.copy_count + 1; return *this;} 23751c0b2f7Stbbdev 23851c0b2f7Stbbdev O operator()(I in){ 23951c0b2f7Stbbdev return in; 24051c0b2f7Stbbdev } 24151c0b2f7Stbbdev }; 24251c0b2f7Stbbdev 24351c0b2f7Stbbdev void test_copies(){ 244*49e08aacStbbdev using namespace oneapi::tbb::flow; 24551c0b2f7Stbbdev 24651c0b2f7Stbbdev CopyCounterBody<int, int> b; 24751c0b2f7Stbbdev 24851c0b2f7Stbbdev graph g; 24951c0b2f7Stbbdev function_node<int, int> fn(g, unlimited, b); 25051c0b2f7Stbbdev 25151c0b2f7Stbbdev CopyCounterBody<int, int> b2 = copy_body<CopyCounterBody<int, int>, function_node<int, int>>(fn); 25251c0b2f7Stbbdev 25351c0b2f7Stbbdev CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies"); 25451c0b2f7Stbbdev } 25551c0b2f7Stbbdev 25651c0b2f7Stbbdev void test_rejecting(){ 257*49e08aacStbbdev oneapi::tbb::flow::graph g; 258*49e08aacStbbdev oneapi::tbb::flow::function_node <int, int, oneapi::tbb::flow::rejecting> fnode(g, oneapi::tbb::flow::serial, 25951c0b2f7Stbbdev [&](int v){ 26051c0b2f7Stbbdev size_t ms = 50; 26151c0b2f7Stbbdev std::chrono::milliseconds sleep_time( ms ); 26251c0b2f7Stbbdev std::this_thread::sleep_for( sleep_time ); 26351c0b2f7Stbbdev return v; 26451c0b2f7Stbbdev }); 26551c0b2f7Stbbdev 26651c0b2f7Stbbdev test_push_receiver<int> sink(g); 26751c0b2f7Stbbdev 26851c0b2f7Stbbdev make_edge(fnode, sink); 26951c0b2f7Stbbdev 27051c0b2f7Stbbdev for(int i = 0; i < 10; ++i){ 27151c0b2f7Stbbdev fnode.try_put(i); 27251c0b2f7Stbbdev } 27351c0b2f7Stbbdev 27451c0b2f7Stbbdev g.wait_for_all(); 27551c0b2f7Stbbdev CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed"); 27651c0b2f7Stbbdev } 27751c0b2f7Stbbdev 27851c0b2f7Stbbdev //! Test function_node with rejecting policy 27951c0b2f7Stbbdev //! \brief \ref interface 28051c0b2f7Stbbdev TEST_CASE("function_node with rejecting policy"){ 28151c0b2f7Stbbdev test_rejecting(); 28251c0b2f7Stbbdev } 28351c0b2f7Stbbdev 28451c0b2f7Stbbdev //! Test body copying and copy_body logic 28551c0b2f7Stbbdev //! \brief \ref interface 28651c0b2f7Stbbdev TEST_CASE("function_node and body copying"){ 28751c0b2f7Stbbdev test_copies(); 28851c0b2f7Stbbdev } 28951c0b2f7Stbbdev 29051c0b2f7Stbbdev //! Test constructors 29151c0b2f7Stbbdev //! \brief \ref interface 29251c0b2f7Stbbdev TEST_CASE("function_node constructors"){ 29351c0b2f7Stbbdev test_policy_ctors(); 29451c0b2f7Stbbdev } 29551c0b2f7Stbbdev 29651c0b2f7Stbbdev //! Test inheritance relations 29751c0b2f7Stbbdev //! \brief \ref interface 29851c0b2f7Stbbdev TEST_CASE("function_node superclasses"){ 29951c0b2f7Stbbdev test_inheritance<int, int>(); 30051c0b2f7Stbbdev test_inheritance<void*, float>(); 30151c0b2f7Stbbdev } 30251c0b2f7Stbbdev 30351c0b2f7Stbbdev //! Test function_node buffering 30451c0b2f7Stbbdev //! \brief \ref requirement 30551c0b2f7Stbbdev TEST_CASE("function_node buffering"){ 306*49e08aacStbbdev test_buffering<oneapi::tbb::flow::rejecting>(); 307*49e08aacStbbdev test_buffering<oneapi::tbb::flow::queueing>(); 30851c0b2f7Stbbdev } 30951c0b2f7Stbbdev 31051c0b2f7Stbbdev //! Test function_node broadcasting 31151c0b2f7Stbbdev //! \brief \ref requirement 31251c0b2f7Stbbdev TEST_CASE("function_node broadcast"){ 31351c0b2f7Stbbdev test_broadcast(); 31451c0b2f7Stbbdev } 31551c0b2f7Stbbdev 31651c0b2f7Stbbdev //! Test deduction guides 31751c0b2f7Stbbdev //! \brief \ref interface \ref requirement 31851c0b2f7Stbbdev TEST_CASE("Deduction guides"){ 31951c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 32051c0b2f7Stbbdev test_deduction_guides(); 32151c0b2f7Stbbdev #endif 32251c0b2f7Stbbdev } 32351c0b2f7Stbbdev 32451c0b2f7Stbbdev //! Test priorities work in single-threaded configuration 32551c0b2f7Stbbdev //! \brief \ref requirement 32651c0b2f7Stbbdev TEST_CASE("function_node priority support"){ 32751c0b2f7Stbbdev test_priority(); 32851c0b2f7Stbbdev } 32951c0b2f7Stbbdev 33051c0b2f7Stbbdev //! Test that measured concurrency respects set limits 33151c0b2f7Stbbdev //! \brief \ref requirement 33251c0b2f7Stbbdev TEST_CASE("concurrency follows set limits"){ 33351c0b2f7Stbbdev test_node_concurrency(); 33451c0b2f7Stbbdev } 33551c0b2f7Stbbdev 33651c0b2f7Stbbdev //! Test calling function body 33751c0b2f7Stbbdev //! \brief \ref interface \ref requirement 33851c0b2f7Stbbdev TEST_CASE("Test function_node body") { 33951c0b2f7Stbbdev test_func_body(); 34051c0b2f7Stbbdev } 341