151c0b2f7Stbbdev /*
2*b15aabb3Stbbdev     Copyright (c) 2020-2021 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
17*b15aabb3Stbbdev #if __INTEL_COMPILER && _MSC_VER
18*b15aabb3Stbbdev #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated
19*b15aabb3Stbbdev #endif
2051c0b2f7Stbbdev 
2151c0b2f7Stbbdev #include "common/test.h"
2251c0b2f7Stbbdev 
2351c0b2f7Stbbdev #include "common/utils.h"
2451c0b2f7Stbbdev #include "common/graph_utils.h"
2551c0b2f7Stbbdev 
2649e08aacStbbdev #include "oneapi/tbb/flow_graph.h"
2749e08aacStbbdev #include "oneapi/tbb/task_arena.h"
2851c0b2f7Stbbdev 
2951c0b2f7Stbbdev #include "conformance_flowgraph.h"
3051c0b2f7Stbbdev 
3151c0b2f7Stbbdev //! \file conformance_input_node.cpp
3251c0b2f7Stbbdev //! \brief Test for [flow_graph.input_node] specification
3351c0b2f7Stbbdev 
3451c0b2f7Stbbdev /*
3551c0b2f7Stbbdev TODO: implement missing conformance tests for input_node:
3651c0b2f7Stbbdev   - [ ] The `copy_body' function copies altered body (e.g. after its successful invocation).
3751c0b2f7Stbbdev   - [ ] Check that in `test_forwarding' the value passed is the actual one received.
3851c0b2f7Stbbdev   - [ ] Improve CTAD test to assert result node type.
3951c0b2f7Stbbdev   - [ ] Explicit test for copy constructor of the node.
4051c0b2f7Stbbdev   - [ ] Check `Output' type indeed copy-constructed and copy-assigned while working with the node.
4151c0b2f7Stbbdev   - [ ] Check node cannot have predecessors (Will ADL be of any help here?)
4251c0b2f7Stbbdev   - [ ] Check the node is serial and its body never invoked concurrently.
4351c0b2f7Stbbdev   - [ ] `try_get()' call testing: a call to body is made only when the internal buffer is empty.
4451c0b2f7Stbbdev */
4551c0b2f7Stbbdev 
4651c0b2f7Stbbdev std::atomic<size_t> global_execute_count;
4751c0b2f7Stbbdev 
4851c0b2f7Stbbdev template<typename OutputType>
4951c0b2f7Stbbdev struct input_functor {
5051c0b2f7Stbbdev     const size_t n;
5151c0b2f7Stbbdev 
5251c0b2f7Stbbdev     input_functor( ) : n(10) { }
5351c0b2f7Stbbdev     input_functor( const input_functor &f ) : n(f.n) {  }
5451c0b2f7Stbbdev     void operator=(const input_functor &f) { n = f.n; }
5551c0b2f7Stbbdev 
5649e08aacStbbdev     OutputType operator()( oneapi::tbb::flow_control & fc ) {
5751c0b2f7Stbbdev        ++global_execute_count;
5851c0b2f7Stbbdev        if(global_execute_count > n){
5951c0b2f7Stbbdev            fc.stop();
6051c0b2f7Stbbdev            return OutputType();
6151c0b2f7Stbbdev        }
6251c0b2f7Stbbdev        return OutputType(global_execute_count.load());
6351c0b2f7Stbbdev     }
6451c0b2f7Stbbdev 
6551c0b2f7Stbbdev };
6651c0b2f7Stbbdev 
6751c0b2f7Stbbdev template<typename O>
6851c0b2f7Stbbdev struct CopyCounterBody{
6951c0b2f7Stbbdev     size_t copy_count;
7051c0b2f7Stbbdev 
7151c0b2f7Stbbdev     CopyCounterBody():
7251c0b2f7Stbbdev         copy_count(0) {}
7351c0b2f7Stbbdev 
7451c0b2f7Stbbdev     CopyCounterBody(const CopyCounterBody<O>& other):
7551c0b2f7Stbbdev         copy_count(other.copy_count + 1) {}
7651c0b2f7Stbbdev 
7751c0b2f7Stbbdev     CopyCounterBody& operator=(const CopyCounterBody<O>& other) {
7851c0b2f7Stbbdev         copy_count = other.copy_count + 1; return *this;
7951c0b2f7Stbbdev     }
8051c0b2f7Stbbdev 
8149e08aacStbbdev     O operator()(oneapi::tbb::flow_control & fc){
8251c0b2f7Stbbdev         fc.stop();
8351c0b2f7Stbbdev         return O();
8451c0b2f7Stbbdev     }
8551c0b2f7Stbbdev };
8651c0b2f7Stbbdev 
8751c0b2f7Stbbdev 
8851c0b2f7Stbbdev void test_input_body(){
8949e08aacStbbdev     oneapi::tbb::flow::graph g;
9051c0b2f7Stbbdev     input_functor<int> fun;
9151c0b2f7Stbbdev 
9251c0b2f7Stbbdev     global_execute_count = 0;
9349e08aacStbbdev     oneapi::tbb::flow::input_node<int> node1(g, fun);
9451c0b2f7Stbbdev     test_push_receiver<int> node2(g);
9551c0b2f7Stbbdev 
9649e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node2);
9751c0b2f7Stbbdev 
9851c0b2f7Stbbdev     node1.activate();
9951c0b2f7Stbbdev     g.wait_for_all();
10051c0b2f7Stbbdev 
10151c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 10), "Descendant of the node needs to be receive N messages");
10251c0b2f7Stbbdev     CHECK_MESSAGE( (global_execute_count == 10 + 1), "Body of the node needs to be executed N + 1 times");
10351c0b2f7Stbbdev }
10451c0b2f7Stbbdev 
10551c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
10651c0b2f7Stbbdev void test_deduction_guides(){
10749e08aacStbbdev     oneapi::tbb::flow::graph g;
10851c0b2f7Stbbdev     input_functor<int> fun;
10949e08aacStbbdev     oneapi::tbb::flow::input_node node1(g, fun);
11051c0b2f7Stbbdev }
11151c0b2f7Stbbdev #endif
11251c0b2f7Stbbdev 
11351c0b2f7Stbbdev void test_buffering(){
11449e08aacStbbdev     oneapi::tbb::flow::graph g;
11551c0b2f7Stbbdev     input_functor<int> fun;
11651c0b2f7Stbbdev     global_execute_count = 0;
11751c0b2f7Stbbdev 
11849e08aacStbbdev     oneapi::tbb::flow::input_node<int> source(g, fun);
11949e08aacStbbdev     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
12051c0b2f7Stbbdev 
12149e08aacStbbdev     oneapi::tbb::flow::make_edge(source, rejecter);
12251c0b2f7Stbbdev     source.activate();
12351c0b2f7Stbbdev     g.wait_for_all();
12451c0b2f7Stbbdev 
12551c0b2f7Stbbdev     int tmp = -1;
12651c0b2f7Stbbdev     CHECK_MESSAGE( (source.try_get(tmp) == true), "try_get after rejection should succeed");
12751c0b2f7Stbbdev     CHECK_MESSAGE( (tmp == 1), "try_get should return correct value");
12851c0b2f7Stbbdev }
12951c0b2f7Stbbdev 
13051c0b2f7Stbbdev void test_forwarding(){
13149e08aacStbbdev     oneapi::tbb::flow::graph g;
13251c0b2f7Stbbdev     input_functor<int> fun;
13351c0b2f7Stbbdev 
13451c0b2f7Stbbdev     global_execute_count = 0;
13549e08aacStbbdev     oneapi::tbb::flow::input_node<int> node1(g, fun);
13651c0b2f7Stbbdev     test_push_receiver<int> node2(g);
13751c0b2f7Stbbdev     test_push_receiver<int> node3(g);
13851c0b2f7Stbbdev 
13949e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node2);
14049e08aacStbbdev     oneapi::tbb::flow::make_edge(node1, node3);
14151c0b2f7Stbbdev 
14251c0b2f7Stbbdev     node1.activate();
14351c0b2f7Stbbdev     g.wait_for_all();
14451c0b2f7Stbbdev 
14551c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 10), "Descendant of the node needs to be receive N messages");
14651c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node3) == 10), "Descendant of the node needs to be receive N messages");
14751c0b2f7Stbbdev }
14851c0b2f7Stbbdev 
14951c0b2f7Stbbdev template<typename O>
15051c0b2f7Stbbdev void test_inheritance(){
15149e08aacStbbdev     using namespace oneapi::tbb::flow;
15251c0b2f7Stbbdev 
15351c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<graph_node, input_node<O>>::value), "input_node should be derived from graph_node");
15451c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<sender<O>, input_node<O>>::value), "input_node should be derived from sender<Output>");
15551c0b2f7Stbbdev }
15651c0b2f7Stbbdev 
15751c0b2f7Stbbdev void test_copies(){
15849e08aacStbbdev     using namespace oneapi::tbb::flow;
15951c0b2f7Stbbdev 
16051c0b2f7Stbbdev     CopyCounterBody<int> b;
16151c0b2f7Stbbdev 
16251c0b2f7Stbbdev     graph g;
16351c0b2f7Stbbdev     input_node<int> fn(g, b);
16451c0b2f7Stbbdev 
16551c0b2f7Stbbdev     CopyCounterBody<int> b2 = copy_body<CopyCounterBody<int>, input_node<int>>(fn);
16651c0b2f7Stbbdev 
16751c0b2f7Stbbdev     CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
16851c0b2f7Stbbdev }
16951c0b2f7Stbbdev 
17051c0b2f7Stbbdev //! Test body copying and copy_body logic
17151c0b2f7Stbbdev //! \brief \ref interface
17251c0b2f7Stbbdev TEST_CASE("input_node and body copying"){
17351c0b2f7Stbbdev     test_copies();
17451c0b2f7Stbbdev }
17551c0b2f7Stbbdev 
17651c0b2f7Stbbdev //! Test inheritance relations
17751c0b2f7Stbbdev //! \brief \ref interface
17851c0b2f7Stbbdev TEST_CASE("input_node superclasses"){
17951c0b2f7Stbbdev     test_inheritance<int>();
18051c0b2f7Stbbdev     test_inheritance<void*>();
18151c0b2f7Stbbdev }
18251c0b2f7Stbbdev 
18351c0b2f7Stbbdev //! Test input_node forwarding
18451c0b2f7Stbbdev //! \brief \ref requirement
18551c0b2f7Stbbdev TEST_CASE("input_node forwarding"){
18651c0b2f7Stbbdev     test_forwarding();
18751c0b2f7Stbbdev }
18851c0b2f7Stbbdev 
18951c0b2f7Stbbdev //! Test input_node buffering
19051c0b2f7Stbbdev //! \brief \ref requirement
19151c0b2f7Stbbdev TEST_CASE("input_node buffering"){
19251c0b2f7Stbbdev     test_buffering();
19351c0b2f7Stbbdev }
19451c0b2f7Stbbdev 
19551c0b2f7Stbbdev //! Test calling input_node body
19651c0b2f7Stbbdev //! \brief \ref interface \ref requirement
19751c0b2f7Stbbdev TEST_CASE("input_node body") {
19851c0b2f7Stbbdev     test_input_body();
19951c0b2f7Stbbdev }
20051c0b2f7Stbbdev 
20151c0b2f7Stbbdev //! Test deduction guides
20251c0b2f7Stbbdev //! \brief \ref interface \ref requirement
20351c0b2f7Stbbdev TEST_CASE("Deduction guides"){
20451c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
20551c0b2f7Stbbdev     test_deduction_guides();
20651c0b2f7Stbbdev #endif
20751c0b2f7Stbbdev }
208