1 /*
2     Copyright (c) 2020-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #if __INTEL_COMPILER && _MSC_VER
18 #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated
19 #endif
20 
21 #include "common/test.h"
22 
23 #include "common/utils.h"
24 #include "common/graph_utils.h"
25 
26 #include "oneapi/tbb/flow_graph.h"
27 #include "oneapi/tbb/task_arena.h"
28 #include "oneapi/tbb/global_control.h"
29 
30 #include "conformance_flowgraph.h"
31 
32 //! \file conformance_function_node.cpp
33 //! \brief Test for [flow_graph.function_node] specification
34 
35 /*
36 TODO: implement missing conformance tests for function_node:
37   - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body> function_node(
38     graph &g, size_t concurrency, Body body, Policy(), node_priority_t, priority = no_priority )'
39   - [ ] Explicit test for copy constructor of the node.
40   - [ ] Rename test_broadcast to test_forwarding and check that the value passed is the actual one
41     received.
42   - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is
43     important to test at least on five values: 1, oneapi::tbb::flow::serial, `max_allowed_parallelism'
44     obtained from `oneapi::tbb::global_control', `oneapi::tbb::flow::unlimited', and, if `max allowed
45     parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism]
46     interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary).
47   - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads;
48     add check that `try_put()' returns `false'
49   - [ ] The copy constructor and copy assignment are called for the node's input and output types.
50   - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call).
51   - [ ] Extend CTAD test to check all node's constructors.
52 */
53 
54 std::atomic<size_t> my_concurrency;
55 std::atomic<size_t> my_max_concurrency;
56 
57 template< typename OutputType >
58 struct concurrency_functor {
59     OutputType operator()( int argument ) {
60         ++my_concurrency;
61 
62         size_t old_value = my_max_concurrency;
63         while(my_max_concurrency < my_concurrency &&
64               !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency))
65             ;
66 
67         size_t ms = 1000;
68         std::chrono::milliseconds sleep_time( ms );
69         std::this_thread::sleep_for( sleep_time );
70 
71         --my_concurrency;
72        return argument;
73     }
74 
75 };
76 
77 void test_func_body(){
78     oneapi::tbb::flow::graph g;
79     inc_functor<int> fun;
80     fun.execute_count = 0;
81 
82     oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun);
83 
84     const size_t n = 10;
85     for(size_t i = 0; i < n; ++i) {
86         CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true");
87     }
88     g.wait_for_all();
89 
90     CHECK_MESSAGE( (fun.execute_count == n), "Body of the node needs to be executed N times");
91 }
92 
93 void test_priority(){
94     size_t concurrency_limit = 1;
95     oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit);
96 
97     oneapi::tbb::flow::graph g;
98 
99     first_functor<int>::first_id.store(-1);
100     first_functor<int> low_functor(1);
101     first_functor<int> high_functor(2);
102 
103     oneapi::tbb::flow::continue_node<int> source(g, [&](oneapi::tbb::flow::continue_msg){return 1;} );
104 
105     oneapi::tbb::flow::function_node<int, int> high(g, oneapi::tbb::flow::unlimited, high_functor, oneapi::tbb::flow::node_priority_t(1));
106     oneapi::tbb::flow::function_node<int, int> low(g, oneapi::tbb::flow::unlimited, low_functor);
107 
108     make_edge(source, low);
109     make_edge(source, high);
110 
111     source.try_put(oneapi::tbb::flow::continue_msg());
112     g.wait_for_all();
113 
114     CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first");
115 }
116 
117 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
118 void test_deduction_guides(){
119     using namespace oneapi::tbb::flow;
120     graph g;
121 
122     auto body = [](const int&)->int { return 1; };
123     function_node f1(g, unlimited, body);
124     CHECK_MESSAGE((std::is_same_v<decltype(f1), function_node<int, int>>), "Function node type must be deducible from its body");
125 }
126 #endif
127 
128 void test_broadcast(){
129     oneapi::tbb::flow::graph g;
130     passthru_body fun;
131 
132     oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun);
133     test_push_receiver<int> node2(g);
134     test_push_receiver<int> node3(g);
135 
136     oneapi::tbb::flow::make_edge(node1, node2);
137     oneapi::tbb::flow::make_edge(node1, node3);
138 
139     node1.try_put(1);
140     g.wait_for_all();
141 
142     CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message.");
143     CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message.");
144 }
145 
146 template<typename Policy>
147 void test_buffering(){
148     oneapi::tbb::flow::graph g;
149     passthru_body fun;
150 
151     oneapi::tbb::flow::function_node<int, int, Policy> node(g, oneapi::tbb::flow::unlimited, fun);
152     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
153 
154     oneapi::tbb::flow::make_edge(node, rejecter);
155     node.try_put(1);
156 
157     int tmp = -1;
158     CHECK_MESSAGE( (node.try_get(tmp) == false), "try_get after rejection should not succeed");
159     CHECK_MESSAGE( (tmp == -1), "try_get after rejection should not alter passed value");
160     g.wait_for_all();
161 }
162 
163 void test_node_concurrency(){
164     my_concurrency = 0;
165     my_max_concurrency = 0;
166 
167     oneapi::tbb::flow::graph g;
168     concurrency_functor<int> counter;
169     oneapi::tbb::flow::function_node <int, int> fnode(g, oneapi::tbb::flow::serial, counter);
170 
171     test_push_receiver<int> sink(g);
172 
173     make_edge(fnode, sink);
174 
175     for(int i = 0; i < 10; ++i){
176         fnode.try_put(i);
177     }
178 
179     g.wait_for_all();
180 
181     CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism is not expected");
182 }
183 
184 template<typename I, typename O>
185 void test_inheritance(){
186     using namespace oneapi::tbb::flow;
187 
188     CHECK_MESSAGE( (std::is_base_of<graph_node, function_node<I, O>>::value), "function_node should be derived from graph_node");
189     CHECK_MESSAGE( (std::is_base_of<receiver<I>, function_node<I, O>>::value), "function_node should be derived from receiver<Input>");
190     CHECK_MESSAGE( (std::is_base_of<sender<O>, function_node<I, O>>::value), "function_node should be derived from sender<Output>");
191 }
192 
193 void test_policy_ctors(){
194     using namespace oneapi::tbb::flow;
195     graph g;
196 
197     function_node<int, int, lightweight> lw_node(g, oneapi::tbb::flow::serial,
198                                                           [](int v) { return v;});
199     function_node<int, int, queueing_lightweight> qlw_node(g, oneapi::tbb::flow::serial,
200                                                           [](int v) { return v;});
201     function_node<int, int, rejecting_lightweight> rlw_node(g, oneapi::tbb::flow::serial,
202                                                           [](int v) { return v;});
203 
204 }
205 
206 class stateful_functor{
207 public:
208     int stored;
209     stateful_functor(): stored(-1){}
210     int operator()(int value){ stored = 1; return value;}
211 };
212 
213 void test_ctors(){
214     using namespace oneapi::tbb::flow;
215     graph g;
216 
217     function_node<int, int> fn(g, unlimited, stateful_functor());
218     fn.try_put(0);
219     g.wait_for_all();
220 
221     stateful_functor b1 = copy_body<stateful_functor, function_node<int, int>>(fn);
222     CHECK_MESSAGE( (b1.stored == 1), "First node should update");
223 
224     function_node<int, int> fn2(fn);
225     stateful_functor b2 = copy_body<stateful_functor, function_node<int, int>>(fn2);
226     CHECK_MESSAGE( (b2.stored == -1), "Copied node should not update");
227 }
228 
229 template<typename I, typename O>
230 struct CopyCounterBody{
231     size_t copy_count;
232 
233     CopyCounterBody():
234         copy_count(0) {}
235 
236     CopyCounterBody(const CopyCounterBody<I, O>& other):
237         copy_count(other.copy_count + 1) {}
238 
239     CopyCounterBody& operator=(const CopyCounterBody<I, O>& other)
240     { copy_count = other.copy_count + 1; return *this;}
241 
242     O operator()(I in){
243         return in;
244     }
245 };
246 
247 void test_copies(){
248     using namespace oneapi::tbb::flow;
249 
250     CopyCounterBody<int, int> b;
251 
252     graph g;
253     function_node<int, int> fn(g, unlimited, b);
254 
255     CopyCounterBody<int, int> b2 = copy_body<CopyCounterBody<int, int>, function_node<int, int>>(fn);
256 
257     CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
258 }
259 
260 void test_rejecting(){
261     oneapi::tbb::flow::graph g;
262     oneapi::tbb::flow::function_node <int, int, oneapi::tbb::flow::rejecting> fnode(g, oneapi::tbb::flow::serial,
263                                                                     [&](int v){
264                                                                         size_t ms = 50;
265                                                                         std::chrono::milliseconds sleep_time( ms );
266                                                                         std::this_thread::sleep_for( sleep_time );
267                                                                         return v;
268                                                                     });
269 
270     test_push_receiver<int> sink(g);
271 
272     make_edge(fnode, sink);
273 
274     for(int i = 0; i < 10; ++i){
275         fnode.try_put(i);
276     }
277 
278     g.wait_for_all();
279     CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed");
280 }
281 
282 //! Test function_node with rejecting policy
283 //! \brief \ref interface
284 TEST_CASE("function_node with rejecting policy"){
285     test_rejecting();
286 }
287 
288 //! Test body copying and copy_body logic
289 //! \brief \ref interface
290 TEST_CASE("function_node and body copying"){
291     test_copies();
292 }
293 
294 //! Test constructors
295 //! \brief \ref interface
296 TEST_CASE("function_node constructors"){
297     test_policy_ctors();
298 }
299 
300 //! Test inheritance relations
301 //! \brief \ref interface
302 TEST_CASE("function_node superclasses"){
303     test_inheritance<int, int>();
304     test_inheritance<void*, float>();
305 }
306 
307 //! Test function_node buffering
308 //! \brief \ref requirement
309 TEST_CASE("function_node buffering"){
310     test_buffering<oneapi::tbb::flow::rejecting>();
311     test_buffering<oneapi::tbb::flow::queueing>();
312 }
313 
314 //! Test function_node broadcasting
315 //! \brief \ref requirement
316 TEST_CASE("function_node broadcast"){
317     test_broadcast();
318 }
319 
320 //! Test deduction guides
321 //! \brief \ref interface \ref requirement
322 TEST_CASE("Deduction guides"){
323 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
324     test_deduction_guides();
325 #endif
326 }
327 
328 //! Test priorities work in single-threaded configuration
329 //! \brief \ref requirement
330 TEST_CASE("function_node priority support"){
331     test_priority();
332 }
333 
334 //! Test that measured concurrency respects set limits
335 //! \brief \ref requirement
336 TEST_CASE("concurrency follows set limits"){
337     test_node_concurrency();
338 }
339 
340 //! Test calling function body
341 //! \brief \ref interface \ref requirement
342 TEST_CASE("Test function_node body") {
343     test_func_body();
344 }
345