151c0b2f7Stbbdev /* 2*b15aabb3Stbbdev Copyright (c) 2020-2021 Intel Corporation 351c0b2f7Stbbdev 451c0b2f7Stbbdev Licensed under the Apache License, Version 2.0 (the "License"); 551c0b2f7Stbbdev you may not use this file except in compliance with the License. 651c0b2f7Stbbdev You may obtain a copy of the License at 751c0b2f7Stbbdev 851c0b2f7Stbbdev http://www.apache.org/licenses/LICENSE-2.0 951c0b2f7Stbbdev 1051c0b2f7Stbbdev Unless required by applicable law or agreed to in writing, software 1151c0b2f7Stbbdev distributed under the License is distributed on an "AS IS" BASIS, 1251c0b2f7Stbbdev WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1351c0b2f7Stbbdev See the License for the specific language governing permissions and 1451c0b2f7Stbbdev limitations under the License. 1551c0b2f7Stbbdev */ 1651c0b2f7Stbbdev 17*b15aabb3Stbbdev #if __INTEL_COMPILER && _MSC_VER 18*b15aabb3Stbbdev #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated 19*b15aabb3Stbbdev #endif 20*b15aabb3Stbbdev 2151c0b2f7Stbbdev #include "common/test.h" 2251c0b2f7Stbbdev 2351c0b2f7Stbbdev #include "common/utils.h" 2451c0b2f7Stbbdev #include "common/graph_utils.h" 2551c0b2f7Stbbdev 2649e08aacStbbdev #include "oneapi/tbb/flow_graph.h" 2749e08aacStbbdev #include "oneapi/tbb/task_arena.h" 2849e08aacStbbdev #include "oneapi/tbb/global_control.h" 2951c0b2f7Stbbdev 3051c0b2f7Stbbdev #include "conformance_flowgraph.h" 3151c0b2f7Stbbdev 3251c0b2f7Stbbdev //! \file conformance_function_node.cpp 3351c0b2f7Stbbdev //! \brief Test for [flow_graph.function_node] specification 3451c0b2f7Stbbdev 3551c0b2f7Stbbdev /* 3651c0b2f7Stbbdev TODO: implement missing conformance tests for function_node: 3751c0b2f7Stbbdev - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body> function_node( 3851c0b2f7Stbbdev graph &g, size_t concurrency, Body body, Policy(), node_priority_t, priority = no_priority )' 3951c0b2f7Stbbdev - [ ] Explicit test for copy constructor of the node. 4051c0b2f7Stbbdev - [ ] Rename test_broadcast to test_forwarding and check that the value passed is the actual one 4151c0b2f7Stbbdev received. 4251c0b2f7Stbbdev - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is 4349e08aacStbbdev important to test at least on five values: 1, oneapi::tbb::flow::serial, `max_allowed_parallelism' 4449e08aacStbbdev obtained from `oneapi::tbb::global_control', `oneapi::tbb::flow::unlimited', and, if `max allowed 4551c0b2f7Stbbdev parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism] 4651c0b2f7Stbbdev interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary). 4751c0b2f7Stbbdev - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads; 4851c0b2f7Stbbdev add check that `try_put()' returns `false' 4951c0b2f7Stbbdev - [ ] The copy constructor and copy assignment are called for the node's input and output types. 5051c0b2f7Stbbdev - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call). 5151c0b2f7Stbbdev - [ ] Extend CTAD test to check all node's constructors. 5251c0b2f7Stbbdev */ 5351c0b2f7Stbbdev 5451c0b2f7Stbbdev std::atomic<size_t> my_concurrency; 5551c0b2f7Stbbdev std::atomic<size_t> my_max_concurrency; 5651c0b2f7Stbbdev 5751c0b2f7Stbbdev template< typename OutputType > 5851c0b2f7Stbbdev struct concurrency_functor { 5951c0b2f7Stbbdev OutputType operator()( int argument ) { 6051c0b2f7Stbbdev ++my_concurrency; 6151c0b2f7Stbbdev 6251c0b2f7Stbbdev size_t old_value = my_max_concurrency; 6351c0b2f7Stbbdev while(my_max_concurrency < my_concurrency && 6451c0b2f7Stbbdev !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency)) 6551c0b2f7Stbbdev ; 6651c0b2f7Stbbdev 6751c0b2f7Stbbdev size_t ms = 1000; 6851c0b2f7Stbbdev std::chrono::milliseconds sleep_time( ms ); 6951c0b2f7Stbbdev std::this_thread::sleep_for( sleep_time ); 7051c0b2f7Stbbdev 7151c0b2f7Stbbdev --my_concurrency; 7251c0b2f7Stbbdev return argument; 7351c0b2f7Stbbdev } 7451c0b2f7Stbbdev 7551c0b2f7Stbbdev }; 7651c0b2f7Stbbdev 7751c0b2f7Stbbdev void test_func_body(){ 7849e08aacStbbdev oneapi::tbb::flow::graph g; 7951c0b2f7Stbbdev inc_functor<int> fun; 8051c0b2f7Stbbdev fun.execute_count = 0; 8151c0b2f7Stbbdev 8249e08aacStbbdev oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun); 8351c0b2f7Stbbdev 8451c0b2f7Stbbdev const size_t n = 10; 8551c0b2f7Stbbdev for(size_t i = 0; i < n; ++i) { 8651c0b2f7Stbbdev CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true"); 8751c0b2f7Stbbdev } 8851c0b2f7Stbbdev g.wait_for_all(); 8951c0b2f7Stbbdev 9051c0b2f7Stbbdev CHECK_MESSAGE( (fun.execute_count == n), "Body of the node needs to be executed N times"); 9151c0b2f7Stbbdev } 9251c0b2f7Stbbdev 9351c0b2f7Stbbdev void test_priority(){ 9451c0b2f7Stbbdev size_t concurrency_limit = 1; 9549e08aacStbbdev oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit); 9651c0b2f7Stbbdev 9749e08aacStbbdev oneapi::tbb::flow::graph g; 9851c0b2f7Stbbdev 9951c0b2f7Stbbdev first_functor<int>::first_id.store(-1); 10051c0b2f7Stbbdev first_functor<int> low_functor(1); 10151c0b2f7Stbbdev first_functor<int> high_functor(2); 10251c0b2f7Stbbdev 10349e08aacStbbdev oneapi::tbb::flow::continue_node<int> source(g, [&](oneapi::tbb::flow::continue_msg){return 1;} ); 10451c0b2f7Stbbdev 10549e08aacStbbdev oneapi::tbb::flow::function_node<int, int> high(g, oneapi::tbb::flow::unlimited, high_functor, oneapi::tbb::flow::node_priority_t(1)); 10649e08aacStbbdev oneapi::tbb::flow::function_node<int, int> low(g, oneapi::tbb::flow::unlimited, low_functor); 10751c0b2f7Stbbdev 10851c0b2f7Stbbdev make_edge(source, low); 10951c0b2f7Stbbdev make_edge(source, high); 11051c0b2f7Stbbdev 11149e08aacStbbdev source.try_put(oneapi::tbb::flow::continue_msg()); 11251c0b2f7Stbbdev g.wait_for_all(); 11351c0b2f7Stbbdev 11451c0b2f7Stbbdev CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first"); 11551c0b2f7Stbbdev } 11651c0b2f7Stbbdev 11751c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 11851c0b2f7Stbbdev void test_deduction_guides(){ 11949e08aacStbbdev using namespace oneapi::tbb::flow; 12051c0b2f7Stbbdev graph g; 12151c0b2f7Stbbdev 12251c0b2f7Stbbdev auto body = [](const int&)->int { return 1; }; 12351c0b2f7Stbbdev function_node f1(g, unlimited, body); 12451c0b2f7Stbbdev CHECK_MESSAGE((std::is_same_v<decltype(f1), function_node<int, int>>), "Function node type must be deducible from its body"); 12551c0b2f7Stbbdev } 12651c0b2f7Stbbdev #endif 12751c0b2f7Stbbdev 12851c0b2f7Stbbdev void test_broadcast(){ 12949e08aacStbbdev oneapi::tbb::flow::graph g; 13051c0b2f7Stbbdev passthru_body fun; 13151c0b2f7Stbbdev 13249e08aacStbbdev oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun); 13351c0b2f7Stbbdev test_push_receiver<int> node2(g); 13451c0b2f7Stbbdev test_push_receiver<int> node3(g); 13551c0b2f7Stbbdev 13649e08aacStbbdev oneapi::tbb::flow::make_edge(node1, node2); 13749e08aacStbbdev oneapi::tbb::flow::make_edge(node1, node3); 13851c0b2f7Stbbdev 13951c0b2f7Stbbdev node1.try_put(1); 14051c0b2f7Stbbdev g.wait_for_all(); 14151c0b2f7Stbbdev 14251c0b2f7Stbbdev CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message."); 14351c0b2f7Stbbdev CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message."); 14451c0b2f7Stbbdev } 14551c0b2f7Stbbdev 14651c0b2f7Stbbdev template<typename Policy> 14751c0b2f7Stbbdev void test_buffering(){ 14849e08aacStbbdev oneapi::tbb::flow::graph g; 14951c0b2f7Stbbdev passthru_body fun; 15051c0b2f7Stbbdev 15149e08aacStbbdev oneapi::tbb::flow::function_node<int, int, Policy> node(g, oneapi::tbb::flow::unlimited, fun); 15249e08aacStbbdev oneapi::tbb::flow::limiter_node<int> rejecter(g, 0); 15351c0b2f7Stbbdev 15449e08aacStbbdev oneapi::tbb::flow::make_edge(node, rejecter); 15551c0b2f7Stbbdev node.try_put(1); 15651c0b2f7Stbbdev 15751c0b2f7Stbbdev int tmp = -1; 15851c0b2f7Stbbdev CHECK_MESSAGE( (node.try_get(tmp) == false), "try_get after rejection should not succeed"); 15951c0b2f7Stbbdev CHECK_MESSAGE( (tmp == -1), "try_get after rejection should not alter passed value"); 16051c0b2f7Stbbdev g.wait_for_all(); 16151c0b2f7Stbbdev } 16251c0b2f7Stbbdev 16351c0b2f7Stbbdev void test_node_concurrency(){ 16451c0b2f7Stbbdev my_concurrency = 0; 16551c0b2f7Stbbdev my_max_concurrency = 0; 16651c0b2f7Stbbdev 16749e08aacStbbdev oneapi::tbb::flow::graph g; 16851c0b2f7Stbbdev concurrency_functor<int> counter; 16949e08aacStbbdev oneapi::tbb::flow::function_node <int, int> fnode(g, oneapi::tbb::flow::serial, counter); 17051c0b2f7Stbbdev 17151c0b2f7Stbbdev test_push_receiver<int> sink(g); 17251c0b2f7Stbbdev 17351c0b2f7Stbbdev make_edge(fnode, sink); 17451c0b2f7Stbbdev 17551c0b2f7Stbbdev for(int i = 0; i < 10; ++i){ 17651c0b2f7Stbbdev fnode.try_put(i); 17751c0b2f7Stbbdev } 17851c0b2f7Stbbdev 17951c0b2f7Stbbdev g.wait_for_all(); 18051c0b2f7Stbbdev 18151c0b2f7Stbbdev CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism is not expected"); 18251c0b2f7Stbbdev } 18351c0b2f7Stbbdev 18451c0b2f7Stbbdev template<typename I, typename O> 18551c0b2f7Stbbdev void test_inheritance(){ 18649e08aacStbbdev using namespace oneapi::tbb::flow; 18751c0b2f7Stbbdev 18851c0b2f7Stbbdev CHECK_MESSAGE( (std::is_base_of<graph_node, function_node<I, O>>::value), "function_node should be derived from graph_node"); 18951c0b2f7Stbbdev CHECK_MESSAGE( (std::is_base_of<receiver<I>, function_node<I, O>>::value), "function_node should be derived from receiver<Input>"); 19051c0b2f7Stbbdev CHECK_MESSAGE( (std::is_base_of<sender<O>, function_node<I, O>>::value), "function_node should be derived from sender<Output>"); 19151c0b2f7Stbbdev } 19251c0b2f7Stbbdev 19351c0b2f7Stbbdev void test_policy_ctors(){ 19449e08aacStbbdev using namespace oneapi::tbb::flow; 19551c0b2f7Stbbdev graph g; 19651c0b2f7Stbbdev 19749e08aacStbbdev function_node<int, int, lightweight> lw_node(g, oneapi::tbb::flow::serial, 19851c0b2f7Stbbdev [](int v) { return v;}); 19949e08aacStbbdev function_node<int, int, queueing_lightweight> qlw_node(g, oneapi::tbb::flow::serial, 20051c0b2f7Stbbdev [](int v) { return v;}); 20149e08aacStbbdev function_node<int, int, rejecting_lightweight> rlw_node(g, oneapi::tbb::flow::serial, 20251c0b2f7Stbbdev [](int v) { return v;}); 20351c0b2f7Stbbdev 20451c0b2f7Stbbdev } 20551c0b2f7Stbbdev 20651c0b2f7Stbbdev class stateful_functor{ 20751c0b2f7Stbbdev public: 20851c0b2f7Stbbdev int stored; 20951c0b2f7Stbbdev stateful_functor(): stored(-1){} 21051c0b2f7Stbbdev int operator()(int value){ stored = 1; return value;} 21151c0b2f7Stbbdev }; 21251c0b2f7Stbbdev 21351c0b2f7Stbbdev void test_ctors(){ 21449e08aacStbbdev using namespace oneapi::tbb::flow; 21551c0b2f7Stbbdev graph g; 21651c0b2f7Stbbdev 21751c0b2f7Stbbdev function_node<int, int> fn(g, unlimited, stateful_functor()); 21851c0b2f7Stbbdev fn.try_put(0); 21951c0b2f7Stbbdev g.wait_for_all(); 22051c0b2f7Stbbdev 22151c0b2f7Stbbdev stateful_functor b1 = copy_body<stateful_functor, function_node<int, int>>(fn); 22251c0b2f7Stbbdev CHECK_MESSAGE( (b1.stored == 1), "First node should update"); 22351c0b2f7Stbbdev 22451c0b2f7Stbbdev function_node<int, int> fn2(fn); 22551c0b2f7Stbbdev stateful_functor b2 = copy_body<stateful_functor, function_node<int, int>>(fn2); 22651c0b2f7Stbbdev CHECK_MESSAGE( (b2.stored == -1), "Copied node should not update"); 22751c0b2f7Stbbdev } 22851c0b2f7Stbbdev 22951c0b2f7Stbbdev template<typename I, typename O> 23051c0b2f7Stbbdev struct CopyCounterBody{ 23151c0b2f7Stbbdev size_t copy_count; 23251c0b2f7Stbbdev 23351c0b2f7Stbbdev CopyCounterBody(): 23451c0b2f7Stbbdev copy_count(0) {} 23551c0b2f7Stbbdev 23651c0b2f7Stbbdev CopyCounterBody(const CopyCounterBody<I, O>& other): 23751c0b2f7Stbbdev copy_count(other.copy_count + 1) {} 23851c0b2f7Stbbdev 23951c0b2f7Stbbdev CopyCounterBody& operator=(const CopyCounterBody<I, O>& other) 24051c0b2f7Stbbdev { copy_count = other.copy_count + 1; return *this;} 24151c0b2f7Stbbdev 24251c0b2f7Stbbdev O operator()(I in){ 24351c0b2f7Stbbdev return in; 24451c0b2f7Stbbdev } 24551c0b2f7Stbbdev }; 24651c0b2f7Stbbdev 24751c0b2f7Stbbdev void test_copies(){ 24849e08aacStbbdev using namespace oneapi::tbb::flow; 24951c0b2f7Stbbdev 25051c0b2f7Stbbdev CopyCounterBody<int, int> b; 25151c0b2f7Stbbdev 25251c0b2f7Stbbdev graph g; 25351c0b2f7Stbbdev function_node<int, int> fn(g, unlimited, b); 25451c0b2f7Stbbdev 25551c0b2f7Stbbdev CopyCounterBody<int, int> b2 = copy_body<CopyCounterBody<int, int>, function_node<int, int>>(fn); 25651c0b2f7Stbbdev 25751c0b2f7Stbbdev CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies"); 25851c0b2f7Stbbdev } 25951c0b2f7Stbbdev 26051c0b2f7Stbbdev void test_rejecting(){ 26149e08aacStbbdev oneapi::tbb::flow::graph g; 26249e08aacStbbdev oneapi::tbb::flow::function_node <int, int, oneapi::tbb::flow::rejecting> fnode(g, oneapi::tbb::flow::serial, 26351c0b2f7Stbbdev [&](int v){ 26451c0b2f7Stbbdev size_t ms = 50; 26551c0b2f7Stbbdev std::chrono::milliseconds sleep_time( ms ); 26651c0b2f7Stbbdev std::this_thread::sleep_for( sleep_time ); 26751c0b2f7Stbbdev return v; 26851c0b2f7Stbbdev }); 26951c0b2f7Stbbdev 27051c0b2f7Stbbdev test_push_receiver<int> sink(g); 27151c0b2f7Stbbdev 27251c0b2f7Stbbdev make_edge(fnode, sink); 27351c0b2f7Stbbdev 27451c0b2f7Stbbdev for(int i = 0; i < 10; ++i){ 27551c0b2f7Stbbdev fnode.try_put(i); 27651c0b2f7Stbbdev } 27751c0b2f7Stbbdev 27851c0b2f7Stbbdev g.wait_for_all(); 27951c0b2f7Stbbdev CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed"); 28051c0b2f7Stbbdev } 28151c0b2f7Stbbdev 28251c0b2f7Stbbdev //! Test function_node with rejecting policy 28351c0b2f7Stbbdev //! \brief \ref interface 28451c0b2f7Stbbdev TEST_CASE("function_node with rejecting policy"){ 28551c0b2f7Stbbdev test_rejecting(); 28651c0b2f7Stbbdev } 28751c0b2f7Stbbdev 28851c0b2f7Stbbdev //! Test body copying and copy_body logic 28951c0b2f7Stbbdev //! \brief \ref interface 29051c0b2f7Stbbdev TEST_CASE("function_node and body copying"){ 29151c0b2f7Stbbdev test_copies(); 29251c0b2f7Stbbdev } 29351c0b2f7Stbbdev 29451c0b2f7Stbbdev //! Test constructors 29551c0b2f7Stbbdev //! \brief \ref interface 29651c0b2f7Stbbdev TEST_CASE("function_node constructors"){ 29751c0b2f7Stbbdev test_policy_ctors(); 29851c0b2f7Stbbdev } 29951c0b2f7Stbbdev 30051c0b2f7Stbbdev //! Test inheritance relations 30151c0b2f7Stbbdev //! \brief \ref interface 30251c0b2f7Stbbdev TEST_CASE("function_node superclasses"){ 30351c0b2f7Stbbdev test_inheritance<int, int>(); 30451c0b2f7Stbbdev test_inheritance<void*, float>(); 30551c0b2f7Stbbdev } 30651c0b2f7Stbbdev 30751c0b2f7Stbbdev //! Test function_node buffering 30851c0b2f7Stbbdev //! \brief \ref requirement 30951c0b2f7Stbbdev TEST_CASE("function_node buffering"){ 31049e08aacStbbdev test_buffering<oneapi::tbb::flow::rejecting>(); 31149e08aacStbbdev test_buffering<oneapi::tbb::flow::queueing>(); 31251c0b2f7Stbbdev } 31351c0b2f7Stbbdev 31451c0b2f7Stbbdev //! Test function_node broadcasting 31551c0b2f7Stbbdev //! \brief \ref requirement 31651c0b2f7Stbbdev TEST_CASE("function_node broadcast"){ 31751c0b2f7Stbbdev test_broadcast(); 31851c0b2f7Stbbdev } 31951c0b2f7Stbbdev 32051c0b2f7Stbbdev //! Test deduction guides 32151c0b2f7Stbbdev //! \brief \ref interface \ref requirement 32251c0b2f7Stbbdev TEST_CASE("Deduction guides"){ 32351c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 32451c0b2f7Stbbdev test_deduction_guides(); 32551c0b2f7Stbbdev #endif 32651c0b2f7Stbbdev } 32751c0b2f7Stbbdev 32851c0b2f7Stbbdev //! Test priorities work in single-threaded configuration 32951c0b2f7Stbbdev //! \brief \ref requirement 33051c0b2f7Stbbdev TEST_CASE("function_node priority support"){ 33151c0b2f7Stbbdev test_priority(); 33251c0b2f7Stbbdev } 33351c0b2f7Stbbdev 33451c0b2f7Stbbdev //! Test that measured concurrency respects set limits 33551c0b2f7Stbbdev //! \brief \ref requirement 33651c0b2f7Stbbdev TEST_CASE("concurrency follows set limits"){ 33751c0b2f7Stbbdev test_node_concurrency(); 33851c0b2f7Stbbdev } 33951c0b2f7Stbbdev 34051c0b2f7Stbbdev //! Test calling function body 34151c0b2f7Stbbdev //! \brief \ref interface \ref requirement 34251c0b2f7Stbbdev TEST_CASE("Test function_node body") { 34351c0b2f7Stbbdev test_func_body(); 34451c0b2f7Stbbdev } 345