151c0b2f7Stbbdev /*
251c0b2f7Stbbdev     Copyright (c) 2020 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
1751c0b2f7Stbbdev 
1851c0b2f7Stbbdev #include "common/test.h"
1951c0b2f7Stbbdev 
2051c0b2f7Stbbdev #include "common/utils.h"
2151c0b2f7Stbbdev #include "common/graph_utils.h"
2251c0b2f7Stbbdev 
23*49e08aacStbbdev #include "oneapi/tbb/flow_graph.h"
24*49e08aacStbbdev #include "oneapi/tbb/task_arena.h"
2551c0b2f7Stbbdev 
2651c0b2f7Stbbdev #include "conformance_flowgraph.h"
2751c0b2f7Stbbdev 
2851c0b2f7Stbbdev //! \file conformance_input_node.cpp
2951c0b2f7Stbbdev //! \brief Test for [flow_graph.input_node] specification
3051c0b2f7Stbbdev 
3151c0b2f7Stbbdev /*
3251c0b2f7Stbbdev TODO: implement missing conformance tests for input_node:
3351c0b2f7Stbbdev   - [ ] The `copy_body' function copies altered body (e.g. after its successful invocation).
3451c0b2f7Stbbdev   - [ ] Check that in `test_forwarding' the value passed is the actual one received.
3551c0b2f7Stbbdev   - [ ] Improve CTAD test to assert result node type.
3651c0b2f7Stbbdev   - [ ] Explicit test for copy constructor of the node.
3751c0b2f7Stbbdev   - [ ] Check `Output' type indeed copy-constructed and copy-assigned while working with the node.
3851c0b2f7Stbbdev   - [ ] Check node cannot have predecessors (Will ADL be of any help here?)
3951c0b2f7Stbbdev   - [ ] Check the node is serial and its body never invoked concurrently.
4051c0b2f7Stbbdev   - [ ] `try_get()' call testing: a call to body is made only when the internal buffer is empty.
4151c0b2f7Stbbdev */
4251c0b2f7Stbbdev 
4351c0b2f7Stbbdev std::atomic<size_t> global_execute_count;
4451c0b2f7Stbbdev 
4551c0b2f7Stbbdev template<typename OutputType>
4651c0b2f7Stbbdev struct input_functor {
4751c0b2f7Stbbdev     const size_t n;
4851c0b2f7Stbbdev 
4951c0b2f7Stbbdev     input_functor( ) : n(10) { }
5051c0b2f7Stbbdev     input_functor( const input_functor &f ) : n(f.n) {  }
5151c0b2f7Stbbdev     void operator=(const input_functor &f) { n = f.n; }
5251c0b2f7Stbbdev 
53*49e08aacStbbdev     OutputType operator()( oneapi::tbb::flow_control & fc ) {
5451c0b2f7Stbbdev        ++global_execute_count;
5551c0b2f7Stbbdev        if(global_execute_count > n){
5651c0b2f7Stbbdev            fc.stop();
5751c0b2f7Stbbdev            return OutputType();
5851c0b2f7Stbbdev        }
5951c0b2f7Stbbdev        return OutputType(global_execute_count.load());
6051c0b2f7Stbbdev     }
6151c0b2f7Stbbdev 
6251c0b2f7Stbbdev };
6351c0b2f7Stbbdev 
6451c0b2f7Stbbdev template<typename O>
6551c0b2f7Stbbdev struct CopyCounterBody{
6651c0b2f7Stbbdev     size_t copy_count;
6751c0b2f7Stbbdev 
6851c0b2f7Stbbdev     CopyCounterBody():
6951c0b2f7Stbbdev         copy_count(0) {}
7051c0b2f7Stbbdev 
7151c0b2f7Stbbdev     CopyCounterBody(const CopyCounterBody<O>& other):
7251c0b2f7Stbbdev         copy_count(other.copy_count + 1) {}
7351c0b2f7Stbbdev 
7451c0b2f7Stbbdev     CopyCounterBody& operator=(const CopyCounterBody<O>& other) {
7551c0b2f7Stbbdev         copy_count = other.copy_count + 1; return *this;
7651c0b2f7Stbbdev     }
7751c0b2f7Stbbdev 
78*49e08aacStbbdev     O operator()(oneapi::tbb::flow_control & fc){
7951c0b2f7Stbbdev         fc.stop();
8051c0b2f7Stbbdev         return O();
8151c0b2f7Stbbdev     }
8251c0b2f7Stbbdev };
8351c0b2f7Stbbdev 
8451c0b2f7Stbbdev 
8551c0b2f7Stbbdev void test_input_body(){
86*49e08aacStbbdev     oneapi::tbb::flow::graph g;
8751c0b2f7Stbbdev     input_functor<int> fun;
8851c0b2f7Stbbdev 
8951c0b2f7Stbbdev     global_execute_count = 0;
90*49e08aacStbbdev     oneapi::tbb::flow::input_node<int> node1(g, fun);
9151c0b2f7Stbbdev     test_push_receiver<int> node2(g);
9251c0b2f7Stbbdev 
93*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node2);
9451c0b2f7Stbbdev 
9551c0b2f7Stbbdev     node1.activate();
9651c0b2f7Stbbdev     g.wait_for_all();
9751c0b2f7Stbbdev 
9851c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 10), "Descendant of the node needs to be receive N messages");
9951c0b2f7Stbbdev     CHECK_MESSAGE( (global_execute_count == 10 + 1), "Body of the node needs to be executed N + 1 times");
10051c0b2f7Stbbdev }
10151c0b2f7Stbbdev 
10251c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
10351c0b2f7Stbbdev void test_deduction_guides(){
104*49e08aacStbbdev     oneapi::tbb::flow::graph g;
10551c0b2f7Stbbdev     input_functor<int> fun;
106*49e08aacStbbdev     oneapi::tbb::flow::input_node node1(g, fun);
10751c0b2f7Stbbdev }
10851c0b2f7Stbbdev #endif
10951c0b2f7Stbbdev 
11051c0b2f7Stbbdev void test_buffering(){
111*49e08aacStbbdev     oneapi::tbb::flow::graph g;
11251c0b2f7Stbbdev     input_functor<int> fun;
11351c0b2f7Stbbdev     global_execute_count = 0;
11451c0b2f7Stbbdev 
115*49e08aacStbbdev     oneapi::tbb::flow::input_node<int> source(g, fun);
116*49e08aacStbbdev     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
11751c0b2f7Stbbdev 
118*49e08aacStbbdev     oneapi::tbb::flow::make_edge(source, rejecter);
11951c0b2f7Stbbdev     source.activate();
12051c0b2f7Stbbdev     g.wait_for_all();
12151c0b2f7Stbbdev 
12251c0b2f7Stbbdev     int tmp = -1;
12351c0b2f7Stbbdev     CHECK_MESSAGE( (source.try_get(tmp) == true), "try_get after rejection should succeed");
12451c0b2f7Stbbdev     CHECK_MESSAGE( (tmp == 1), "try_get should return correct value");
12551c0b2f7Stbbdev }
12651c0b2f7Stbbdev 
12751c0b2f7Stbbdev void test_forwarding(){
128*49e08aacStbbdev     oneapi::tbb::flow::graph g;
12951c0b2f7Stbbdev     input_functor<int> fun;
13051c0b2f7Stbbdev 
13151c0b2f7Stbbdev     global_execute_count = 0;
132*49e08aacStbbdev     oneapi::tbb::flow::input_node<int> node1(g, fun);
13351c0b2f7Stbbdev     test_push_receiver<int> node2(g);
13451c0b2f7Stbbdev     test_push_receiver<int> node3(g);
13551c0b2f7Stbbdev 
136*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node2);
137*49e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node3);
13851c0b2f7Stbbdev 
13951c0b2f7Stbbdev     node1.activate();
14051c0b2f7Stbbdev     g.wait_for_all();
14151c0b2f7Stbbdev 
14251c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 10), "Descendant of the node needs to be receive N messages");
14351c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node3) == 10), "Descendant of the node needs to be receive N messages");
14451c0b2f7Stbbdev }
14551c0b2f7Stbbdev 
14651c0b2f7Stbbdev template<typename O>
14751c0b2f7Stbbdev void test_inheritance(){
148*49e08aacStbbdev     using namespace oneapi::tbb::flow;
14951c0b2f7Stbbdev 
15051c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<graph_node, input_node<O>>::value), "input_node should be derived from graph_node");
15151c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<sender<O>, input_node<O>>::value), "input_node should be derived from sender<Output>");
15251c0b2f7Stbbdev }
15351c0b2f7Stbbdev 
15451c0b2f7Stbbdev void test_copies(){
155*49e08aacStbbdev     using namespace oneapi::tbb::flow;
15651c0b2f7Stbbdev 
15751c0b2f7Stbbdev     CopyCounterBody<int> b;
15851c0b2f7Stbbdev 
15951c0b2f7Stbbdev     graph g;
16051c0b2f7Stbbdev     input_node<int> fn(g, b);
16151c0b2f7Stbbdev 
16251c0b2f7Stbbdev     CopyCounterBody<int> b2 = copy_body<CopyCounterBody<int>, input_node<int>>(fn);
16351c0b2f7Stbbdev 
16451c0b2f7Stbbdev     CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
16551c0b2f7Stbbdev }
16651c0b2f7Stbbdev 
16751c0b2f7Stbbdev //! Test body copying and copy_body logic
16851c0b2f7Stbbdev //! \brief \ref interface
16951c0b2f7Stbbdev TEST_CASE("input_node and body copying"){
17051c0b2f7Stbbdev     test_copies();
17151c0b2f7Stbbdev }
17251c0b2f7Stbbdev 
17351c0b2f7Stbbdev //! Test inheritance relations
17451c0b2f7Stbbdev //! \brief \ref interface
17551c0b2f7Stbbdev TEST_CASE("input_node superclasses"){
17651c0b2f7Stbbdev     test_inheritance<int>();
17751c0b2f7Stbbdev     test_inheritance<void*>();
17851c0b2f7Stbbdev }
17951c0b2f7Stbbdev 
18051c0b2f7Stbbdev //! Test input_node forwarding
18151c0b2f7Stbbdev //! \brief \ref requirement
18251c0b2f7Stbbdev TEST_CASE("input_node forwarding"){
18351c0b2f7Stbbdev     test_forwarding();
18451c0b2f7Stbbdev }
18551c0b2f7Stbbdev 
18651c0b2f7Stbbdev //! Test input_node buffering
18751c0b2f7Stbbdev //! \brief \ref requirement
18851c0b2f7Stbbdev TEST_CASE("input_node buffering"){
18951c0b2f7Stbbdev     test_buffering();
19051c0b2f7Stbbdev }
19151c0b2f7Stbbdev 
19251c0b2f7Stbbdev //! Test calling input_node body
19351c0b2f7Stbbdev //! \brief \ref interface \ref requirement
19451c0b2f7Stbbdev TEST_CASE("input_node body") {
19551c0b2f7Stbbdev     test_input_body();
19651c0b2f7Stbbdev }
19751c0b2f7Stbbdev 
19851c0b2f7Stbbdev //! Test deduction guides
19951c0b2f7Stbbdev //! \brief \ref interface \ref requirement
20051c0b2f7Stbbdev TEST_CASE("Deduction guides"){
20151c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
20251c0b2f7Stbbdev     test_deduction_guides();
20351c0b2f7Stbbdev #endif
20451c0b2f7Stbbdev }
205