1 /*
2     Copyright (c) 2020 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 
18 #include "common/test.h"
19 
20 #include "common/utils.h"
21 #include "common/graph_utils.h"
22 
23 #include "oneapi/tbb/flow_graph.h"
24 #include "oneapi/tbb/task_arena.h"
25 
26 #include "conformance_flowgraph.h"
27 
28 //! \file conformance_input_node.cpp
29 //! \brief Test for [flow_graph.input_node] specification
30 
31 /*
32 TODO: implement missing conformance tests for input_node:
33   - [ ] The `copy_body' function copies altered body (e.g. after its successful invocation).
34   - [ ] Check that in `test_forwarding' the value passed is the actual one received.
35   - [ ] Improve CTAD test to assert result node type.
36   - [ ] Explicit test for copy constructor of the node.
37   - [ ] Check `Output' type indeed copy-constructed and copy-assigned while working with the node.
38   - [ ] Check node cannot have predecessors (Will ADL be of any help here?)
39   - [ ] Check the node is serial and its body never invoked concurrently.
40   - [ ] `try_get()' call testing: a call to body is made only when the internal buffer is empty.
41 */
42 
43 std::atomic<size_t> global_execute_count;
44 
45 template<typename OutputType>
46 struct input_functor {
47     const size_t n;
48 
49     input_functor( ) : n(10) { }
50     input_functor( const input_functor &f ) : n(f.n) {  }
51     void operator=(const input_functor &f) { n = f.n; }
52 
53     OutputType operator()( oneapi::tbb::flow_control & fc ) {
54        ++global_execute_count;
55        if(global_execute_count > n){
56            fc.stop();
57            return OutputType();
58        }
59        return OutputType(global_execute_count.load());
60     }
61 
62 };
63 
64 template<typename O>
65 struct CopyCounterBody{
66     size_t copy_count;
67 
68     CopyCounterBody():
69         copy_count(0) {}
70 
71     CopyCounterBody(const CopyCounterBody<O>& other):
72         copy_count(other.copy_count + 1) {}
73 
74     CopyCounterBody& operator=(const CopyCounterBody<O>& other) {
75         copy_count = other.copy_count + 1; return *this;
76     }
77 
78     O operator()(oneapi::tbb::flow_control & fc){
79         fc.stop();
80         return O();
81     }
82 };
83 
84 
85 void test_input_body(){
86     oneapi::tbb::flow::graph g;
87     input_functor<int> fun;
88 
89     global_execute_count = 0;
90     oneapi::tbb::flow::input_node<int> node1(g, fun);
91     test_push_receiver<int> node2(g);
92 
93     oneapi::tbb::flow::make_edge(node1, node2);
94 
95     node1.activate();
96     g.wait_for_all();
97 
98     CHECK_MESSAGE( (get_count(node2) == 10), "Descendant of the node needs to be receive N messages");
99     CHECK_MESSAGE( (global_execute_count == 10 + 1), "Body of the node needs to be executed N + 1 times");
100 }
101 
102 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
103 void test_deduction_guides(){
104     oneapi::tbb::flow::graph g;
105     input_functor<int> fun;
106     oneapi::tbb::flow::input_node node1(g, fun);
107 }
108 #endif
109 
110 void test_buffering(){
111     oneapi::tbb::flow::graph g;
112     input_functor<int> fun;
113     global_execute_count = 0;
114 
115     oneapi::tbb::flow::input_node<int> source(g, fun);
116     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
117 
118     oneapi::tbb::flow::make_edge(source, rejecter);
119     source.activate();
120     g.wait_for_all();
121 
122     int tmp = -1;
123     CHECK_MESSAGE( (source.try_get(tmp) == true), "try_get after rejection should succeed");
124     CHECK_MESSAGE( (tmp == 1), "try_get should return correct value");
125 }
126 
127 void test_forwarding(){
128     oneapi::tbb::flow::graph g;
129     input_functor<int> fun;
130 
131     global_execute_count = 0;
132     oneapi::tbb::flow::input_node<int> node1(g, fun);
133     test_push_receiver<int> node2(g);
134     test_push_receiver<int> node3(g);
135 
136     oneapi::tbb::flow::make_edge(node1, node2);
137     oneapi::tbb::flow::make_edge(node1, node3);
138 
139     node1.activate();
140     g.wait_for_all();
141 
142     CHECK_MESSAGE( (get_count(node2) == 10), "Descendant of the node needs to be receive N messages");
143     CHECK_MESSAGE( (get_count(node3) == 10), "Descendant of the node needs to be receive N messages");
144 }
145 
146 template<typename O>
147 void test_inheritance(){
148     using namespace oneapi::tbb::flow;
149 
150     CHECK_MESSAGE( (std::is_base_of<graph_node, input_node<O>>::value), "input_node should be derived from graph_node");
151     CHECK_MESSAGE( (std::is_base_of<sender<O>, input_node<O>>::value), "input_node should be derived from sender<Output>");
152 }
153 
154 void test_copies(){
155     using namespace oneapi::tbb::flow;
156 
157     CopyCounterBody<int> b;
158 
159     graph g;
160     input_node<int> fn(g, b);
161 
162     CopyCounterBody<int> b2 = copy_body<CopyCounterBody<int>, input_node<int>>(fn);
163 
164     CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
165 }
166 
167 //! Test body copying and copy_body logic
168 //! \brief \ref interface
169 TEST_CASE("input_node and body copying"){
170     test_copies();
171 }
172 
173 //! Test inheritance relations
174 //! \brief \ref interface
175 TEST_CASE("input_node superclasses"){
176     test_inheritance<int>();
177     test_inheritance<void*>();
178 }
179 
180 //! Test input_node forwarding
181 //! \brief \ref requirement
182 TEST_CASE("input_node forwarding"){
183     test_forwarding();
184 }
185 
186 //! Test input_node buffering
187 //! \brief \ref requirement
188 TEST_CASE("input_node buffering"){
189     test_buffering();
190 }
191 
192 //! Test calling input_node body
193 //! \brief \ref interface \ref requirement
194 TEST_CASE("input_node body") {
195     test_input_body();
196 }
197 
198 //! Test deduction guides
199 //! \brief \ref interface \ref requirement
200 TEST_CASE("Deduction guides"){
201 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
202     test_deduction_guides();
203 #endif
204 }
205