151c0b2f7Stbbdev /*
2*b15aabb3Stbbdev     Copyright (c) 2020-2021 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
17*b15aabb3Stbbdev #if __INTEL_COMPILER && _MSC_VER
18*b15aabb3Stbbdev #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated
19*b15aabb3Stbbdev #endif
20*b15aabb3Stbbdev 
2151c0b2f7Stbbdev #include "common/test.h"
2251c0b2f7Stbbdev 
2351c0b2f7Stbbdev #include "common/utils.h"
2451c0b2f7Stbbdev #include "common/graph_utils.h"
2551c0b2f7Stbbdev 
2649e08aacStbbdev #include "oneapi/tbb/flow_graph.h"
2749e08aacStbbdev #include "oneapi/tbb/task_arena.h"
2849e08aacStbbdev #include "oneapi/tbb/global_control.h"
2951c0b2f7Stbbdev 
3051c0b2f7Stbbdev #include "conformance_flowgraph.h"
3151c0b2f7Stbbdev 
3251c0b2f7Stbbdev //! \file conformance_function_node.cpp
3351c0b2f7Stbbdev //! \brief Test for [flow_graph.function_node] specification
3451c0b2f7Stbbdev 
3551c0b2f7Stbbdev /*
3651c0b2f7Stbbdev TODO: implement missing conformance tests for function_node:
3751c0b2f7Stbbdev   - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body> function_node(
3851c0b2f7Stbbdev     graph &g, size_t concurrency, Body body, Policy(), node_priority_t, priority = no_priority )'
3951c0b2f7Stbbdev   - [ ] Explicit test for copy constructor of the node.
4051c0b2f7Stbbdev   - [ ] Rename test_broadcast to test_forwarding and check that the value passed is the actual one
4151c0b2f7Stbbdev     received.
4251c0b2f7Stbbdev   - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is
4349e08aacStbbdev     important to test at least on five values: 1, oneapi::tbb::flow::serial, `max_allowed_parallelism'
4449e08aacStbbdev     obtained from `oneapi::tbb::global_control', `oneapi::tbb::flow::unlimited', and, if `max allowed
4551c0b2f7Stbbdev     parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism]
4651c0b2f7Stbbdev     interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary).
4751c0b2f7Stbbdev   - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads;
4851c0b2f7Stbbdev     add check that `try_put()' returns `false'
4951c0b2f7Stbbdev   - [ ] The copy constructor and copy assignment are called for the node's input and output types.
5051c0b2f7Stbbdev   - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call).
5151c0b2f7Stbbdev   - [ ] Extend CTAD test to check all node's constructors.
5251c0b2f7Stbbdev */
5351c0b2f7Stbbdev 
5451c0b2f7Stbbdev std::atomic<size_t> my_concurrency;
5551c0b2f7Stbbdev std::atomic<size_t> my_max_concurrency;
5651c0b2f7Stbbdev 
5751c0b2f7Stbbdev template< typename OutputType >
5851c0b2f7Stbbdev struct concurrency_functor {
5951c0b2f7Stbbdev     OutputType operator()( int argument ) {
6051c0b2f7Stbbdev         ++my_concurrency;
6151c0b2f7Stbbdev 
6251c0b2f7Stbbdev         size_t old_value = my_max_concurrency;
6351c0b2f7Stbbdev         while(my_max_concurrency < my_concurrency &&
6451c0b2f7Stbbdev               !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency))
6551c0b2f7Stbbdev             ;
6651c0b2f7Stbbdev 
6751c0b2f7Stbbdev         size_t ms = 1000;
6851c0b2f7Stbbdev         std::chrono::milliseconds sleep_time( ms );
6951c0b2f7Stbbdev         std::this_thread::sleep_for( sleep_time );
7051c0b2f7Stbbdev 
7151c0b2f7Stbbdev         --my_concurrency;
7251c0b2f7Stbbdev        return argument;
7351c0b2f7Stbbdev     }
7451c0b2f7Stbbdev 
7551c0b2f7Stbbdev };
7651c0b2f7Stbbdev 
7751c0b2f7Stbbdev void test_func_body(){
7849e08aacStbbdev     oneapi::tbb::flow::graph g;
7951c0b2f7Stbbdev     inc_functor<int> fun;
8051c0b2f7Stbbdev     fun.execute_count = 0;
8151c0b2f7Stbbdev 
8249e08aacStbbdev     oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun);
8351c0b2f7Stbbdev 
8451c0b2f7Stbbdev     const size_t n = 10;
8551c0b2f7Stbbdev     for(size_t i = 0; i < n; ++i) {
8651c0b2f7Stbbdev         CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true");
8751c0b2f7Stbbdev     }
8851c0b2f7Stbbdev     g.wait_for_all();
8951c0b2f7Stbbdev 
9051c0b2f7Stbbdev     CHECK_MESSAGE( (fun.execute_count == n), "Body of the node needs to be executed N times");
9151c0b2f7Stbbdev }
9251c0b2f7Stbbdev 
9351c0b2f7Stbbdev void test_priority(){
9451c0b2f7Stbbdev     size_t concurrency_limit = 1;
9549e08aacStbbdev     oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit);
9651c0b2f7Stbbdev 
9749e08aacStbbdev     oneapi::tbb::flow::graph g;
9851c0b2f7Stbbdev 
9951c0b2f7Stbbdev     first_functor<int>::first_id.store(-1);
10051c0b2f7Stbbdev     first_functor<int> low_functor(1);
10151c0b2f7Stbbdev     first_functor<int> high_functor(2);
10251c0b2f7Stbbdev 
10349e08aacStbbdev     oneapi::tbb::flow::continue_node<int> source(g, [&](oneapi::tbb::flow::continue_msg){return 1;} );
10451c0b2f7Stbbdev 
10549e08aacStbbdev     oneapi::tbb::flow::function_node<int, int> high(g, oneapi::tbb::flow::unlimited, high_functor, oneapi::tbb::flow::node_priority_t(1));
10649e08aacStbbdev     oneapi::tbb::flow::function_node<int, int> low(g, oneapi::tbb::flow::unlimited, low_functor);
10751c0b2f7Stbbdev 
10851c0b2f7Stbbdev     make_edge(source, low);
10951c0b2f7Stbbdev     make_edge(source, high);
11051c0b2f7Stbbdev 
11149e08aacStbbdev     source.try_put(oneapi::tbb::flow::continue_msg());
11251c0b2f7Stbbdev     g.wait_for_all();
11351c0b2f7Stbbdev 
11451c0b2f7Stbbdev     CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first");
11551c0b2f7Stbbdev }
11651c0b2f7Stbbdev 
11751c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
11851c0b2f7Stbbdev void test_deduction_guides(){
11949e08aacStbbdev     using namespace oneapi::tbb::flow;
12051c0b2f7Stbbdev     graph g;
12151c0b2f7Stbbdev 
12251c0b2f7Stbbdev     auto body = [](const int&)->int { return 1; };
12351c0b2f7Stbbdev     function_node f1(g, unlimited, body);
12451c0b2f7Stbbdev     CHECK_MESSAGE((std::is_same_v<decltype(f1), function_node<int, int>>), "Function node type must be deducible from its body");
12551c0b2f7Stbbdev }
12651c0b2f7Stbbdev #endif
12751c0b2f7Stbbdev 
12851c0b2f7Stbbdev void test_broadcast(){
12949e08aacStbbdev     oneapi::tbb::flow::graph g;
13051c0b2f7Stbbdev     passthru_body fun;
13151c0b2f7Stbbdev 
13249e08aacStbbdev     oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun);
13351c0b2f7Stbbdev     test_push_receiver<int> node2(g);
13451c0b2f7Stbbdev     test_push_receiver<int> node3(g);
13551c0b2f7Stbbdev 
13649e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node2);
13749e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node3);
13851c0b2f7Stbbdev 
13951c0b2f7Stbbdev     node1.try_put(1);
14051c0b2f7Stbbdev     g.wait_for_all();
14151c0b2f7Stbbdev 
14251c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message.");
14351c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message.");
14451c0b2f7Stbbdev }
14551c0b2f7Stbbdev 
14651c0b2f7Stbbdev template<typename Policy>
14751c0b2f7Stbbdev void test_buffering(){
14849e08aacStbbdev     oneapi::tbb::flow::graph g;
14951c0b2f7Stbbdev     passthru_body fun;
15051c0b2f7Stbbdev 
15149e08aacStbbdev     oneapi::tbb::flow::function_node<int, int, Policy> node(g, oneapi::tbb::flow::unlimited, fun);
15249e08aacStbbdev     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
15351c0b2f7Stbbdev 
15449e08aacStbbdev     oneapi::tbb::flow::make_edge(node, rejecter);
15551c0b2f7Stbbdev     node.try_put(1);
15651c0b2f7Stbbdev 
15751c0b2f7Stbbdev     int tmp = -1;
15851c0b2f7Stbbdev     CHECK_MESSAGE( (node.try_get(tmp) == false), "try_get after rejection should not succeed");
15951c0b2f7Stbbdev     CHECK_MESSAGE( (tmp == -1), "try_get after rejection should not alter passed value");
16051c0b2f7Stbbdev     g.wait_for_all();
16151c0b2f7Stbbdev }
16251c0b2f7Stbbdev 
16351c0b2f7Stbbdev void test_node_concurrency(){
16451c0b2f7Stbbdev     my_concurrency = 0;
16551c0b2f7Stbbdev     my_max_concurrency = 0;
16651c0b2f7Stbbdev 
16749e08aacStbbdev     oneapi::tbb::flow::graph g;
16851c0b2f7Stbbdev     concurrency_functor<int> counter;
16949e08aacStbbdev     oneapi::tbb::flow::function_node <int, int> fnode(g, oneapi::tbb::flow::serial, counter);
17051c0b2f7Stbbdev 
17151c0b2f7Stbbdev     test_push_receiver<int> sink(g);
17251c0b2f7Stbbdev 
17351c0b2f7Stbbdev     make_edge(fnode, sink);
17451c0b2f7Stbbdev 
17551c0b2f7Stbbdev     for(int i = 0; i < 10; ++i){
17651c0b2f7Stbbdev         fnode.try_put(i);
17751c0b2f7Stbbdev     }
17851c0b2f7Stbbdev 
17951c0b2f7Stbbdev     g.wait_for_all();
18051c0b2f7Stbbdev 
18151c0b2f7Stbbdev     CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism is not expected");
18251c0b2f7Stbbdev }
18351c0b2f7Stbbdev 
18451c0b2f7Stbbdev template<typename I, typename O>
18551c0b2f7Stbbdev void test_inheritance(){
18649e08aacStbbdev     using namespace oneapi::tbb::flow;
18751c0b2f7Stbbdev 
18851c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<graph_node, function_node<I, O>>::value), "function_node should be derived from graph_node");
18951c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<receiver<I>, function_node<I, O>>::value), "function_node should be derived from receiver<Input>");
19051c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<sender<O>, function_node<I, O>>::value), "function_node should be derived from sender<Output>");
19151c0b2f7Stbbdev }
19251c0b2f7Stbbdev 
19351c0b2f7Stbbdev void test_policy_ctors(){
19449e08aacStbbdev     using namespace oneapi::tbb::flow;
19551c0b2f7Stbbdev     graph g;
19651c0b2f7Stbbdev 
19749e08aacStbbdev     function_node<int, int, lightweight> lw_node(g, oneapi::tbb::flow::serial,
19851c0b2f7Stbbdev                                                           [](int v) { return v;});
19949e08aacStbbdev     function_node<int, int, queueing_lightweight> qlw_node(g, oneapi::tbb::flow::serial,
20051c0b2f7Stbbdev                                                           [](int v) { return v;});
20149e08aacStbbdev     function_node<int, int, rejecting_lightweight> rlw_node(g, oneapi::tbb::flow::serial,
20251c0b2f7Stbbdev                                                           [](int v) { return v;});
20351c0b2f7Stbbdev 
20451c0b2f7Stbbdev }
20551c0b2f7Stbbdev 
20651c0b2f7Stbbdev class stateful_functor{
20751c0b2f7Stbbdev public:
20851c0b2f7Stbbdev     int stored;
20951c0b2f7Stbbdev     stateful_functor(): stored(-1){}
21051c0b2f7Stbbdev     int operator()(int value){ stored = 1; return value;}
21151c0b2f7Stbbdev };
21251c0b2f7Stbbdev 
21351c0b2f7Stbbdev void test_ctors(){
21449e08aacStbbdev     using namespace oneapi::tbb::flow;
21551c0b2f7Stbbdev     graph g;
21651c0b2f7Stbbdev 
21751c0b2f7Stbbdev     function_node<int, int> fn(g, unlimited, stateful_functor());
21851c0b2f7Stbbdev     fn.try_put(0);
21951c0b2f7Stbbdev     g.wait_for_all();
22051c0b2f7Stbbdev 
22151c0b2f7Stbbdev     stateful_functor b1 = copy_body<stateful_functor, function_node<int, int>>(fn);
22251c0b2f7Stbbdev     CHECK_MESSAGE( (b1.stored == 1), "First node should update");
22351c0b2f7Stbbdev 
22451c0b2f7Stbbdev     function_node<int, int> fn2(fn);
22551c0b2f7Stbbdev     stateful_functor b2 = copy_body<stateful_functor, function_node<int, int>>(fn2);
22651c0b2f7Stbbdev     CHECK_MESSAGE( (b2.stored == -1), "Copied node should not update");
22751c0b2f7Stbbdev }
22851c0b2f7Stbbdev 
22951c0b2f7Stbbdev template<typename I, typename O>
23051c0b2f7Stbbdev struct CopyCounterBody{
23151c0b2f7Stbbdev     size_t copy_count;
23251c0b2f7Stbbdev 
23351c0b2f7Stbbdev     CopyCounterBody():
23451c0b2f7Stbbdev         copy_count(0) {}
23551c0b2f7Stbbdev 
23651c0b2f7Stbbdev     CopyCounterBody(const CopyCounterBody<I, O>& other):
23751c0b2f7Stbbdev         copy_count(other.copy_count + 1) {}
23851c0b2f7Stbbdev 
23951c0b2f7Stbbdev     CopyCounterBody& operator=(const CopyCounterBody<I, O>& other)
24051c0b2f7Stbbdev     { copy_count = other.copy_count + 1; return *this;}
24151c0b2f7Stbbdev 
24251c0b2f7Stbbdev     O operator()(I in){
24351c0b2f7Stbbdev         return in;
24451c0b2f7Stbbdev     }
24551c0b2f7Stbbdev };
24651c0b2f7Stbbdev 
24751c0b2f7Stbbdev void test_copies(){
24849e08aacStbbdev     using namespace oneapi::tbb::flow;
24951c0b2f7Stbbdev 
25051c0b2f7Stbbdev     CopyCounterBody<int, int> b;
25151c0b2f7Stbbdev 
25251c0b2f7Stbbdev     graph g;
25351c0b2f7Stbbdev     function_node<int, int> fn(g, unlimited, b);
25451c0b2f7Stbbdev 
25551c0b2f7Stbbdev     CopyCounterBody<int, int> b2 = copy_body<CopyCounterBody<int, int>, function_node<int, int>>(fn);
25651c0b2f7Stbbdev 
25751c0b2f7Stbbdev     CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
25851c0b2f7Stbbdev }
25951c0b2f7Stbbdev 
26051c0b2f7Stbbdev void test_rejecting(){
26149e08aacStbbdev     oneapi::tbb::flow::graph g;
26249e08aacStbbdev     oneapi::tbb::flow::function_node <int, int, oneapi::tbb::flow::rejecting> fnode(g, oneapi::tbb::flow::serial,
26351c0b2f7Stbbdev                                                                     [&](int v){
26451c0b2f7Stbbdev                                                                         size_t ms = 50;
26551c0b2f7Stbbdev                                                                         std::chrono::milliseconds sleep_time( ms );
26651c0b2f7Stbbdev                                                                         std::this_thread::sleep_for( sleep_time );
26751c0b2f7Stbbdev                                                                         return v;
26851c0b2f7Stbbdev                                                                     });
26951c0b2f7Stbbdev 
27051c0b2f7Stbbdev     test_push_receiver<int> sink(g);
27151c0b2f7Stbbdev 
27251c0b2f7Stbbdev     make_edge(fnode, sink);
27351c0b2f7Stbbdev 
27451c0b2f7Stbbdev     for(int i = 0; i < 10; ++i){
27551c0b2f7Stbbdev         fnode.try_put(i);
27651c0b2f7Stbbdev     }
27751c0b2f7Stbbdev 
27851c0b2f7Stbbdev     g.wait_for_all();
27951c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed");
28051c0b2f7Stbbdev }
28151c0b2f7Stbbdev 
28251c0b2f7Stbbdev //! Test function_node with rejecting policy
28351c0b2f7Stbbdev //! \brief \ref interface
28451c0b2f7Stbbdev TEST_CASE("function_node with rejecting policy"){
28551c0b2f7Stbbdev     test_rejecting();
28651c0b2f7Stbbdev }
28751c0b2f7Stbbdev 
28851c0b2f7Stbbdev //! Test body copying and copy_body logic
28951c0b2f7Stbbdev //! \brief \ref interface
29051c0b2f7Stbbdev TEST_CASE("function_node and body copying"){
29151c0b2f7Stbbdev     test_copies();
29251c0b2f7Stbbdev }
29351c0b2f7Stbbdev 
29451c0b2f7Stbbdev //! Test constructors
29551c0b2f7Stbbdev //! \brief \ref interface
29651c0b2f7Stbbdev TEST_CASE("function_node constructors"){
29751c0b2f7Stbbdev     test_policy_ctors();
29851c0b2f7Stbbdev }
29951c0b2f7Stbbdev 
30051c0b2f7Stbbdev //! Test inheritance relations
30151c0b2f7Stbbdev //! \brief \ref interface
30251c0b2f7Stbbdev TEST_CASE("function_node superclasses"){
30351c0b2f7Stbbdev     test_inheritance<int, int>();
30451c0b2f7Stbbdev     test_inheritance<void*, float>();
30551c0b2f7Stbbdev }
30651c0b2f7Stbbdev 
30751c0b2f7Stbbdev //! Test function_node buffering
30851c0b2f7Stbbdev //! \brief \ref requirement
30951c0b2f7Stbbdev TEST_CASE("function_node buffering"){
31049e08aacStbbdev     test_buffering<oneapi::tbb::flow::rejecting>();
31149e08aacStbbdev     test_buffering<oneapi::tbb::flow::queueing>();
31251c0b2f7Stbbdev }
31351c0b2f7Stbbdev 
31451c0b2f7Stbbdev //! Test function_node broadcasting
31551c0b2f7Stbbdev //! \brief \ref requirement
31651c0b2f7Stbbdev TEST_CASE("function_node broadcast"){
31751c0b2f7Stbbdev     test_broadcast();
31851c0b2f7Stbbdev }
31951c0b2f7Stbbdev 
32051c0b2f7Stbbdev //! Test deduction guides
32151c0b2f7Stbbdev //! \brief \ref interface \ref requirement
32251c0b2f7Stbbdev TEST_CASE("Deduction guides"){
32351c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
32451c0b2f7Stbbdev     test_deduction_guides();
32551c0b2f7Stbbdev #endif
32651c0b2f7Stbbdev }
32751c0b2f7Stbbdev 
32851c0b2f7Stbbdev //! Test priorities work in single-threaded configuration
32951c0b2f7Stbbdev //! \brief \ref requirement
33051c0b2f7Stbbdev TEST_CASE("function_node priority support"){
33151c0b2f7Stbbdev     test_priority();
33251c0b2f7Stbbdev }
33351c0b2f7Stbbdev 
33451c0b2f7Stbbdev //! Test that measured concurrency respects set limits
33551c0b2f7Stbbdev //! \brief \ref requirement
33651c0b2f7Stbbdev TEST_CASE("concurrency follows set limits"){
33751c0b2f7Stbbdev     test_node_concurrency();
33851c0b2f7Stbbdev }
33951c0b2f7Stbbdev 
34051c0b2f7Stbbdev //! Test calling function body
34151c0b2f7Stbbdev //! \brief \ref interface \ref requirement
34251c0b2f7Stbbdev TEST_CASE("Test function_node body") {
34351c0b2f7Stbbdev     test_func_body();
34451c0b2f7Stbbdev }
345