151c0b2f7Stbbdev /*
251c0b2f7Stbbdev     Copyright (c) 2020 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
1751c0b2f7Stbbdev #include "common/test.h"
1851c0b2f7Stbbdev 
1951c0b2f7Stbbdev #include "common/utils.h"
2051c0b2f7Stbbdev #include "common/graph_utils.h"
2151c0b2f7Stbbdev 
22*49e08aacStbbdev #include "oneapi/tbb/flow_graph.h"
23*49e08aacStbbdev #include "oneapi/tbb/task_arena.h"
24*49e08aacStbbdev #include "oneapi/tbb/global_control.h"
2551c0b2f7Stbbdev 
2651c0b2f7Stbbdev #include "conformance_flowgraph.h"
2751c0b2f7Stbbdev 
2851c0b2f7Stbbdev //! \file conformance_function_node.cpp
2951c0b2f7Stbbdev //! \brief Test for [flow_graph.function_node] specification
3051c0b2f7Stbbdev 
3151c0b2f7Stbbdev /*
3251c0b2f7Stbbdev TODO: implement missing conformance tests for function_node:
3351c0b2f7Stbbdev   - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body> function_node(
3451c0b2f7Stbbdev     graph &g, size_t concurrency, Body body, Policy(), node_priority_t, priority = no_priority )'
3551c0b2f7Stbbdev   - [ ] Explicit test for copy constructor of the node.
3651c0b2f7Stbbdev   - [ ] Rename test_broadcast to test_forwarding and check that the value passed is the actual one
3751c0b2f7Stbbdev     received.
3851c0b2f7Stbbdev   - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is
39*49e08aacStbbdev     important to test at least on five values: 1, oneapi::tbb::flow::serial, `max_allowed_parallelism'
40*49e08aacStbbdev     obtained from `oneapi::tbb::global_control', `oneapi::tbb::flow::unlimited', and, if `max allowed
4151c0b2f7Stbbdev     parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism]
4251c0b2f7Stbbdev     interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary).
4351c0b2f7Stbbdev   - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads;
4451c0b2f7Stbbdev     add check that `try_put()' returns `false'
4551c0b2f7Stbbdev   - [ ] The copy constructor and copy assignment are called for the node's input and output types.
4651c0b2f7Stbbdev   - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call).
4751c0b2f7Stbbdev   - [ ] Extend CTAD test to check all node's constructors.
4851c0b2f7Stbbdev */
4951c0b2f7Stbbdev 
5051c0b2f7Stbbdev std::atomic<size_t> my_concurrency;
5151c0b2f7Stbbdev std::atomic<size_t> my_max_concurrency;
5251c0b2f7Stbbdev 
5351c0b2f7Stbbdev template< typename OutputType >
5451c0b2f7Stbbdev struct concurrency_functor {
5551c0b2f7Stbbdev     OutputType operator()( int argument ) {
5651c0b2f7Stbbdev         ++my_concurrency;
5751c0b2f7Stbbdev 
5851c0b2f7Stbbdev         size_t old_value = my_max_concurrency;
5951c0b2f7Stbbdev         while(my_max_concurrency < my_concurrency &&
6051c0b2f7Stbbdev               !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency))
6151c0b2f7Stbbdev             ;
6251c0b2f7Stbbdev 
6351c0b2f7Stbbdev         size_t ms = 1000;
6451c0b2f7Stbbdev         std::chrono::milliseconds sleep_time( ms );
6551c0b2f7Stbbdev         std::this_thread::sleep_for( sleep_time );
6651c0b2f7Stbbdev 
6751c0b2f7Stbbdev         --my_concurrency;
6851c0b2f7Stbbdev        return argument;
6951c0b2f7Stbbdev     }
7051c0b2f7Stbbdev 
7151c0b2f7Stbbdev };
7251c0b2f7Stbbdev 
7351c0b2f7Stbbdev void test_func_body(){
74*49e08aacStbbdev     oneapi::tbb::flow::graph g;
7551c0b2f7Stbbdev     inc_functor<int> fun;
7651c0b2f7Stbbdev     fun.execute_count = 0;
7751c0b2f7Stbbdev 
78*49e08aacStbbdev     oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun);
7951c0b2f7Stbbdev 
8051c0b2f7Stbbdev     const size_t n = 10;
8151c0b2f7Stbbdev     for(size_t i = 0; i < n; ++i) {
8251c0b2f7Stbbdev         CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true");
8351c0b2f7Stbbdev     }
8451c0b2f7Stbbdev     g.wait_for_all();
8551c0b2f7Stbbdev 
8651c0b2f7Stbbdev     CHECK_MESSAGE( (fun.execute_count == n), "Body of the node needs to be executed N times");
8751c0b2f7Stbbdev }
8851c0b2f7Stbbdev 
8951c0b2f7Stbbdev void test_priority(){
9051c0b2f7Stbbdev     size_t concurrency_limit = 1;
91*49e08aacStbbdev     oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit);
9251c0b2f7Stbbdev 
93*49e08aacStbbdev     oneapi::tbb::flow::graph g;
9451c0b2f7Stbbdev 
9551c0b2f7Stbbdev     first_functor<int>::first_id.store(-1);
9651c0b2f7Stbbdev     first_functor<int> low_functor(1);
9751c0b2f7Stbbdev     first_functor<int> high_functor(2);
9851c0b2f7Stbbdev 
99*49e08aacStbbdev     oneapi::tbb::flow::continue_node<int> source(g, [&](oneapi::tbb::flow::continue_msg){return 1;} );
10051c0b2f7Stbbdev 
101*49e08aacStbbdev     oneapi::tbb::flow::function_node<int, int> high(g, oneapi::tbb::flow::unlimited, high_functor, oneapi::tbb::flow::node_priority_t(1));
102*49e08aacStbbdev     oneapi::tbb::flow::function_node<int, int> low(g, oneapi::tbb::flow::unlimited, low_functor);
10351c0b2f7Stbbdev 
10451c0b2f7Stbbdev     make_edge(source, low);
10551c0b2f7Stbbdev     make_edge(source, high);
10651c0b2f7Stbbdev 
107*49e08aacStbbdev     source.try_put(oneapi::tbb::flow::continue_msg());
10851c0b2f7Stbbdev     g.wait_for_all();
10951c0b2f7Stbbdev 
11051c0b2f7Stbbdev     CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first");
11151c0b2f7Stbbdev }
11251c0b2f7Stbbdev 
11351c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
11451c0b2f7Stbbdev void test_deduction_guides(){
115*49e08aacStbbdev     using namespace oneapi::tbb::flow;
11651c0b2f7Stbbdev     graph g;
11751c0b2f7Stbbdev 
11851c0b2f7Stbbdev     auto body = [](const int&)->int { return 1; };
11951c0b2f7Stbbdev     function_node f1(g, unlimited, body);
12051c0b2f7Stbbdev     CHECK_MESSAGE((std::is_same_v<decltype(f1), function_node<int, int>>), "Function node type must be deducible from its body");
12151c0b2f7Stbbdev }
12251c0b2f7Stbbdev #endif
12351c0b2f7Stbbdev 
12451c0b2f7Stbbdev void test_broadcast(){
125*49e08aacStbbdev     oneapi::tbb::flow::graph g;
12651c0b2f7Stbbdev     passthru_body fun;
12751c0b2f7Stbbdev 
128*49e08aacStbbdev     oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun);
12951c0b2f7Stbbdev     test_push_receiver<int> node2(g);
13051c0b2f7Stbbdev     test_push_receiver<int> node3(g);
13151c0b2f7Stbbdev 
132*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node2);
133*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node3);
13451c0b2f7Stbbdev 
13551c0b2f7Stbbdev     node1.try_put(1);
13651c0b2f7Stbbdev     g.wait_for_all();
13751c0b2f7Stbbdev 
13851c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message.");
13951c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message.");
14051c0b2f7Stbbdev }
14151c0b2f7Stbbdev 
14251c0b2f7Stbbdev template<typename Policy>
14351c0b2f7Stbbdev void test_buffering(){
144*49e08aacStbbdev     oneapi::tbb::flow::graph g;
14551c0b2f7Stbbdev     passthru_body fun;
14651c0b2f7Stbbdev 
147*49e08aacStbbdev     oneapi::tbb::flow::function_node<int, int, Policy> node(g, oneapi::tbb::flow::unlimited, fun);
148*49e08aacStbbdev     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
14951c0b2f7Stbbdev 
150*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node, rejecter);
15151c0b2f7Stbbdev     node.try_put(1);
15251c0b2f7Stbbdev 
15351c0b2f7Stbbdev     int tmp = -1;
15451c0b2f7Stbbdev     CHECK_MESSAGE( (node.try_get(tmp) == false), "try_get after rejection should not succeed");
15551c0b2f7Stbbdev     CHECK_MESSAGE( (tmp == -1), "try_get after rejection should not alter passed value");
15651c0b2f7Stbbdev     g.wait_for_all();
15751c0b2f7Stbbdev }
15851c0b2f7Stbbdev 
15951c0b2f7Stbbdev void test_node_concurrency(){
16051c0b2f7Stbbdev     my_concurrency = 0;
16151c0b2f7Stbbdev     my_max_concurrency = 0;
16251c0b2f7Stbbdev 
163*49e08aacStbbdev     oneapi::tbb::flow::graph g;
16451c0b2f7Stbbdev     concurrency_functor<int> counter;
165*49e08aacStbbdev     oneapi::tbb::flow::function_node <int, int> fnode(g, oneapi::tbb::flow::serial, counter);
16651c0b2f7Stbbdev 
16751c0b2f7Stbbdev     test_push_receiver<int> sink(g);
16851c0b2f7Stbbdev 
16951c0b2f7Stbbdev     make_edge(fnode, sink);
17051c0b2f7Stbbdev 
17151c0b2f7Stbbdev     for(int i = 0; i < 10; ++i){
17251c0b2f7Stbbdev         fnode.try_put(i);
17351c0b2f7Stbbdev     }
17451c0b2f7Stbbdev 
17551c0b2f7Stbbdev     g.wait_for_all();
17651c0b2f7Stbbdev 
17751c0b2f7Stbbdev     CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism is not expected");
17851c0b2f7Stbbdev }
17951c0b2f7Stbbdev 
18051c0b2f7Stbbdev template<typename I, typename O>
18151c0b2f7Stbbdev void test_inheritance(){
182*49e08aacStbbdev     using namespace oneapi::tbb::flow;
18351c0b2f7Stbbdev 
18451c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<graph_node, function_node<I, O>>::value), "function_node should be derived from graph_node");
18551c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<receiver<I>, function_node<I, O>>::value), "function_node should be derived from receiver<Input>");
18651c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<sender<O>, function_node<I, O>>::value), "function_node should be derived from sender<Output>");
18751c0b2f7Stbbdev }
18851c0b2f7Stbbdev 
18951c0b2f7Stbbdev void test_policy_ctors(){
190*49e08aacStbbdev     using namespace oneapi::tbb::flow;
19151c0b2f7Stbbdev     graph g;
19251c0b2f7Stbbdev 
193*49e08aacStbbdev     function_node<int, int, lightweight> lw_node(g, oneapi::tbb::flow::serial,
19451c0b2f7Stbbdev                                                           [](int v) { return v;});
195*49e08aacStbbdev     function_node<int, int, queueing_lightweight> qlw_node(g, oneapi::tbb::flow::serial,
19651c0b2f7Stbbdev                                                           [](int v) { return v;});
197*49e08aacStbbdev     function_node<int, int, rejecting_lightweight> rlw_node(g, oneapi::tbb::flow::serial,
19851c0b2f7Stbbdev                                                           [](int v) { return v;});
19951c0b2f7Stbbdev 
20051c0b2f7Stbbdev }
20151c0b2f7Stbbdev 
20251c0b2f7Stbbdev class stateful_functor{
20351c0b2f7Stbbdev public:
20451c0b2f7Stbbdev     int stored;
20551c0b2f7Stbbdev     stateful_functor(): stored(-1){}
20651c0b2f7Stbbdev     int operator()(int value){ stored = 1; return value;}
20751c0b2f7Stbbdev };
20851c0b2f7Stbbdev 
20951c0b2f7Stbbdev void test_ctors(){
210*49e08aacStbbdev     using namespace oneapi::tbb::flow;
21151c0b2f7Stbbdev     graph g;
21251c0b2f7Stbbdev 
21351c0b2f7Stbbdev     function_node<int, int> fn(g, unlimited, stateful_functor());
21451c0b2f7Stbbdev     fn.try_put(0);
21551c0b2f7Stbbdev     g.wait_for_all();
21651c0b2f7Stbbdev 
21751c0b2f7Stbbdev     stateful_functor b1 = copy_body<stateful_functor, function_node<int, int>>(fn);
21851c0b2f7Stbbdev     CHECK_MESSAGE( (b1.stored == 1), "First node should update");
21951c0b2f7Stbbdev 
22051c0b2f7Stbbdev     function_node<int, int> fn2(fn);
22151c0b2f7Stbbdev     stateful_functor b2 = copy_body<stateful_functor, function_node<int, int>>(fn2);
22251c0b2f7Stbbdev     CHECK_MESSAGE( (b2.stored == -1), "Copied node should not update");
22351c0b2f7Stbbdev }
22451c0b2f7Stbbdev 
22551c0b2f7Stbbdev template<typename I, typename O>
22651c0b2f7Stbbdev struct CopyCounterBody{
22751c0b2f7Stbbdev     size_t copy_count;
22851c0b2f7Stbbdev 
22951c0b2f7Stbbdev     CopyCounterBody():
23051c0b2f7Stbbdev         copy_count(0) {}
23151c0b2f7Stbbdev 
23251c0b2f7Stbbdev     CopyCounterBody(const CopyCounterBody<I, O>& other):
23351c0b2f7Stbbdev         copy_count(other.copy_count + 1) {}
23451c0b2f7Stbbdev 
23551c0b2f7Stbbdev     CopyCounterBody& operator=(const CopyCounterBody<I, O>& other)
23651c0b2f7Stbbdev     { copy_count = other.copy_count + 1; return *this;}
23751c0b2f7Stbbdev 
23851c0b2f7Stbbdev     O operator()(I in){
23951c0b2f7Stbbdev         return in;
24051c0b2f7Stbbdev     }
24151c0b2f7Stbbdev };
24251c0b2f7Stbbdev 
24351c0b2f7Stbbdev void test_copies(){
244*49e08aacStbbdev     using namespace oneapi::tbb::flow;
24551c0b2f7Stbbdev 
24651c0b2f7Stbbdev     CopyCounterBody<int, int> b;
24751c0b2f7Stbbdev 
24851c0b2f7Stbbdev     graph g;
24951c0b2f7Stbbdev     function_node<int, int> fn(g, unlimited, b);
25051c0b2f7Stbbdev 
25151c0b2f7Stbbdev     CopyCounterBody<int, int> b2 = copy_body<CopyCounterBody<int, int>, function_node<int, int>>(fn);
25251c0b2f7Stbbdev 
25351c0b2f7Stbbdev     CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
25451c0b2f7Stbbdev }
25551c0b2f7Stbbdev 
25651c0b2f7Stbbdev void test_rejecting(){
257*49e08aacStbbdev     oneapi::tbb::flow::graph g;
258*49e08aacStbbdev     oneapi::tbb::flow::function_node <int, int, oneapi::tbb::flow::rejecting> fnode(g, oneapi::tbb::flow::serial,
25951c0b2f7Stbbdev                                                                     [&](int v){
26051c0b2f7Stbbdev                                                                         size_t ms = 50;
26151c0b2f7Stbbdev                                                                         std::chrono::milliseconds sleep_time( ms );
26251c0b2f7Stbbdev                                                                         std::this_thread::sleep_for( sleep_time );
26351c0b2f7Stbbdev                                                                         return v;
26451c0b2f7Stbbdev                                                                     });
26551c0b2f7Stbbdev 
26651c0b2f7Stbbdev     test_push_receiver<int> sink(g);
26751c0b2f7Stbbdev 
26851c0b2f7Stbbdev     make_edge(fnode, sink);
26951c0b2f7Stbbdev 
27051c0b2f7Stbbdev     for(int i = 0; i < 10; ++i){
27151c0b2f7Stbbdev         fnode.try_put(i);
27251c0b2f7Stbbdev     }
27351c0b2f7Stbbdev 
27451c0b2f7Stbbdev     g.wait_for_all();
27551c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed");
27651c0b2f7Stbbdev }
27751c0b2f7Stbbdev 
27851c0b2f7Stbbdev //! Test function_node with rejecting policy
27951c0b2f7Stbbdev //! \brief \ref interface
28051c0b2f7Stbbdev TEST_CASE("function_node with rejecting policy"){
28151c0b2f7Stbbdev     test_rejecting();
28251c0b2f7Stbbdev }
28351c0b2f7Stbbdev 
28451c0b2f7Stbbdev //! Test body copying and copy_body logic
28551c0b2f7Stbbdev //! \brief \ref interface
28651c0b2f7Stbbdev TEST_CASE("function_node and body copying"){
28751c0b2f7Stbbdev     test_copies();
28851c0b2f7Stbbdev }
28951c0b2f7Stbbdev 
29051c0b2f7Stbbdev //! Test constructors
29151c0b2f7Stbbdev //! \brief \ref interface
29251c0b2f7Stbbdev TEST_CASE("function_node constructors"){
29351c0b2f7Stbbdev     test_policy_ctors();
29451c0b2f7Stbbdev }
29551c0b2f7Stbbdev 
29651c0b2f7Stbbdev //! Test inheritance relations
29751c0b2f7Stbbdev //! \brief \ref interface
29851c0b2f7Stbbdev TEST_CASE("function_node superclasses"){
29951c0b2f7Stbbdev     test_inheritance<int, int>();
30051c0b2f7Stbbdev     test_inheritance<void*, float>();
30151c0b2f7Stbbdev }
30251c0b2f7Stbbdev 
30351c0b2f7Stbbdev //! Test function_node buffering
30451c0b2f7Stbbdev //! \brief \ref requirement
30551c0b2f7Stbbdev TEST_CASE("function_node buffering"){
306*49e08aacStbbdev     test_buffering<oneapi::tbb::flow::rejecting>();
307*49e08aacStbbdev     test_buffering<oneapi::tbb::flow::queueing>();
30851c0b2f7Stbbdev }
30951c0b2f7Stbbdev 
31051c0b2f7Stbbdev //! Test function_node broadcasting
31151c0b2f7Stbbdev //! \brief \ref requirement
31251c0b2f7Stbbdev TEST_CASE("function_node broadcast"){
31351c0b2f7Stbbdev     test_broadcast();
31451c0b2f7Stbbdev }
31551c0b2f7Stbbdev 
31651c0b2f7Stbbdev //! Test deduction guides
31751c0b2f7Stbbdev //! \brief \ref interface \ref requirement
31851c0b2f7Stbbdev TEST_CASE("Deduction guides"){
31951c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
32051c0b2f7Stbbdev     test_deduction_guides();
32151c0b2f7Stbbdev #endif
32251c0b2f7Stbbdev }
32351c0b2f7Stbbdev 
32451c0b2f7Stbbdev //! Test priorities work in single-threaded configuration
32551c0b2f7Stbbdev //! \brief \ref requirement
32651c0b2f7Stbbdev TEST_CASE("function_node priority support"){
32751c0b2f7Stbbdev     test_priority();
32851c0b2f7Stbbdev }
32951c0b2f7Stbbdev 
33051c0b2f7Stbbdev //! Test that measured concurrency respects set limits
33151c0b2f7Stbbdev //! \brief \ref requirement
33251c0b2f7Stbbdev TEST_CASE("concurrency follows set limits"){
33351c0b2f7Stbbdev     test_node_concurrency();
33451c0b2f7Stbbdev }
33551c0b2f7Stbbdev 
33651c0b2f7Stbbdev //! Test calling function body
33751c0b2f7Stbbdev //! \brief \ref interface \ref requirement
33851c0b2f7Stbbdev TEST_CASE("Test function_node body") {
33951c0b2f7Stbbdev     test_func_body();
34051c0b2f7Stbbdev }
341