1 /*
2     Copyright (c) 2020 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #include "common/test.h"
18 
19 #include "common/utils.h"
20 #include "common/graph_utils.h"
21 
22 #include "oneapi/tbb/flow_graph.h"
23 #include "oneapi/tbb/task_arena.h"
24 #include "oneapi/tbb/global_control.h"
25 
26 #include "conformance_flowgraph.h"
27 
28 //! \file conformance_function_node.cpp
29 //! \brief Test for [flow_graph.function_node] specification
30 
31 /*
32 TODO: implement missing conformance tests for function_node:
33   - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body> function_node(
34     graph &g, size_t concurrency, Body body, Policy(), node_priority_t, priority = no_priority )'
35   - [ ] Explicit test for copy constructor of the node.
36   - [ ] Rename test_broadcast to test_forwarding and check that the value passed is the actual one
37     received.
38   - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is
39     important to test at least on five values: 1, oneapi::tbb::flow::serial, `max_allowed_parallelism'
40     obtained from `oneapi::tbb::global_control', `oneapi::tbb::flow::unlimited', and, if `max allowed
41     parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism]
42     interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary).
43   - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads;
44     add check that `try_put()' returns `false'
45   - [ ] The copy constructor and copy assignment are called for the node's input and output types.
46   - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call).
47   - [ ] Extend CTAD test to check all node's constructors.
48 */
49 
50 std::atomic<size_t> my_concurrency;
51 std::atomic<size_t> my_max_concurrency;
52 
53 template< typename OutputType >
54 struct concurrency_functor {
55     OutputType operator()( int argument ) {
56         ++my_concurrency;
57 
58         size_t old_value = my_max_concurrency;
59         while(my_max_concurrency < my_concurrency &&
60               !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency))
61             ;
62 
63         size_t ms = 1000;
64         std::chrono::milliseconds sleep_time( ms );
65         std::this_thread::sleep_for( sleep_time );
66 
67         --my_concurrency;
68        return argument;
69     }
70 
71 };
72 
73 void test_func_body(){
74     oneapi::tbb::flow::graph g;
75     inc_functor<int> fun;
76     fun.execute_count = 0;
77 
78     oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun);
79 
80     const size_t n = 10;
81     for(size_t i = 0; i < n; ++i) {
82         CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true");
83     }
84     g.wait_for_all();
85 
86     CHECK_MESSAGE( (fun.execute_count == n), "Body of the node needs to be executed N times");
87 }
88 
89 void test_priority(){
90     size_t concurrency_limit = 1;
91     oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit);
92 
93     oneapi::tbb::flow::graph g;
94 
95     first_functor<int>::first_id.store(-1);
96     first_functor<int> low_functor(1);
97     first_functor<int> high_functor(2);
98 
99     oneapi::tbb::flow::continue_node<int> source(g, [&](oneapi::tbb::flow::continue_msg){return 1;} );
100 
101     oneapi::tbb::flow::function_node<int, int> high(g, oneapi::tbb::flow::unlimited, high_functor, oneapi::tbb::flow::node_priority_t(1));
102     oneapi::tbb::flow::function_node<int, int> low(g, oneapi::tbb::flow::unlimited, low_functor);
103 
104     make_edge(source, low);
105     make_edge(source, high);
106 
107     source.try_put(oneapi::tbb::flow::continue_msg());
108     g.wait_for_all();
109 
110     CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first");
111 }
112 
113 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
114 void test_deduction_guides(){
115     using namespace oneapi::tbb::flow;
116     graph g;
117 
118     auto body = [](const int&)->int { return 1; };
119     function_node f1(g, unlimited, body);
120     CHECK_MESSAGE((std::is_same_v<decltype(f1), function_node<int, int>>), "Function node type must be deducible from its body");
121 }
122 #endif
123 
124 void test_broadcast(){
125     oneapi::tbb::flow::graph g;
126     passthru_body fun;
127 
128     oneapi::tbb::flow::function_node<int, int> node1(g, oneapi::tbb::flow::unlimited, fun);
129     test_push_receiver<int> node2(g);
130     test_push_receiver<int> node3(g);
131 
132     oneapi::tbb::flow::make_edge(node1, node2);
133     oneapi::tbb::flow::make_edge(node1, node3);
134 
135     node1.try_put(1);
136     g.wait_for_all();
137 
138     CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message.");
139     CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message.");
140 }
141 
142 template<typename Policy>
143 void test_buffering(){
144     oneapi::tbb::flow::graph g;
145     passthru_body fun;
146 
147     oneapi::tbb::flow::function_node<int, int, Policy> node(g, oneapi::tbb::flow::unlimited, fun);
148     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
149 
150     oneapi::tbb::flow::make_edge(node, rejecter);
151     node.try_put(1);
152 
153     int tmp = -1;
154     CHECK_MESSAGE( (node.try_get(tmp) == false), "try_get after rejection should not succeed");
155     CHECK_MESSAGE( (tmp == -1), "try_get after rejection should not alter passed value");
156     g.wait_for_all();
157 }
158 
159 void test_node_concurrency(){
160     my_concurrency = 0;
161     my_max_concurrency = 0;
162 
163     oneapi::tbb::flow::graph g;
164     concurrency_functor<int> counter;
165     oneapi::tbb::flow::function_node <int, int> fnode(g, oneapi::tbb::flow::serial, counter);
166 
167     test_push_receiver<int> sink(g);
168 
169     make_edge(fnode, sink);
170 
171     for(int i = 0; i < 10; ++i){
172         fnode.try_put(i);
173     }
174 
175     g.wait_for_all();
176 
177     CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism is not expected");
178 }
179 
180 template<typename I, typename O>
181 void test_inheritance(){
182     using namespace oneapi::tbb::flow;
183 
184     CHECK_MESSAGE( (std::is_base_of<graph_node, function_node<I, O>>::value), "function_node should be derived from graph_node");
185     CHECK_MESSAGE( (std::is_base_of<receiver<I>, function_node<I, O>>::value), "function_node should be derived from receiver<Input>");
186     CHECK_MESSAGE( (std::is_base_of<sender<O>, function_node<I, O>>::value), "function_node should be derived from sender<Output>");
187 }
188 
189 void test_policy_ctors(){
190     using namespace oneapi::tbb::flow;
191     graph g;
192 
193     function_node<int, int, lightweight> lw_node(g, oneapi::tbb::flow::serial,
194                                                           [](int v) { return v;});
195     function_node<int, int, queueing_lightweight> qlw_node(g, oneapi::tbb::flow::serial,
196                                                           [](int v) { return v;});
197     function_node<int, int, rejecting_lightweight> rlw_node(g, oneapi::tbb::flow::serial,
198                                                           [](int v) { return v;});
199 
200 }
201 
202 class stateful_functor{
203 public:
204     int stored;
205     stateful_functor(): stored(-1){}
206     int operator()(int value){ stored = 1; return value;}
207 };
208 
209 void test_ctors(){
210     using namespace oneapi::tbb::flow;
211     graph g;
212 
213     function_node<int, int> fn(g, unlimited, stateful_functor());
214     fn.try_put(0);
215     g.wait_for_all();
216 
217     stateful_functor b1 = copy_body<stateful_functor, function_node<int, int>>(fn);
218     CHECK_MESSAGE( (b1.stored == 1), "First node should update");
219 
220     function_node<int, int> fn2(fn);
221     stateful_functor b2 = copy_body<stateful_functor, function_node<int, int>>(fn2);
222     CHECK_MESSAGE( (b2.stored == -1), "Copied node should not update");
223 }
224 
225 template<typename I, typename O>
226 struct CopyCounterBody{
227     size_t copy_count;
228 
229     CopyCounterBody():
230         copy_count(0) {}
231 
232     CopyCounterBody(const CopyCounterBody<I, O>& other):
233         copy_count(other.copy_count + 1) {}
234 
235     CopyCounterBody& operator=(const CopyCounterBody<I, O>& other)
236     { copy_count = other.copy_count + 1; return *this;}
237 
238     O operator()(I in){
239         return in;
240     }
241 };
242 
243 void test_copies(){
244     using namespace oneapi::tbb::flow;
245 
246     CopyCounterBody<int, int> b;
247 
248     graph g;
249     function_node<int, int> fn(g, unlimited, b);
250 
251     CopyCounterBody<int, int> b2 = copy_body<CopyCounterBody<int, int>, function_node<int, int>>(fn);
252 
253     CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
254 }
255 
256 void test_rejecting(){
257     oneapi::tbb::flow::graph g;
258     oneapi::tbb::flow::function_node <int, int, oneapi::tbb::flow::rejecting> fnode(g, oneapi::tbb::flow::serial,
259                                                                     [&](int v){
260                                                                         size_t ms = 50;
261                                                                         std::chrono::milliseconds sleep_time( ms );
262                                                                         std::this_thread::sleep_for( sleep_time );
263                                                                         return v;
264                                                                     });
265 
266     test_push_receiver<int> sink(g);
267 
268     make_edge(fnode, sink);
269 
270     for(int i = 0; i < 10; ++i){
271         fnode.try_put(i);
272     }
273 
274     g.wait_for_all();
275     CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed");
276 }
277 
278 //! Test function_node with rejecting policy
279 //! \brief \ref interface
280 TEST_CASE("function_node with rejecting policy"){
281     test_rejecting();
282 }
283 
284 //! Test body copying and copy_body logic
285 //! \brief \ref interface
286 TEST_CASE("function_node and body copying"){
287     test_copies();
288 }
289 
290 //! Test constructors
291 //! \brief \ref interface
292 TEST_CASE("function_node constructors"){
293     test_policy_ctors();
294 }
295 
296 //! Test inheritance relations
297 //! \brief \ref interface
298 TEST_CASE("function_node superclasses"){
299     test_inheritance<int, int>();
300     test_inheritance<void*, float>();
301 }
302 
303 //! Test function_node buffering
304 //! \brief \ref requirement
305 TEST_CASE("function_node buffering"){
306     test_buffering<oneapi::tbb::flow::rejecting>();
307     test_buffering<oneapi::tbb::flow::queueing>();
308 }
309 
310 //! Test function_node broadcasting
311 //! \brief \ref requirement
312 TEST_CASE("function_node broadcast"){
313     test_broadcast();
314 }
315 
316 //! Test deduction guides
317 //! \brief \ref interface \ref requirement
318 TEST_CASE("Deduction guides"){
319 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
320     test_deduction_guides();
321 #endif
322 }
323 
324 //! Test priorities work in single-threaded configuration
325 //! \brief \ref requirement
326 TEST_CASE("function_node priority support"){
327     test_priority();
328 }
329 
330 //! Test that measured concurrency respects set limits
331 //! \brief \ref requirement
332 TEST_CASE("concurrency follows set limits"){
333     test_node_concurrency();
334 }
335 
336 //! Test calling function body
337 //! \brief \ref interface \ref requirement
338 TEST_CASE("Test function_node body") {
339     test_func_body();
340 }
341