151c0b2f7Stbbdev /*
251c0b2f7Stbbdev     Copyright (c) 2020 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
1751c0b2f7Stbbdev #include "common/test.h"
1851c0b2f7Stbbdev 
1951c0b2f7Stbbdev #include "common/utils.h"
2051c0b2f7Stbbdev #include "common/graph_utils.h"
2151c0b2f7Stbbdev 
22*49e08aacStbbdev #include "oneapi/tbb/flow_graph.h"
23*49e08aacStbbdev #include "oneapi/tbb/task_arena.h"
2451c0b2f7Stbbdev 
25*49e08aacStbbdev #include "oneapi/tbb/global_control.h"
2651c0b2f7Stbbdev #include "conformance_flowgraph.h"
2751c0b2f7Stbbdev 
2851c0b2f7Stbbdev //! \file conformance_continue_node.cpp
2951c0b2f7Stbbdev //! \brief Test for [flow_graph.continue_node] specification
3051c0b2f7Stbbdev 
3151c0b2f7Stbbdev /*
3251c0b2f7Stbbdev TODO: implement missing conformance tests for continue_node:
3351c0b2f7Stbbdev   - [ ] For `test_forwarding' check that the value passed is the actual one received.
3451c0b2f7Stbbdev   - [ ] The `copy_body' function copies altered body (e.g. after its successful invocation).
3551c0b2f7Stbbdev   - [ ] Improve CTAD test.
3651c0b2f7Stbbdev   - [ ] Improve constructors test, including addition of calls to constructors with
3751c0b2f7Stbbdev     `number_of_predecessors' parameter.
3851c0b2f7Stbbdev   - [ ] Explicit test for copy constructor of the node.
3951c0b2f7Stbbdev   - [ ] Rewrite test_priority.
4051c0b2f7Stbbdev   - [ ] Check `Output' type indeed copy-constructed and copy-assigned while working with the node.
4151c0b2f7Stbbdev   - [ ] Explicit test for correct working of `number_of_predecessors' constructor parameter,
4251c0b2f7Stbbdev     including taking it into account when making and removing edges.
4351c0b2f7Stbbdev   - [ ] Add testing of `try_put' statement. In particular that it does not wait for the execution of
4451c0b2f7Stbbdev     the body to complete.
4551c0b2f7Stbbdev */
4651c0b2f7Stbbdev 
4751c0b2f7Stbbdev void test_cont_body(){
48*49e08aacStbbdev     oneapi::tbb::flow::graph g;
4951c0b2f7Stbbdev     inc_functor<int> cf;
5051c0b2f7Stbbdev     cf.execute_count = 0;
5151c0b2f7Stbbdev 
52*49e08aacStbbdev     oneapi::tbb::flow::continue_node<int> node1(g, cf);
5351c0b2f7Stbbdev 
5451c0b2f7Stbbdev     const size_t n = 10;
5551c0b2f7Stbbdev     for(size_t i = 0; i < n; ++i) {
56*49e08aacStbbdev         CHECK_MESSAGE((node1.try_put(oneapi::tbb::flow::continue_msg()) == true),
5751c0b2f7Stbbdev                       "continue_node::try_put() should never reject a message.");
5851c0b2f7Stbbdev     }
5951c0b2f7Stbbdev     g.wait_for_all();
6051c0b2f7Stbbdev 
6151c0b2f7Stbbdev     CHECK_MESSAGE( (cf.execute_count == n), "Body of the first node needs to be executed N times");
6251c0b2f7Stbbdev }
6351c0b2f7Stbbdev 
6451c0b2f7Stbbdev template<typename O>
6551c0b2f7Stbbdev void test_inheritance(){
66*49e08aacStbbdev     using namespace oneapi::tbb::flow;
6751c0b2f7Stbbdev 
6851c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<graph_node, continue_node<O>>::value), "continue_node should be derived from graph_node");
6951c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<receiver<continue_msg>, continue_node<O>>::value), "continue_node should be derived from receiver<Input>");
7051c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<sender<O>, continue_node<O>>::value), "continue_node should be derived from sender<Output>");
7151c0b2f7Stbbdev }
7251c0b2f7Stbbdev 
7351c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
7451c0b2f7Stbbdev void test_deduction_guides(){
75*49e08aacStbbdev     oneapi::tbb::flow::graph g;
7651c0b2f7Stbbdev     inc_functor<int> fun;
77*49e08aacStbbdev     oneapi::tbb::flow::continue_node node1(g, fun);
7851c0b2f7Stbbdev }
7951c0b2f7Stbbdev #endif
8051c0b2f7Stbbdev 
8151c0b2f7Stbbdev void test_forwarding(){
82*49e08aacStbbdev     oneapi::tbb::flow::graph g;
8351c0b2f7Stbbdev     inc_functor<int> fun;
8451c0b2f7Stbbdev     fun.execute_count = 0;
8551c0b2f7Stbbdev 
86*49e08aacStbbdev     oneapi::tbb::flow::continue_node<int> node1(g, fun);
8751c0b2f7Stbbdev     test_push_receiver<int> node2(g);
8851c0b2f7Stbbdev     test_push_receiver<int> node3(g);
8951c0b2f7Stbbdev 
90*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node2);
91*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node3);
9251c0b2f7Stbbdev 
93*49e08aacStbbdev     node1.try_put(oneapi::tbb::flow::continue_msg());
9451c0b2f7Stbbdev     g.wait_for_all();
9551c0b2f7Stbbdev 
9651c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message.");
9751c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message.");
9851c0b2f7Stbbdev }
9951c0b2f7Stbbdev 
10051c0b2f7Stbbdev void test_buffering(){
101*49e08aacStbbdev     oneapi::tbb::flow::graph g;
10251c0b2f7Stbbdev     inc_functor<int> fun;
10351c0b2f7Stbbdev 
104*49e08aacStbbdev     oneapi::tbb::flow::continue_node<int> node(g, fun);
105*49e08aacStbbdev     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
10651c0b2f7Stbbdev 
107*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node, rejecter);
108*49e08aacStbbdev     node.try_put(oneapi::tbb::flow::continue_msg());
10951c0b2f7Stbbdev 
11051c0b2f7Stbbdev     int tmp = -1;
11151c0b2f7Stbbdev     CHECK_MESSAGE( (node.try_get(tmp) == false), "try_get after rejection should not succeed");
11251c0b2f7Stbbdev     CHECK_MESSAGE( (tmp == -1), "try_get after rejection should not alter passed value");
11351c0b2f7Stbbdev     g.wait_for_all();
11451c0b2f7Stbbdev }
11551c0b2f7Stbbdev 
11651c0b2f7Stbbdev void test_policy_ctors(){
117*49e08aacStbbdev     using namespace oneapi::tbb::flow;
11851c0b2f7Stbbdev     graph g;
11951c0b2f7Stbbdev 
12051c0b2f7Stbbdev     inc_functor<int> fun;
12151c0b2f7Stbbdev 
12251c0b2f7Stbbdev     continue_node<int, lightweight> lw_node(g, fun);
12351c0b2f7Stbbdev }
12451c0b2f7Stbbdev 
12551c0b2f7Stbbdev void test_ctors(){
126*49e08aacStbbdev     using namespace oneapi::tbb::flow;
12751c0b2f7Stbbdev     graph g;
12851c0b2f7Stbbdev 
12951c0b2f7Stbbdev     inc_functor<int> fun;
13051c0b2f7Stbbdev 
131*49e08aacStbbdev     continue_node<int> proto1(g, 2, fun, oneapi::tbb::flow::node_priority_t(1));
13251c0b2f7Stbbdev }
13351c0b2f7Stbbdev 
13451c0b2f7Stbbdev template<typename O>
13551c0b2f7Stbbdev struct CopyCounterBody{
13651c0b2f7Stbbdev     size_t copy_count;
13751c0b2f7Stbbdev 
13851c0b2f7Stbbdev     CopyCounterBody():
13951c0b2f7Stbbdev         copy_count(0) {}
14051c0b2f7Stbbdev 
14151c0b2f7Stbbdev     CopyCounterBody(const CopyCounterBody<O>& other):
14251c0b2f7Stbbdev         copy_count(other.copy_count + 1) {}
14351c0b2f7Stbbdev 
14451c0b2f7Stbbdev     CopyCounterBody& operator=(const CopyCounterBody<O>& other){
14551c0b2f7Stbbdev         copy_count = other.copy_count + 1;
14651c0b2f7Stbbdev         return *this;
14751c0b2f7Stbbdev     }
14851c0b2f7Stbbdev 
149*49e08aacStbbdev     O operator()(oneapi::tbb::flow::continue_msg){
15051c0b2f7Stbbdev         return 1;
15151c0b2f7Stbbdev     }
15251c0b2f7Stbbdev };
15351c0b2f7Stbbdev 
15451c0b2f7Stbbdev void test_copies(){
155*49e08aacStbbdev     using namespace oneapi::tbb::flow;
15651c0b2f7Stbbdev 
15751c0b2f7Stbbdev     CopyCounterBody<int> b;
15851c0b2f7Stbbdev 
15951c0b2f7Stbbdev     graph g;
16051c0b2f7Stbbdev     continue_node<int> fn(g, b);
16151c0b2f7Stbbdev 
16251c0b2f7Stbbdev     CopyCounterBody<int> b2 = copy_body<CopyCounterBody<int>,
16351c0b2f7Stbbdev                                              continue_node<int>>(fn);
16451c0b2f7Stbbdev 
16551c0b2f7Stbbdev     CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
16651c0b2f7Stbbdev }
16751c0b2f7Stbbdev 
16851c0b2f7Stbbdev 
16951c0b2f7Stbbdev void test_priority(){
17051c0b2f7Stbbdev     size_t concurrency_limit = 1;
171*49e08aacStbbdev     oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit);
17251c0b2f7Stbbdev 
173*49e08aacStbbdev     oneapi::tbb::flow::graph g;
17451c0b2f7Stbbdev 
175*49e08aacStbbdev     oneapi::tbb::flow::continue_node<oneapi::tbb::flow::continue_msg> source(g,
176*49e08aacStbbdev                                                              [](oneapi::tbb::flow::continue_msg){ return oneapi::tbb::flow::continue_msg();});
177*49e08aacStbbdev     source.try_put(oneapi::tbb::flow::continue_msg());
17851c0b2f7Stbbdev 
17951c0b2f7Stbbdev     first_functor<int>::first_id = -1;
18051c0b2f7Stbbdev     first_functor<int> low_functor(1);
18151c0b2f7Stbbdev     first_functor<int> high_functor(2);
18251c0b2f7Stbbdev 
183*49e08aacStbbdev     oneapi::tbb::flow::continue_node<int, int> high(g, high_functor, oneapi::tbb::flow::node_priority_t(1));
184*49e08aacStbbdev     oneapi::tbb::flow::continue_node<int, int> low(g, low_functor);
18551c0b2f7Stbbdev 
18651c0b2f7Stbbdev     make_edge(source, low);
18751c0b2f7Stbbdev     make_edge(source, high);
18851c0b2f7Stbbdev 
18951c0b2f7Stbbdev     g.wait_for_all();
19051c0b2f7Stbbdev 
19151c0b2f7Stbbdev     CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first");
19251c0b2f7Stbbdev }
19351c0b2f7Stbbdev 
19451c0b2f7Stbbdev //! Test node costructors
19551c0b2f7Stbbdev //! \brief \ref requirement
19651c0b2f7Stbbdev TEST_CASE("continue_node constructors"){
19751c0b2f7Stbbdev     test_ctors();
19851c0b2f7Stbbdev }
19951c0b2f7Stbbdev 
20051c0b2f7Stbbdev //! Test priorities work in single-threaded configuration
20151c0b2f7Stbbdev //! \brief \ref requirement
20251c0b2f7Stbbdev TEST_CASE("continue_node priority support"){
20351c0b2f7Stbbdev     test_priority();
20451c0b2f7Stbbdev }
20551c0b2f7Stbbdev 
20651c0b2f7Stbbdev //! Test body copying and copy_body logic
20751c0b2f7Stbbdev //! \brief \ref interface
20851c0b2f7Stbbdev TEST_CASE("continue_node and body copying"){
20951c0b2f7Stbbdev     test_copies();
21051c0b2f7Stbbdev }
21151c0b2f7Stbbdev 
21251c0b2f7Stbbdev //! Test constructors
21351c0b2f7Stbbdev //! \brief \ref interface
21451c0b2f7Stbbdev TEST_CASE("continue_node constructors"){
21551c0b2f7Stbbdev     test_policy_ctors();
21651c0b2f7Stbbdev }
21751c0b2f7Stbbdev 
21851c0b2f7Stbbdev //! Test continue_node buffering
21951c0b2f7Stbbdev //! \brief \ref requirement
22051c0b2f7Stbbdev TEST_CASE("continue_node buffering"){
22151c0b2f7Stbbdev     test_buffering();
22251c0b2f7Stbbdev }
22351c0b2f7Stbbdev 
22451c0b2f7Stbbdev //! Test function_node broadcasting
22551c0b2f7Stbbdev //! \brief \ref requirement
22651c0b2f7Stbbdev TEST_CASE("continue_node broadcast"){
22751c0b2f7Stbbdev     test_forwarding();
22851c0b2f7Stbbdev }
22951c0b2f7Stbbdev 
23051c0b2f7Stbbdev //! Test deduction guides
23151c0b2f7Stbbdev //! \brief \ref interface \ref requirement
23251c0b2f7Stbbdev TEST_CASE("Deduction guides"){
23351c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
23451c0b2f7Stbbdev     test_deduction_guides();
23551c0b2f7Stbbdev #endif
23651c0b2f7Stbbdev }
23751c0b2f7Stbbdev 
23851c0b2f7Stbbdev //! Test inheritance relations
23951c0b2f7Stbbdev //! \brief \ref interface
24051c0b2f7Stbbdev TEST_CASE("continue_node superclasses"){
24151c0b2f7Stbbdev     test_inheritance<int>();
24251c0b2f7Stbbdev     test_inheritance<void*>();
24351c0b2f7Stbbdev }
24451c0b2f7Stbbdev 
24551c0b2f7Stbbdev //! Test body execution
24651c0b2f7Stbbdev //! \brief \ref interface \ref requirement
24751c0b2f7Stbbdev TEST_CASE("continue body") {
24851c0b2f7Stbbdev     test_cont_body();
24951c0b2f7Stbbdev }
250