1*51c0b2f7Stbbdev /*
2*51c0b2f7Stbbdev     Copyright (c) 2020 Intel Corporation
3*51c0b2f7Stbbdev 
4*51c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
5*51c0b2f7Stbbdev     you may not use this file except in compliance with the License.
6*51c0b2f7Stbbdev     You may obtain a copy of the License at
7*51c0b2f7Stbbdev 
8*51c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
9*51c0b2f7Stbbdev 
10*51c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
11*51c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
12*51c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*51c0b2f7Stbbdev     See the License for the specific language governing permissions and
14*51c0b2f7Stbbdev     limitations under the License.
15*51c0b2f7Stbbdev */
16*51c0b2f7Stbbdev 
17*51c0b2f7Stbbdev #include "common/test.h"
18*51c0b2f7Stbbdev 
19*51c0b2f7Stbbdev #include "common/utils.h"
20*51c0b2f7Stbbdev #include "common/graph_utils.h"
21*51c0b2f7Stbbdev 
22*51c0b2f7Stbbdev #include "tbb/flow_graph.h"
23*51c0b2f7Stbbdev #include "tbb/task_arena.h"
24*51c0b2f7Stbbdev #include "tbb/global_control.h"
25*51c0b2f7Stbbdev 
26*51c0b2f7Stbbdev #include "conformance_flowgraph.h"
27*51c0b2f7Stbbdev 
28*51c0b2f7Stbbdev //! \file conformance_function_node.cpp
29*51c0b2f7Stbbdev //! \brief Test for [flow_graph.function_node] specification
30*51c0b2f7Stbbdev 
31*51c0b2f7Stbbdev /*
32*51c0b2f7Stbbdev TODO: implement missing conformance tests for function_node:
33*51c0b2f7Stbbdev   - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body> function_node(
34*51c0b2f7Stbbdev     graph &g, size_t concurrency, Body body, Policy(), node_priority_t, priority = no_priority )'
35*51c0b2f7Stbbdev   - [ ] Explicit test for copy constructor of the node.
36*51c0b2f7Stbbdev   - [ ] Rename test_broadcast to test_forwarding and check that the value passed is the actual one
37*51c0b2f7Stbbdev     received.
38*51c0b2f7Stbbdev   - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is
39*51c0b2f7Stbbdev     important to test at least on five values: 1, tbb::flow::serial, `max_allowed_parallelism'
40*51c0b2f7Stbbdev     obtained from `tbb::global_control', `tbb::flow::unlimited', and, if `max allowed
41*51c0b2f7Stbbdev     parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism]
42*51c0b2f7Stbbdev     interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary).
43*51c0b2f7Stbbdev   - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads;
44*51c0b2f7Stbbdev     add check that `try_put()' returns `false'
45*51c0b2f7Stbbdev   - [ ] The copy constructor and copy assignment are called for the node's input and output types.
46*51c0b2f7Stbbdev   - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call).
47*51c0b2f7Stbbdev   - [ ] Extend CTAD test to check all node's constructors.
48*51c0b2f7Stbbdev */
49*51c0b2f7Stbbdev 
50*51c0b2f7Stbbdev std::atomic<size_t> my_concurrency;
51*51c0b2f7Stbbdev std::atomic<size_t> my_max_concurrency;
52*51c0b2f7Stbbdev 
53*51c0b2f7Stbbdev template< typename OutputType >
54*51c0b2f7Stbbdev struct concurrency_functor {
55*51c0b2f7Stbbdev     OutputType operator()( int argument ) {
56*51c0b2f7Stbbdev         ++my_concurrency;
57*51c0b2f7Stbbdev 
58*51c0b2f7Stbbdev         size_t old_value = my_max_concurrency;
59*51c0b2f7Stbbdev         while(my_max_concurrency < my_concurrency &&
60*51c0b2f7Stbbdev               !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency))
61*51c0b2f7Stbbdev             ;
62*51c0b2f7Stbbdev 
63*51c0b2f7Stbbdev         size_t ms = 1000;
64*51c0b2f7Stbbdev         std::chrono::milliseconds sleep_time( ms );
65*51c0b2f7Stbbdev         std::this_thread::sleep_for( sleep_time );
66*51c0b2f7Stbbdev 
67*51c0b2f7Stbbdev         --my_concurrency;
68*51c0b2f7Stbbdev        return argument;
69*51c0b2f7Stbbdev     }
70*51c0b2f7Stbbdev 
71*51c0b2f7Stbbdev };
72*51c0b2f7Stbbdev 
73*51c0b2f7Stbbdev void test_func_body(){
74*51c0b2f7Stbbdev     tbb::flow::graph g;
75*51c0b2f7Stbbdev     inc_functor<int> fun;
76*51c0b2f7Stbbdev     fun.execute_count = 0;
77*51c0b2f7Stbbdev 
78*51c0b2f7Stbbdev     tbb::flow::function_node<int, int> node1(g, tbb::flow::unlimited, fun);
79*51c0b2f7Stbbdev 
80*51c0b2f7Stbbdev     const size_t n = 10;
81*51c0b2f7Stbbdev     for(size_t i = 0; i < n; ++i) {
82*51c0b2f7Stbbdev         CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true");
83*51c0b2f7Stbbdev     }
84*51c0b2f7Stbbdev     g.wait_for_all();
85*51c0b2f7Stbbdev 
86*51c0b2f7Stbbdev     CHECK_MESSAGE( (fun.execute_count == n), "Body of the node needs to be executed N times");
87*51c0b2f7Stbbdev }
88*51c0b2f7Stbbdev 
89*51c0b2f7Stbbdev void test_priority(){
90*51c0b2f7Stbbdev     size_t concurrency_limit = 1;
91*51c0b2f7Stbbdev     tbb::global_control control(tbb::global_control::max_allowed_parallelism, concurrency_limit);
92*51c0b2f7Stbbdev 
93*51c0b2f7Stbbdev     tbb::flow::graph g;
94*51c0b2f7Stbbdev 
95*51c0b2f7Stbbdev     first_functor<int>::first_id.store(-1);
96*51c0b2f7Stbbdev     first_functor<int> low_functor(1);
97*51c0b2f7Stbbdev     first_functor<int> high_functor(2);
98*51c0b2f7Stbbdev 
99*51c0b2f7Stbbdev     tbb::flow::continue_node<int> source(g, [&](tbb::flow::continue_msg){return 1;} );
100*51c0b2f7Stbbdev 
101*51c0b2f7Stbbdev     tbb::flow::function_node<int, int> high(g, tbb::flow::unlimited, high_functor, tbb::flow::node_priority_t(1));
102*51c0b2f7Stbbdev     tbb::flow::function_node<int, int> low(g, tbb::flow::unlimited, low_functor);
103*51c0b2f7Stbbdev 
104*51c0b2f7Stbbdev     make_edge(source, low);
105*51c0b2f7Stbbdev     make_edge(source, high);
106*51c0b2f7Stbbdev 
107*51c0b2f7Stbbdev     source.try_put(tbb::flow::continue_msg());
108*51c0b2f7Stbbdev     g.wait_for_all();
109*51c0b2f7Stbbdev 
110*51c0b2f7Stbbdev     CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first");
111*51c0b2f7Stbbdev }
112*51c0b2f7Stbbdev 
113*51c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
114*51c0b2f7Stbbdev void test_deduction_guides(){
115*51c0b2f7Stbbdev     using namespace tbb::flow;
116*51c0b2f7Stbbdev     graph g;
117*51c0b2f7Stbbdev 
118*51c0b2f7Stbbdev     auto body = [](const int&)->int { return 1; };
119*51c0b2f7Stbbdev     function_node f1(g, unlimited, body);
120*51c0b2f7Stbbdev     CHECK_MESSAGE((std::is_same_v<decltype(f1), function_node<int, int>>), "Function node type must be deducible from its body");
121*51c0b2f7Stbbdev }
122*51c0b2f7Stbbdev #endif
123*51c0b2f7Stbbdev 
124*51c0b2f7Stbbdev void test_broadcast(){
125*51c0b2f7Stbbdev     tbb::flow::graph g;
126*51c0b2f7Stbbdev     passthru_body fun;
127*51c0b2f7Stbbdev 
128*51c0b2f7Stbbdev     tbb::flow::function_node<int, int> node1(g, tbb::flow::unlimited, fun);
129*51c0b2f7Stbbdev     test_push_receiver<int> node2(g);
130*51c0b2f7Stbbdev     test_push_receiver<int> node3(g);
131*51c0b2f7Stbbdev 
132*51c0b2f7Stbbdev     tbb::flow::make_edge(node1, node2);
133*51c0b2f7Stbbdev     tbb::flow::make_edge(node1, node3);
134*51c0b2f7Stbbdev 
135*51c0b2f7Stbbdev     node1.try_put(1);
136*51c0b2f7Stbbdev     g.wait_for_all();
137*51c0b2f7Stbbdev 
138*51c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message.");
139*51c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message.");
140*51c0b2f7Stbbdev }
141*51c0b2f7Stbbdev 
142*51c0b2f7Stbbdev template<typename Policy>
143*51c0b2f7Stbbdev void test_buffering(){
144*51c0b2f7Stbbdev     tbb::flow::graph g;
145*51c0b2f7Stbbdev     passthru_body fun;
146*51c0b2f7Stbbdev 
147*51c0b2f7Stbbdev     tbb::flow::function_node<int, int, Policy> node(g, tbb::flow::unlimited, fun);
148*51c0b2f7Stbbdev     tbb::flow::limiter_node<int> rejecter(g, 0);
149*51c0b2f7Stbbdev 
150*51c0b2f7Stbbdev     tbb::flow::make_edge(node, rejecter);
151*51c0b2f7Stbbdev     node.try_put(1);
152*51c0b2f7Stbbdev 
153*51c0b2f7Stbbdev     int tmp = -1;
154*51c0b2f7Stbbdev     CHECK_MESSAGE( (node.try_get(tmp) == false), "try_get after rejection should not succeed");
155*51c0b2f7Stbbdev     CHECK_MESSAGE( (tmp == -1), "try_get after rejection should not alter passed value");
156*51c0b2f7Stbbdev     g.wait_for_all();
157*51c0b2f7Stbbdev }
158*51c0b2f7Stbbdev 
159*51c0b2f7Stbbdev void test_node_concurrency(){
160*51c0b2f7Stbbdev     my_concurrency = 0;
161*51c0b2f7Stbbdev     my_max_concurrency = 0;
162*51c0b2f7Stbbdev 
163*51c0b2f7Stbbdev     tbb::flow::graph g;
164*51c0b2f7Stbbdev     concurrency_functor<int> counter;
165*51c0b2f7Stbbdev     tbb::flow::function_node <int, int> fnode(g, tbb::flow::serial, counter);
166*51c0b2f7Stbbdev 
167*51c0b2f7Stbbdev     test_push_receiver<int> sink(g);
168*51c0b2f7Stbbdev 
169*51c0b2f7Stbbdev     make_edge(fnode, sink);
170*51c0b2f7Stbbdev 
171*51c0b2f7Stbbdev     for(int i = 0; i < 10; ++i){
172*51c0b2f7Stbbdev         fnode.try_put(i);
173*51c0b2f7Stbbdev     }
174*51c0b2f7Stbbdev 
175*51c0b2f7Stbbdev     g.wait_for_all();
176*51c0b2f7Stbbdev 
177*51c0b2f7Stbbdev     CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism is not expected");
178*51c0b2f7Stbbdev }
179*51c0b2f7Stbbdev 
180*51c0b2f7Stbbdev template<typename I, typename O>
181*51c0b2f7Stbbdev void test_inheritance(){
182*51c0b2f7Stbbdev     using namespace tbb::flow;
183*51c0b2f7Stbbdev 
184*51c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<graph_node, function_node<I, O>>::value), "function_node should be derived from graph_node");
185*51c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<receiver<I>, function_node<I, O>>::value), "function_node should be derived from receiver<Input>");
186*51c0b2f7Stbbdev     CHECK_MESSAGE( (std::is_base_of<sender<O>, function_node<I, O>>::value), "function_node should be derived from sender<Output>");
187*51c0b2f7Stbbdev }
188*51c0b2f7Stbbdev 
189*51c0b2f7Stbbdev void test_policy_ctors(){
190*51c0b2f7Stbbdev     using namespace tbb::flow;
191*51c0b2f7Stbbdev     graph g;
192*51c0b2f7Stbbdev 
193*51c0b2f7Stbbdev     function_node<int, int, lightweight> lw_node(g, tbb::flow::serial,
194*51c0b2f7Stbbdev                                                           [](int v) { return v;});
195*51c0b2f7Stbbdev     function_node<int, int, queueing_lightweight> qlw_node(g, tbb::flow::serial,
196*51c0b2f7Stbbdev                                                           [](int v) { return v;});
197*51c0b2f7Stbbdev     function_node<int, int, rejecting_lightweight> rlw_node(g, tbb::flow::serial,
198*51c0b2f7Stbbdev                                                           [](int v) { return v;});
199*51c0b2f7Stbbdev 
200*51c0b2f7Stbbdev }
201*51c0b2f7Stbbdev 
202*51c0b2f7Stbbdev class stateful_functor{
203*51c0b2f7Stbbdev public:
204*51c0b2f7Stbbdev     int stored;
205*51c0b2f7Stbbdev     stateful_functor(): stored(-1){}
206*51c0b2f7Stbbdev     int operator()(int value){ stored = 1; return value;}
207*51c0b2f7Stbbdev };
208*51c0b2f7Stbbdev 
209*51c0b2f7Stbbdev void test_ctors(){
210*51c0b2f7Stbbdev     using namespace tbb::flow;
211*51c0b2f7Stbbdev     graph g;
212*51c0b2f7Stbbdev 
213*51c0b2f7Stbbdev     function_node<int, int> fn(g, unlimited, stateful_functor());
214*51c0b2f7Stbbdev     fn.try_put(0);
215*51c0b2f7Stbbdev     g.wait_for_all();
216*51c0b2f7Stbbdev 
217*51c0b2f7Stbbdev     stateful_functor b1 = copy_body<stateful_functor, function_node<int, int>>(fn);
218*51c0b2f7Stbbdev     CHECK_MESSAGE( (b1.stored == 1), "First node should update");
219*51c0b2f7Stbbdev 
220*51c0b2f7Stbbdev     function_node<int, int> fn2(fn);
221*51c0b2f7Stbbdev     stateful_functor b2 = copy_body<stateful_functor, function_node<int, int>>(fn2);
222*51c0b2f7Stbbdev     CHECK_MESSAGE( (b2.stored == -1), "Copied node should not update");
223*51c0b2f7Stbbdev }
224*51c0b2f7Stbbdev 
225*51c0b2f7Stbbdev template<typename I, typename O>
226*51c0b2f7Stbbdev struct CopyCounterBody{
227*51c0b2f7Stbbdev     size_t copy_count;
228*51c0b2f7Stbbdev 
229*51c0b2f7Stbbdev     CopyCounterBody():
230*51c0b2f7Stbbdev         copy_count(0) {}
231*51c0b2f7Stbbdev 
232*51c0b2f7Stbbdev     CopyCounterBody(const CopyCounterBody<I, O>& other):
233*51c0b2f7Stbbdev         copy_count(other.copy_count + 1) {}
234*51c0b2f7Stbbdev 
235*51c0b2f7Stbbdev     CopyCounterBody& operator=(const CopyCounterBody<I, O>& other)
236*51c0b2f7Stbbdev     { copy_count = other.copy_count + 1; return *this;}
237*51c0b2f7Stbbdev 
238*51c0b2f7Stbbdev     O operator()(I in){
239*51c0b2f7Stbbdev         return in;
240*51c0b2f7Stbbdev     }
241*51c0b2f7Stbbdev };
242*51c0b2f7Stbbdev 
243*51c0b2f7Stbbdev void test_copies(){
244*51c0b2f7Stbbdev     using namespace tbb::flow;
245*51c0b2f7Stbbdev 
246*51c0b2f7Stbbdev     CopyCounterBody<int, int> b;
247*51c0b2f7Stbbdev 
248*51c0b2f7Stbbdev     graph g;
249*51c0b2f7Stbbdev     function_node<int, int> fn(g, unlimited, b);
250*51c0b2f7Stbbdev 
251*51c0b2f7Stbbdev     CopyCounterBody<int, int> b2 = copy_body<CopyCounterBody<int, int>, function_node<int, int>>(fn);
252*51c0b2f7Stbbdev 
253*51c0b2f7Stbbdev     CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
254*51c0b2f7Stbbdev }
255*51c0b2f7Stbbdev 
256*51c0b2f7Stbbdev void test_rejecting(){
257*51c0b2f7Stbbdev     tbb::flow::graph g;
258*51c0b2f7Stbbdev     tbb::flow::function_node <int, int, tbb::flow::rejecting> fnode(g, tbb::flow::serial,
259*51c0b2f7Stbbdev                                                                     [&](int v){
260*51c0b2f7Stbbdev                                                                         size_t ms = 50;
261*51c0b2f7Stbbdev                                                                         std::chrono::milliseconds sleep_time( ms );
262*51c0b2f7Stbbdev                                                                         std::this_thread::sleep_for( sleep_time );
263*51c0b2f7Stbbdev                                                                         return v;
264*51c0b2f7Stbbdev                                                                     });
265*51c0b2f7Stbbdev 
266*51c0b2f7Stbbdev     test_push_receiver<int> sink(g);
267*51c0b2f7Stbbdev 
268*51c0b2f7Stbbdev     make_edge(fnode, sink);
269*51c0b2f7Stbbdev 
270*51c0b2f7Stbbdev     for(int i = 0; i < 10; ++i){
271*51c0b2f7Stbbdev         fnode.try_put(i);
272*51c0b2f7Stbbdev     }
273*51c0b2f7Stbbdev 
274*51c0b2f7Stbbdev     g.wait_for_all();
275*51c0b2f7Stbbdev     CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed");
276*51c0b2f7Stbbdev }
277*51c0b2f7Stbbdev 
278*51c0b2f7Stbbdev //! Test function_node with rejecting policy
279*51c0b2f7Stbbdev //! \brief \ref interface
280*51c0b2f7Stbbdev TEST_CASE("function_node with rejecting policy"){
281*51c0b2f7Stbbdev     test_rejecting();
282*51c0b2f7Stbbdev }
283*51c0b2f7Stbbdev 
284*51c0b2f7Stbbdev //! Test body copying and copy_body logic
285*51c0b2f7Stbbdev //! \brief \ref interface
286*51c0b2f7Stbbdev TEST_CASE("function_node and body copying"){
287*51c0b2f7Stbbdev     test_copies();
288*51c0b2f7Stbbdev }
289*51c0b2f7Stbbdev 
290*51c0b2f7Stbbdev //! Test constructors
291*51c0b2f7Stbbdev //! \brief \ref interface
292*51c0b2f7Stbbdev TEST_CASE("function_node constructors"){
293*51c0b2f7Stbbdev     test_policy_ctors();
294*51c0b2f7Stbbdev }
295*51c0b2f7Stbbdev 
296*51c0b2f7Stbbdev //! Test inheritance relations
297*51c0b2f7Stbbdev //! \brief \ref interface
298*51c0b2f7Stbbdev TEST_CASE("function_node superclasses"){
299*51c0b2f7Stbbdev     test_inheritance<int, int>();
300*51c0b2f7Stbbdev     test_inheritance<void*, float>();
301*51c0b2f7Stbbdev }
302*51c0b2f7Stbbdev 
303*51c0b2f7Stbbdev //! Test function_node buffering
304*51c0b2f7Stbbdev //! \brief \ref requirement
305*51c0b2f7Stbbdev TEST_CASE("function_node buffering"){
306*51c0b2f7Stbbdev     test_buffering<tbb::flow::rejecting>();
307*51c0b2f7Stbbdev     test_buffering<tbb::flow::queueing>();
308*51c0b2f7Stbbdev }
309*51c0b2f7Stbbdev 
310*51c0b2f7Stbbdev //! Test function_node broadcasting
311*51c0b2f7Stbbdev //! \brief \ref requirement
312*51c0b2f7Stbbdev TEST_CASE("function_node broadcast"){
313*51c0b2f7Stbbdev     test_broadcast();
314*51c0b2f7Stbbdev }
315*51c0b2f7Stbbdev 
316*51c0b2f7Stbbdev //! Test deduction guides
317*51c0b2f7Stbbdev //! \brief \ref interface \ref requirement
318*51c0b2f7Stbbdev TEST_CASE("Deduction guides"){
319*51c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
320*51c0b2f7Stbbdev     test_deduction_guides();
321*51c0b2f7Stbbdev #endif
322*51c0b2f7Stbbdev }
323*51c0b2f7Stbbdev 
324*51c0b2f7Stbbdev //! Test priorities work in single-threaded configuration
325*51c0b2f7Stbbdev //! \brief \ref requirement
326*51c0b2f7Stbbdev TEST_CASE("function_node priority support"){
327*51c0b2f7Stbbdev     test_priority();
328*51c0b2f7Stbbdev }
329*51c0b2f7Stbbdev 
330*51c0b2f7Stbbdev //! Test that measured concurrency respects set limits
331*51c0b2f7Stbbdev //! \brief \ref requirement
332*51c0b2f7Stbbdev TEST_CASE("concurrency follows set limits"){
333*51c0b2f7Stbbdev     test_node_concurrency();
334*51c0b2f7Stbbdev }
335*51c0b2f7Stbbdev 
336*51c0b2f7Stbbdev //! Test calling function body
337*51c0b2f7Stbbdev //! \brief \ref interface \ref requirement
338*51c0b2f7Stbbdev TEST_CASE("Test function_node body") {
339*51c0b2f7Stbbdev     test_func_body();
340*51c0b2f7Stbbdev }
341