1 /*
2     Copyright (c) 2020-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #if __INTEL_COMPILER && _MSC_VER
18 #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated
19 #endif
20 
21 #include "common/test.h"
22 
23 #include "common/utils.h"
24 #include "common/graph_utils.h"
25 
26 #include "oneapi/tbb/flow_graph.h"
27 #include "oneapi/tbb/task_arena.h"
28 
29 #include "conformance_flowgraph.h"
30 
31 //! \file conformance_input_node.cpp
32 //! \brief Test for [flow_graph.input_node] specification
33 
34 /*
35 TODO: implement missing conformance tests for input_node:
36   - [ ] The `copy_body' function copies altered body (e.g. after its successful invocation).
37   - [ ] Check that in `test_forwarding' the value passed is the actual one received.
38   - [ ] Improve CTAD test to assert result node type.
39   - [ ] Explicit test for copy constructor of the node.
40   - [ ] Check `Output' type indeed copy-constructed and copy-assigned while working with the node.
41   - [ ] Check node cannot have predecessors (Will ADL be of any help here?)
42   - [ ] Check the node is serial and its body never invoked concurrently.
43   - [ ] `try_get()' call testing: a call to body is made only when the internal buffer is empty.
44 */
45 
46 std::atomic<size_t> global_execute_count;
47 
48 template<typename OutputType>
49 struct input_functor {
50     const size_t n;
51 
52     input_functor( ) : n(10) { }
53     input_functor( const input_functor &f ) : n(f.n) {  }
54     void operator=(const input_functor &f) { n = f.n; }
55 
56     OutputType operator()( oneapi::tbb::flow_control & fc ) {
57        ++global_execute_count;
58        if(global_execute_count > n){
59            fc.stop();
60            return OutputType();
61        }
62        return OutputType(global_execute_count.load());
63     }
64 
65 };
66 
67 template<typename O>
68 struct CopyCounterBody{
69     size_t copy_count;
70 
71     CopyCounterBody():
72         copy_count(0) {}
73 
74     CopyCounterBody(const CopyCounterBody<O>& other):
75         copy_count(other.copy_count + 1) {}
76 
77     CopyCounterBody& operator=(const CopyCounterBody<O>& other) {
78         copy_count = other.copy_count + 1; return *this;
79     }
80 
81     O operator()(oneapi::tbb::flow_control & fc){
82         fc.stop();
83         return O();
84     }
85 };
86 
87 
88 void test_input_body(){
89     oneapi::tbb::flow::graph g;
90     input_functor<int> fun;
91 
92     global_execute_count = 0;
93     oneapi::tbb::flow::input_node<int> node1(g, fun);
94     test_push_receiver<int> node2(g);
95 
96     oneapi::tbb::flow::make_edge(node1, node2);
97 
98     node1.activate();
99     g.wait_for_all();
100 
101     CHECK_MESSAGE( (get_count(node2) == 10), "Descendant of the node needs to be receive N messages");
102     CHECK_MESSAGE( (global_execute_count == 10 + 1), "Body of the node needs to be executed N + 1 times");
103 }
104 
105 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
106 void test_deduction_guides(){
107     oneapi::tbb::flow::graph g;
108     input_functor<int> fun;
109     oneapi::tbb::flow::input_node node1(g, fun);
110 }
111 #endif
112 
113 void test_buffering(){
114     oneapi::tbb::flow::graph g;
115     input_functor<int> fun;
116     global_execute_count = 0;
117 
118     oneapi::tbb::flow::input_node<int> source(g, fun);
119     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
120 
121     oneapi::tbb::flow::make_edge(source, rejecter);
122     source.activate();
123     g.wait_for_all();
124 
125     int tmp = -1;
126     CHECK_MESSAGE( (source.try_get(tmp) == true), "try_get after rejection should succeed");
127     CHECK_MESSAGE( (tmp == 1), "try_get should return correct value");
128 }
129 
130 void test_forwarding(){
131     oneapi::tbb::flow::graph g;
132     input_functor<int> fun;
133 
134     global_execute_count = 0;
135     oneapi::tbb::flow::input_node<int> node1(g, fun);
136     test_push_receiver<int> node2(g);
137     test_push_receiver<int> node3(g);
138 
139     oneapi::tbb::flow::make_edge(node1, node2);
140     oneapi::tbb::flow::make_edge(node1, node3);
141 
142     node1.activate();
143     g.wait_for_all();
144 
145     CHECK_MESSAGE( (get_count(node2) == 10), "Descendant of the node needs to be receive N messages");
146     CHECK_MESSAGE( (get_count(node3) == 10), "Descendant of the node needs to be receive N messages");
147 }
148 
149 template<typename O>
150 void test_inheritance(){
151     using namespace oneapi::tbb::flow;
152 
153     CHECK_MESSAGE( (std::is_base_of<graph_node, input_node<O>>::value), "input_node should be derived from graph_node");
154     CHECK_MESSAGE( (std::is_base_of<sender<O>, input_node<O>>::value), "input_node should be derived from sender<Output>");
155 }
156 
157 void test_copies(){
158     using namespace oneapi::tbb::flow;
159 
160     CopyCounterBody<int> b;
161 
162     graph g;
163     input_node<int> fn(g, b);
164 
165     CopyCounterBody<int> b2 = copy_body<CopyCounterBody<int>, input_node<int>>(fn);
166 
167     CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
168 }
169 
170 //! Test body copying and copy_body logic
171 //! \brief \ref interface
172 TEST_CASE("input_node and body copying"){
173     test_copies();
174 }
175 
176 //! Test inheritance relations
177 //! \brief \ref interface
178 TEST_CASE("input_node superclasses"){
179     test_inheritance<int>();
180     test_inheritance<void*>();
181 }
182 
183 //! Test input_node forwarding
184 //! \brief \ref requirement
185 TEST_CASE("input_node forwarding"){
186     test_forwarding();
187 }
188 
189 //! Test input_node buffering
190 //! \brief \ref requirement
191 TEST_CASE("input_node buffering"){
192     test_buffering();
193 }
194 
195 //! Test calling input_node body
196 //! \brief \ref interface \ref requirement
197 TEST_CASE("input_node body") {
198     test_input_body();
199 }
200 
201 //! Test deduction guides
202 //! \brief \ref interface \ref requirement
203 TEST_CASE("Deduction guides"){
204 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
205     test_deduction_guides();
206 #endif
207 }
208