1 /*
2     Copyright (c) 2020 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #include "common/test.h"
18 
19 #include "common/utils.h"
20 #include "common/graph_utils.h"
21 
22 #include "oneapi/tbb/flow_graph.h"
23 #include "oneapi/tbb/task_arena.h"
24 #include "oneapi/tbb/global_control.h"
25 
26 #include "conformance_flowgraph.h"
27 
28 //! \file conformance_multifunction_node.cpp
29 //! \brief Test for [flow_graph.function_node] specification
30 
31 /*
32 TODO: implement missing conformance tests for multifunction_node:
33   - [ ] Implement test_forwarding that checks messages are broadcast to all the successors connected
34     to the output port the message is being sent to. And check that the value passed is the
35     actual one received.
36   - [ ] Explicit test for copy constructor of the node.
37   - [ ] Constructor with explicitly passed Policy parameter: `template<typename Body>
38     multifunction_node( graph &g, size_t concurrency, Body body, Policy(), node_priority_t priority = no_priority )'.
39   - [ ] Concurrency testing of the node: make a loop over possible concurrency levels. It is
40     important to test at least on five values: 1, oneapi::tbb::flow::serial, `max_allowed_parallelism'
41     obtained from `oneapi::tbb::global_control', `oneapi::tbb::flow::unlimited', and, if `max allowed
42     parallelism' is > 2, use something in the middle of the [1, max_allowed_parallelism]
43     interval. Use `utils::ExactConcurrencyLevel' entity (extending it if necessary).
44   - [ ] make `test_rejecting' deterministic, i.e. avoid dependency on OS scheduling of the threads;
45     add check that `try_put()' returns `false'
46   - [ ] The `copy_body' function copies altered body (e.g. after successful `try_put()' call).
47   - [ ] `output_ports_type' is defined and accessible by the user.
48   - [ ] Explicit test on `mfn::output_ports()' method.
49   - [ ] The copy constructor and copy assignment are called for the node's input and output types.
50   - [ ] Add CTAD test.
51 */
52 
53 template< typename OutputType >
54 struct mf_functor {
55 
56     std::atomic<std::size_t>& local_execute_count;
57 
58     mf_functor(std::atomic<std::size_t>& execute_count ) :
59         local_execute_count (execute_count)
60     {  }
61 
62     mf_functor( const mf_functor &f ) : local_execute_count(f.local_execute_count) { }
63     void operator=(const mf_functor &f) { local_execute_count = std::size_t(f.local_execute_count); }
64 
65     void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
66        ++local_execute_count;
67        std::get<0>(op).try_put(argument);
68     }
69 
70 };
71 
72 template<typename I, typename O>
73 void test_inheritance(){
74     using namespace oneapi::tbb::flow;
75 
76     CHECK_MESSAGE( (std::is_base_of<graph_node, multifunction_node<I, O>>::value), "multifunction_node should be derived from graph_node");
77     CHECK_MESSAGE( (std::is_base_of<receiver<I>, multifunction_node<I, O>>::value), "multifunction_node should be derived from receiver<Input>");
78 }
79 
80 void test_multifunc_body(){
81     oneapi::tbb::flow::graph g;
82     std::atomic<size_t> local_count(0);
83     mf_functor<std::tuple<int>> fun(local_count);
84 
85     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>, oneapi::tbb::flow::rejecting> node1(g, oneapi::tbb::flow::unlimited, fun);
86 
87     const size_t n = 10;
88     for(size_t i = 0; i < n; ++i) {
89         CHECK_MESSAGE((node1.try_put(1) == true), "try_put needs to return true");
90     }
91     g.wait_for_all();
92 
93     CHECK_MESSAGE( (local_count == n), "Body of the node needs to be executed N times");
94 }
95 
96 template<typename I, typename O>
97 struct CopyCounterBody{
98     size_t copy_count;
99 
100     CopyCounterBody():
101         copy_count(0) {}
102 
103     CopyCounterBody(const CopyCounterBody<I, O>& other):
104         copy_count(other.copy_count + 1) {}
105 
106     CopyCounterBody& operator=(const CopyCounterBody<I, O>& other)
107     { copy_count = other.copy_count + 1; return *this;}
108 
109     void operator()( const I& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
110        std::get<0>(op).try_put(argument);
111     }
112 };
113 
114 void test_copies(){
115      using namespace oneapi::tbb::flow;
116 
117      CopyCounterBody<int, std::tuple<int>> b;
118 
119      graph g;
120      multifunction_node<int, std::tuple<int>> fn(g, unlimited, b);
121 
122      CopyCounterBody<int, std::tuple<int>> b2 = copy_body<CopyCounterBody<int, std::tuple<int>>,
123                                                           multifunction_node<int, std::tuple<int>>>(fn);
124 
125      CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies");
126 }
127 
128 template< typename OutputType >
129 struct id_functor {
130     void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
131        std::get<0>(op).try_put(argument);
132     }
133 };
134 
135 void test_forwarding(){
136     oneapi::tbb::flow::graph g;
137     id_functor<int> fun;
138 
139     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> node1(g, oneapi::tbb::flow::unlimited, fun);
140     test_push_receiver<int> node2(g);
141     test_push_receiver<int> node3(g);
142 
143     oneapi::tbb::flow::make_edge(node1, node2);
144     oneapi::tbb::flow::make_edge(node1, node3);
145 
146     node1.try_put(1);
147     g.wait_for_all();
148 
149     CHECK_MESSAGE( (get_count(node3) == 1), "Descendant of the node must receive one message.");
150     CHECK_MESSAGE( (get_count(node2) == 1), "Descendant of the node must receive one message.");
151 }
152 
153 void test_rejecting_buffering(){
154     oneapi::tbb::flow::graph g;
155     id_functor<int> fun;
156 
157     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>, oneapi::tbb::flow::rejecting> node(g, oneapi::tbb::flow::unlimited, fun);
158     oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
159 
160     oneapi::tbb::flow::make_edge(node, rejecter);
161     node.try_put(1);
162 
163     int tmp = -1;
164     CHECK_MESSAGE( (std::get<0>(node.output_ports()).try_get(tmp) == false), "try_get after rejection should not succeed");
165     CHECK_MESSAGE( (tmp == -1), "try_get after rejection should alter passed value");
166     g.wait_for_all();
167 }
168 
169 void test_policy_ctors(){
170     using namespace oneapi::tbb::flow;
171     graph g;
172 
173     id_functor<int> fun;
174 
175     multifunction_node<int, std::tuple<int>, lightweight> lw_node(g, oneapi::tbb::flow::serial, fun);
176     multifunction_node<int, std::tuple<int>, queueing_lightweight> qlw_node(g, oneapi::tbb::flow::serial, fun);
177     multifunction_node<int, std::tuple<int>, rejecting_lightweight> rlw_node(g, oneapi::tbb::flow::serial, fun);
178 
179 }
180 
181 std::atomic<size_t> my_concurrency;
182 std::atomic<size_t> my_max_concurrency;
183 
184 struct concurrency_functor {
185     void operator()( const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ) {
186         ++my_concurrency;
187 
188         size_t old_value = my_max_concurrency;
189         while(my_max_concurrency < my_concurrency &&
190               !my_max_concurrency.compare_exchange_weak(old_value, my_concurrency))
191             ;
192 
193         size_t ms = 1000;
194         std::chrono::milliseconds sleep_time( ms );
195         std::this_thread::sleep_for( sleep_time );
196 
197         --my_concurrency;
198         std::get<0>(op).try_put(argument);
199     }
200 
201 };
202 
203 void test_node_concurrency(){
204     my_concurrency = 0;
205     my_max_concurrency = 0;
206 
207     oneapi::tbb::flow::graph g;
208 
209     concurrency_functor counter;
210     oneapi::tbb::flow::multifunction_node <int, std::tuple<int>> fnode(g, oneapi::tbb::flow::serial, counter);
211 
212     test_push_receiver<int> sink(g);
213 
214     make_edge(std::get<0>(fnode.output_ports()), sink);
215 
216     for(int i = 0; i < 10; ++i){
217         fnode.try_put(i);
218     }
219 
220     g.wait_for_all();
221     CHECK_MESSAGE( ( my_max_concurrency.load() == 1), "Measured parallelism over limit");
222 }
223 
224 
225 void test_priority(){
226     size_t concurrency_limit = 1;
227     oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, concurrency_limit);
228 
229     oneapi::tbb::flow::graph g;
230 
231     oneapi::tbb::flow::continue_node<int> source(g,
232                                          [](oneapi::tbb::flow::continue_msg){ return 1;});
233     source.try_put(oneapi::tbb::flow::continue_msg());
234 
235     first_functor<int>::first_id = -1;
236     first_functor<int> low_functor(1);
237     first_functor<int> high_functor(2);
238 
239     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> high(g, oneapi::tbb::flow::unlimited, high_functor, oneapi::tbb::flow::node_priority_t(1));
240     oneapi::tbb::flow::multifunction_node<int, std::tuple<int>> low(g, oneapi::tbb::flow::unlimited, low_functor);
241 
242     make_edge(source, low);
243     make_edge(source, high);
244 
245     g.wait_for_all();
246 
247     CHECK_MESSAGE( (first_functor<int>::first_id == 2), "High priority node should execute first");
248 }
249 
250 void test_rejecting(){
251     oneapi::tbb::flow::graph g;
252     oneapi::tbb::flow::multifunction_node <int, std::tuple<int>, oneapi::tbb::flow::rejecting> fnode(g, oneapi::tbb::flow::serial,
253                                                                     [&](const int& argument, oneapi::tbb::flow::multifunction_node<int, std::tuple<int>>::output_ports_type &op ){
254                                                                         size_t ms = 50;
255                                                                         std::chrono::milliseconds sleep_time( ms );
256                                                                         std::this_thread::sleep_for( sleep_time );
257                                                                         std::get<0>(op).try_put(argument);
258                                                                     });
259 
260     test_push_receiver<int> sink(g);
261 
262     make_edge(std::get<0>(fnode.output_ports()), sink);
263 
264     for(int i = 0; i < 10; ++i){
265         fnode.try_put(i);
266     }
267 
268     g.wait_for_all();
269     CHECK_MESSAGE( (get_count(sink) == 1), "Messages should be rejected while the first is being processed");
270 }
271 
272 //! Test multifunction_node with rejecting policy
273 //! \brief \ref interface
274 TEST_CASE("multifunction_node with rejecting policy"){
275     test_rejecting();
276 }
277 
278 //! Test priorities
279 //! \brief \ref interface
280 TEST_CASE("multifunction_node priority"){
281     test_priority();
282 }
283 
284 //! Test concurrency
285 //! \brief \ref interface
286 TEST_CASE("multifunction_node concurrency"){
287     test_node_concurrency();
288 }
289 
290 //! Test constructors
291 //! \brief \ref interface
292 TEST_CASE("multifunction_node constructors"){
293     test_policy_ctors();
294 }
295 
296 //! Test function_node buffering
297 //! \brief \ref requirement
298 TEST_CASE("multifunction_node buffering"){
299     test_rejecting_buffering();
300 }
301 
302 //! Test function_node broadcasting
303 //! \brief \ref requirement
304 TEST_CASE("multifunction_node broadcast"){
305     test_forwarding();
306 }
307 
308 //! Test body copying and copy_body logic
309 //! \brief \ref interface
310 TEST_CASE("multifunction_node constructors"){
311     test_copies();
312 }
313 
314 //! Test calling function body
315 //! \brief \ref interface \ref requirement
316 TEST_CASE("multifunction_node body") {
317     test_multifunc_body();
318 }
319 
320 //! Test inheritance relations
321 //! \brief \ref interface
322 TEST_CASE("multifunction_node superclasses"){
323     test_inheritance<int, std::tuple<int>>();
324     test_inheritance<void*, std::tuple<float>>();
325 }
326