1 /*
2 Copyright (c) 2020-2023 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 #ifndef __TBB_test_conformance_conformance_flowgraph_H
18 #define __TBB_test_conformance_conformance_flowgraph_H
19
20 #include "common/test.h"
21 #include "common/utils.h"
22 #include "common/graph_utils.h"
23 #include "common/concurrency_tracker.h"
24
25 #include "oneapi/tbb/flow_graph.h"
26 #include "oneapi/tbb/task_arena.h"
27 #include "oneapi/tbb/global_control.h"
28
29 namespace conformance {
30
31 constexpr int expected = 5;
32
33 template<typename V>
34 using test_push_receiver = oneapi::tbb::flow::queue_node<V>;
35
36 template<typename Input, typename Output = Input>
37 using multifunc_ports_t =
38 typename oneapi::tbb::flow::multifunction_node<Input, std::tuple<Output>>::output_ports_type;
39
40 template<typename Input, typename Output = Input>
41 using async_ports_t =
42 typename oneapi::tbb::flow::async_node<Input, Output>::gateway_type;
43
44 template<bool DefaultConstructible, bool CopyConstructible, bool CopyAssignable>
45 struct message {
46 int data;
47
messagemessage48 message(int _data) : data(_data) {};
49
50 template<bool T = DefaultConstructible, typename = typename std::enable_if<T>::type>
messagemessage51 message(){};
52
53 template<bool T = CopyConstructible, typename = typename std::enable_if<T>::type>
messagemessage54 message(const message& msg) : data(msg.data) {};
55
56 template<bool T = CopyAssignable, typename = typename std::enable_if<T>::type>
57 message& operator=(const message& msg) {
58 this->data = msg.data;
59 return *this;
60 };
61
62 bool operator==(const int expected_data) const {
63 return data == expected_data;
64 }
65
66 bool operator==(const message& msg) const {
67 return data == msg.data;
68 }
69
size_tmessage70 operator std::size_t() const {
71 return static_cast<std::size_t>(data);
72 }
73
74 operator int() const {
75 return data;
76 }
77 };
78
79 template<typename V>
get_values(test_push_receiver<V> & rr)80 typename std::enable_if<!std::is_default_constructible<V>::value, std::vector<V>>::type get_values( test_push_receiver<V>& rr ) {
81 std::vector<V> messages;
82 V tmp(0);
83
84 while (rr.try_get(tmp)) {
85 messages.push_back(tmp);
86 }
87 return messages;
88 }
89
90 template<typename V>
get_values(test_push_receiver<V> & rr)91 typename std::enable_if<std::is_default_constructible<V>::value, std::vector<V>>::type get_values( test_push_receiver<V>& rr ) {
92 std::vector<V> messages;
93 V tmp;
94
95 while (rr.try_get(tmp)) {
96 messages.push_back(tmp);
97 }
98 return messages;
99 }
100
101 template<typename Node, typename InputType = void>
102 bool produce_messages(Node& node, int arg = 1) {
103 utils::suppress_unused_warning(arg);
104 #if defined CONFORMANCE_INPUT_NODE
105 node.activate();
106 return true;
107 #elif defined CONFORMANCE_CONTINUE_NODE
108 return node.try_put(InputType());
109 #else
110 return node.try_put(InputType(arg));
111 #endif
112 }
113
114 template<typename T, typename U>
check_output_type()115 typename std::enable_if<std::is_same<T, U>::value, bool>::type check_output_type(){
116 return true;
117 }
118
119 template<typename T, typename U>
check_output_type()120 typename std::enable_if<!std::is_same<T, U>::value, bool>::type check_output_type(){
121 return false;
122 }
123
124 template<typename T>
125 struct sequencer_functor {
126 struct seq_message {
127 std::size_t id;
128 T data;
129 };
130
131 using input_type = T;
132
operatorsequencer_functor133 std::size_t operator()(T v) {
134 return v;
135 }
136
operatorsequencer_functor137 std::size_t operator()(seq_message msg) {
138 return msg.id;
139 }
140 };
141
142 template<typename OutputType>
143 struct track_first_id_functor {
144 int my_id;
145 static std::atomic<int> first_id;
146
track_first_id_functortrack_first_id_functor147 track_first_id_functor(int id) : my_id(id) {}
148
operatortrack_first_id_functor149 OutputType operator()( OutputType argument ) {
150 int old_value = -1;
151 while(first_id == -1 &&
152 !first_id.compare_exchange_strong(old_value, my_id));
153 return argument;
154 }
155
156 template<typename InputType>
operatortrack_first_id_functor157 OutputType operator()( InputType& ) {
158 return operator()(OutputType(0));
159 }
160
161 template<typename InputType>
operatortrack_first_id_functor162 void operator()( InputType, async_ports_t<InputType, OutputType>& g ) {
163 g.try_put(operator()(OutputType(0)));
164 }
165
166 template<typename InputType>
operatortrack_first_id_functor167 void operator()( InputType, multifunc_ports_t<InputType, OutputType>& op ) {
168 std::get<0>(op).try_put(operator()(OutputType(0)));
169 }
170 };
171
172 template<typename OutputType>
173 std::atomic<int> track_first_id_functor<OutputType>::first_id = {-1};
174
175 template<typename OutputType>
176 struct counting_functor {
177 OutputType return_value;
178
179 static std::atomic<std::size_t> execute_count;
180
return_valuecounting_functor181 counting_functor( OutputType value = OutputType(0) ) : return_value(value) {
182 execute_count = 0;
183 }
184
counting_functorcounting_functor185 counting_functor( const counting_functor & c ) : return_value(static_cast<int>(c.return_value)) {
186 execute_count = 0;
187 }
188
189 template<typename InputType>
operatorcounting_functor190 OutputType operator()( InputType ) {
191 ++execute_count;
192 return return_value;
193 }
194
195 template<typename InputType>
operatorcounting_functor196 void operator()( InputType, multifunc_ports_t<InputType, OutputType>& op ) {
197 ++execute_count;
198 std::get<0>(op).try_put(return_value);
199 }
200
operatorcounting_functor201 OutputType operator()( oneapi::tbb::flow_control& fc ) {
202 ++execute_count;
203 if(execute_count > std::size_t(return_value)) {
204 fc.stop();
205 return return_value;
206 }
207 return return_value;
208 }
209
210 template<typename InputType>
operatorcounting_functor211 void operator()( InputType, async_ports_t<InputType, OutputType>& g ) {
212 ++execute_count;
213 g.try_put(return_value);
214 }
215 };
216
217 template<typename OutputType>
218 std::atomic<std::size_t> counting_functor<OutputType>::execute_count = {0};
219
220 template<typename OutputType>
221 struct dummy_functor {
222 template<typename InputType>
operatordummy_functor223 OutputType operator()( InputType ) {
224 #ifdef CONFORMANCE_CONTINUE_NODE
225 return OutputType();
226 #else
227 return OutputType(0);
228 #endif
229 }
230
231 template<typename InputType>
operatordummy_functor232 void operator()( InputType, multifunc_ports_t<InputType, OutputType>& op ) {
233 std::get<0>(op).try_put(OutputType(0));
234 }
235
236 template<typename InputType>
operatordummy_functor237 void operator()( InputType, async_ports_t<InputType, OutputType>& g ) {
238 g.try_put(OutputType(0));
239 }
240
241 template<typename InputType, typename T>
operatordummy_functor242 void operator()( InputType, std::tuple<T, T>& ) {}
243
operatordummy_functor244 OutputType operator()( oneapi::tbb::flow_control & fc ) {
245 static bool check = false;
246 if(check) {
247 check = false;
248 fc.stop();
249 return OutputType(1);
250 }
251 check = true;
252 return OutputType(1);
253 }
254 };
255
256 struct wait_flag_body {
257 static std::atomic<bool> flag;
258
wait_flag_bodywait_flag_body259 wait_flag_body() {
260 flag.store(false);
261 }
262
263 template<typename InputType>
operatorwait_flag_body264 InputType operator()( InputType ) {
265 while(!flag.load()) { utils::yield(); };
266 #ifdef CONFORMANCE_CONTINUE_NODE
267 return InputType();
268 #else
269 return InputType(0);
270 #endif
271 }
272
273 template<typename InputType>
operatorwait_flag_body274 void operator()( InputType argument, multifunc_ports_t<InputType>& op ) {
275 while(!flag.load()) { };
276 std::get<0>(op).try_put(argument);
277 }
278
279 template<typename InputType>
operatorwait_flag_body280 void operator()( InputType argument, async_ports_t<InputType>& g ) {
281 while(!flag.load()) { };
282 g.try_put(argument);
283 }
284 };
285
286 std::atomic<bool> wait_flag_body::flag{false};
287
288 struct concurrency_peak_checker_body {
289 std::size_t required_max_concurrency = 0;
290
291 concurrency_peak_checker_body( std::size_t req_max_concurrency = 0 ) :
required_max_concurrencyconcurrency_peak_checker_body292 required_max_concurrency(req_max_concurrency) {}
293
294 concurrency_peak_checker_body( const concurrency_peak_checker_body & ) = default;
295
operatorconcurrency_peak_checker_body296 int operator()( oneapi::tbb::flow_control & fc ) {
297 static int counter = 0;
298 utils::ConcurrencyTracker ct;
299 if(++counter > 500) {
300 counter = 0;
301 fc.stop();
302 return 1;
303 }
304 utils::doDummyWork(1000);
305 CHECK_MESSAGE((int)utils::ConcurrencyTracker::PeakParallelism() <= required_max_concurrency,
306 "Input node is serial and its body never invoked concurrently");
307 return 1;
308 }
309
operatorconcurrency_peak_checker_body310 int operator()( int ) {
311 utils::ConcurrencyTracker ct;
312 utils::doDummyWork(1000);
313 CHECK_MESSAGE((int)utils::ConcurrencyTracker::PeakParallelism() <= required_max_concurrency,
314 "Measured parallelism is not expected");
315 return 1;
316 }
317
operatorconcurrency_peak_checker_body318 void operator()( const int& argument, multifunc_ports_t<int>& op ) {
319 utils::ConcurrencyTracker ct;
320 utils::doDummyWork(1000);
321 CHECK_MESSAGE((int)utils::ConcurrencyTracker::PeakParallelism() <= required_max_concurrency,
322 "Measured parallelism is not expected");
323 std::get<0>(op).try_put(argument);
324 }
325
operatorconcurrency_peak_checker_body326 void operator()( const int& argument , async_ports_t<int>& g ) {
327 utils::ConcurrencyTracker ct;
328 utils::doDummyWork(1000);
329 CHECK_MESSAGE((int)utils::ConcurrencyTracker::PeakParallelism() <= required_max_concurrency,
330 "Measured parallelism is not expected");
331 g.try_put(argument);
332 }
333 };
334
335 template<typename OutputType, typename InputType = int>
336 struct copy_counting_object {
337 std::size_t copy_count;/*increases on every new copied object*/
338 mutable std::size_t copies_count;/*count number of objects copied from this object*/
339 std::size_t assign_count;
340 bool is_copy;
341
copy_counting_objectcopy_counting_object342 copy_counting_object():
343 copy_count(0), copies_count(0), assign_count(0), is_copy(false) {}
344
copy_counting_objectcopy_counting_object345 copy_counting_object(int):
346 copy_count(0), copies_count(0), assign_count(0), is_copy(false) {}
347
copy_counting_objectcopy_counting_object348 copy_counting_object( const copy_counting_object<OutputType, InputType>& other ):
349 copy_count(other.copy_count + 1), is_copy(true) {
350 ++other.copies_count;
351 }
352
353 copy_counting_object& operator=( const copy_counting_object<OutputType, InputType>& other ) {
354 assign_count = other.assign_count + 1;
355 is_copy = true;
356 return *this;
357 }
358
operatorcopy_counting_object359 OutputType operator()( InputType ) {
360 return OutputType(1);
361 }
362
operatorcopy_counting_object363 void operator()( InputType, multifunc_ports_t<InputType,OutputType>& op ) {
364 std::get<0>(op).try_put(OutputType(1));
365 }
366
operatorcopy_counting_object367 void operator()( InputType , async_ports_t<InputType, OutputType>& g) {
368 g.try_put(OutputType(1));
369 }
370
operatorcopy_counting_object371 OutputType operator()( oneapi::tbb::flow_control & fc ) {
372 static bool check = false;
373 if(check) {
374 check = false;
375 fc.stop();
376 return OutputType(1);
377 }
378 check = true;
379 return OutputType(1);
380 }
381 };
382
383 template <typename OutputType = int>
384 struct passthru_body {
operatorpassthru_body385 OutputType operator()( const oneapi::tbb::flow::continue_msg& ) {
386 return OutputType(0);
387 }
388
operatorpassthru_body389 OutputType operator()( const OutputType& i ) {
390 return i;
391 }
392
operatorpassthru_body393 OutputType operator()( oneapi::tbb::flow_control & fc ) {
394 static bool check = false;
395 if(check) {
396 check = false;
397 fc.stop();
398 return OutputType(0);
399 }
400 check = true;
401 return OutputType(0);
402 }
403
operatorpassthru_body404 void operator()( OutputType argument, multifunc_ports_t<OutputType>& op ) {
405 std::get<0>(op).try_put(argument);
406 }
407
operatorpassthru_body408 void operator()( OutputType argument, async_ports_t<OutputType>& g ) {
409 g.try_put(argument);
410 }
411 };
412
413 template<typename Node, typename InputType, typename OutputType, typename ...Args>
test_body_exec(Args...node_args)414 void test_body_exec(Args... node_args) {
415 oneapi::tbb::flow::graph g;
416 counting_functor<OutputType> counting_body;
417 counting_body.execute_count = 0;
418
419 Node testing_node(g, node_args..., counting_body);
420
421 constexpr std::size_t n = 10;
422 for(std::size_t i = 0; i < n; ++i) {
423 CHECK_MESSAGE((produce_messages<Node, InputType>(testing_node) == true),
424 "try_put of first node should return true");
425 }
426 g.wait_for_all();
427
428 CHECK_MESSAGE((counting_body.execute_count == n), "Body of the first node needs to be executed N times");
429 }
430
431 template<typename Node, typename Body, typename ...Args>
test_copy_body_function(Args...node_args)432 void test_copy_body_function(Args... node_args) {
433 using namespace oneapi::tbb::flow;
434
435 Body base_body;
436
437 graph g;
438
439 Node testing_node(g, node_args..., base_body);
440
441 Body b2 = copy_body<Body, Node>(testing_node);
442
443 CHECK_MESSAGE((base_body.copy_count + 1 < b2.copy_count), "copy_body and constructor should copy bodies");
444 }
445
446 template<typename Node, typename InputType, typename ...Args>
test_buffering(Args...node_args)447 void test_buffering(Args... node_args) {
448 oneapi::tbb::flow::graph g;
449
450 Node testing_node(g, node_args...);
451 oneapi::tbb::flow::limiter_node<int> rejecter(g, 0);
452
453 oneapi::tbb::flow::make_edge(testing_node, rejecter);
454
455 int tmp = -1;
456 produce_messages<Node, InputType>(testing_node);
457 g.wait_for_all();
458
459
460 #if defined CONFORMANCE_BUFFERING_NODES || defined CONFORMANCE_INPUT_NODE
461 CHECK_MESSAGE((testing_node.try_get(tmp) == true), "try_get after rejection should succeed");
462 CHECK_MESSAGE((tmp == 1), "try_get after rejection should set value");
463 #else
464 #ifdef CONFORMANCE_MULTIFUNCTION_NODE
465 CHECK_MESSAGE((std::get<0>(testing_node.output_ports()).try_get(tmp) == false), "try_get after rejection should not succeed");
466 #else
467 CHECK_MESSAGE((testing_node.try_get(tmp) == false), "try_get after rejection should not succeed");
468 #endif
469 CHECK_MESSAGE((tmp == -1), "try_get after rejection should not alter passed value");
470 #endif
471 }
472
473
474 template<typename Node, typename InputType, typename OutputType = InputType, typename ...Args>
test_forwarding(std::size_t messages_received,Args...node_args)475 void test_forwarding(std::size_t messages_received, Args... node_args) {
476 oneapi::tbb::flow::graph g;
477
478 Node testing_node(g, node_args...);
479 std::vector<std::unique_ptr<test_push_receiver<OutputType>>> receiver_nodes;
480
481 for(std::size_t i = 0; i < 10; ++i) {
482 receiver_nodes.emplace_back(new test_push_receiver<OutputType>(g));
483 oneapi::tbb::flow::make_edge(testing_node, *receiver_nodes.back());
484 }
485
486 produce_messages<Node, InputType>(testing_node, expected);
487
488 #ifdef CONFORMANCE_INPUT_NODE
489 CHECK_MESSAGE(expected == messages_received, "For correct execution of test");
490 #endif
491
492 g.wait_for_all();
493 for(auto& receiver : receiver_nodes) {
494 auto values = get_values(*receiver);
495 CHECK_MESSAGE((values.size() == messages_received), std::string("Descendant of the node must receive " + std::to_string(messages_received) + " message."));
496 CHECK_MESSAGE((values[0] == expected), "Value passed is the actual one received.");
497 }
498 }
499
500 template<typename Node, typename ...Args>
test_forwarding_single_push(Args...node_args)501 void test_forwarding_single_push(Args... node_args) {
502 oneapi::tbb::flow::graph g;
503
504 Node testing_node(g, node_args...);
505 test_push_receiver<int> suc_node1(g);
506 test_push_receiver<int> suc_node2(g);
507
508 oneapi::tbb::flow::make_edge(testing_node, suc_node1);
509 oneapi::tbb::flow::make_edge(testing_node, suc_node2);
510
511 testing_node.try_put(0);
512 g.wait_for_all();
513
514 auto values1 = get_values(suc_node1);
515 auto values2 = get_values(suc_node2);
516 CHECK_MESSAGE((values1.size() != values2.size()), "Only one descendant the node needs to receive");
517 CHECK_MESSAGE((values1.size() + values2.size() == 1), "All messages need to be received");
518
519 testing_node.try_put(1);
520 g.wait_for_all();
521
522 auto values3 = get_values(suc_node1);
523 auto values4 = get_values(suc_node2);
524 CHECK_MESSAGE((values3.size() != values4.size()), "Only one descendant the node needs to receive");
525 CHECK_MESSAGE((values3.size() + values4.size() == 1), "All messages need to be received");
526
527 #ifdef CONFORMANCE_QUEUE_NODE
528 CHECK_MESSAGE((values1[0] == 0), "Value passed is the actual one received");
529 CHECK_MESSAGE((values3[0] == 1), "Value passed is the actual one received");
530 #else
531 if(values1.size() == 1) {
532 CHECK_MESSAGE((values1[0] == 0), "Value passed is the actual one received");
533 }else{
534 CHECK_MESSAGE((values2[0] == 0), "Value passed is the actual one received");
535 }
536 #endif
537 }
538
539 template<typename Node, typename InputType, typename OutputType>
test_inheritance()540 void test_inheritance() {
541 using namespace oneapi::tbb::flow;
542
543 CHECK_MESSAGE((std::is_base_of<graph_node, Node>::value), "Node should be derived from graph_node");
544 CHECK_MESSAGE((std::is_base_of<receiver<InputType>, Node>::value), "Node should be derived from receiver<Input>");
545 CHECK_MESSAGE((std::is_base_of<sender<OutputType>, Node>::value), "Node should be derived from sender<Output>");
546 }
547
548 template<typename Node>
test_copy_ctor()549 void test_copy_ctor() {
550 using namespace oneapi::tbb::flow;
551 graph g;
552
553 dummy_functor<int> fun1;
554 conformance::copy_counting_object<int> fun2;
555
556 Node node0(g, unlimited, fun1);
557 Node node1(g, unlimited, fun2);
558 test_push_receiver<int> suc_node1(g);
559 test_push_receiver<int> suc_node2(g);
560
561 oneapi::tbb::flow::make_edge(node0, node1);
562 oneapi::tbb::flow::make_edge(node1, suc_node1);
563
564 Node node_copy(node1);
565
566 conformance::copy_counting_object<int> b2 = copy_body<conformance::copy_counting_object<int>, Node>(node_copy);
567
568 CHECK_MESSAGE((fun2.copy_count + 1 < b2.copy_count), "constructor should copy bodies");
569
570 oneapi::tbb::flow::make_edge(node_copy, suc_node2);
571
572 node_copy.try_put(1);
573 g.wait_for_all();
574
575 CHECK_MESSAGE((get_values(suc_node1).size() == 0 && get_values(suc_node2).size() == 1), "Copied node doesn`t copy successor");
576
577 node0.try_put(1);
578 g.wait_for_all();
579
580 CHECK_MESSAGE((get_values(suc_node1).size() == 1 && get_values(suc_node2).size() == 0), "Copied node doesn`t copy predecessor");
581 }
582
583 template<typename Node, typename ...Args>
test_copy_ctor_for_buffering_nodes(Args...node_args)584 void test_copy_ctor_for_buffering_nodes(Args... node_args) {
585 oneapi::tbb::flow::graph g;
586
587 dummy_functor<int> fun;
588
589 Node testing_node(g, node_args...);
590 oneapi::tbb::flow::continue_node<int> pred_node(g, fun);
591 test_push_receiver<int> suc_node1(g);
592 test_push_receiver<int> suc_node2(g);
593
594 oneapi::tbb::flow::make_edge(pred_node, testing_node);
595 oneapi::tbb::flow::make_edge(testing_node, suc_node1);
596
597 #ifdef CONFORMANCE_OVERWRITE_NODE
598 testing_node.try_put(1);
599 #endif
600
601 Node node_copy(testing_node);
602
603 #ifdef CONFORMANCE_OVERWRITE_NODE
604 int tmp;
605 CHECK_MESSAGE((!node_copy.is_valid() && !node_copy.try_get(tmp)), "The buffered value is not copied from src");
606 get_values(suc_node1);
607 #endif
608
609 oneapi::tbb::flow::make_edge(node_copy, suc_node2);
610
611 node_copy.try_put(0);
612 g.wait_for_all();
613
614 CHECK_MESSAGE((get_values(suc_node1).size() == 0 && get_values(suc_node2).size() == 1), "Copied node doesn`t copy successor");
615
616 #ifdef CONFORMANCE_OVERWRITE_NODE
617 node_copy.clear();
618 testing_node.clear();
619 #endif
620
621 pred_node.try_put(oneapi::tbb::flow::continue_msg());
622 g.wait_for_all();
623
624 CHECK_MESSAGE((get_values(suc_node1).size() == 1 && get_values(suc_node2).size() == 0), "Copied node doesn`t copy predecessor");
625 }
626
627 template<typename Node, typename InputType, typename ...Args>
test_priority(Args...node_args)628 void test_priority(Args... node_args) {
629
630
631 oneapi::tbb::flow::graph g;
632 oneapi::tbb::flow::continue_node<InputType> source(g, dummy_functor<InputType>());
633
634 track_first_id_functor<int>::first_id = -1;
635 track_first_id_functor<int> low_functor(1);
636 track_first_id_functor<int> high_functor(2);
637
638 // Due to args... we cannot create the nodes inside the lambda with old compilers
639 Node high(g, node_args..., high_functor, oneapi::tbb::flow::node_priority_t(1));
640 Node low(g, node_args..., low_functor);
641
642 tbb::task_arena a(1, 1);
643 a.execute([&] {
644 g.reset(); // attach to this arena
645
646 make_edge(source, low);
647 make_edge(source, high);
648 source.try_put(oneapi::tbb::flow::continue_msg());
649
650 g.wait_for_all();
651
652 CHECK_MESSAGE((track_first_id_functor<int>::first_id == 2), "High priority node should execute first");
653 });
654 }
655
656 template<typename Node>
test_concurrency()657 void test_concurrency() {
658 auto max_num_threads = oneapi::tbb::this_task_arena::max_concurrency();
659
660 oneapi::tbb::global_control c(oneapi::tbb::global_control::max_allowed_parallelism,
661 max_num_threads);
662
663 std::vector<int> threads_count = {1, oneapi::tbb::flow::serial, max_num_threads, oneapi::tbb::flow::unlimited};
664
665 if(max_num_threads > 2) {
666 threads_count.push_back(max_num_threads / 2);
667 }
668
669 for(auto num_threads : threads_count) {
670 utils::ConcurrencyTracker::Reset();
671 int expected_threads = num_threads;
672 if(num_threads == oneapi::tbb::flow::unlimited) {
673 expected_threads = max_num_threads;
674 }
675 if(num_threads == oneapi::tbb::flow::serial) {
676 expected_threads = 1;
677 }
678 oneapi::tbb::flow::graph g;
679 concurrency_peak_checker_body counter(expected_threads);
680 Node fnode(g, num_threads, counter);
681
682 test_push_receiver<int> suc_node(g);
683
684 make_edge(fnode, suc_node);
685
686 for(int i = 0; i < 500; ++i) {
687 fnode.try_put(i);
688 }
689 g.wait_for_all();
690 }
691 }
692
693 template<typename Node>
test_rejecting()694 void test_rejecting() {
695 oneapi::tbb::flow::graph g;
696
697 wait_flag_body body;
698 Node fnode(g, oneapi::tbb::flow::serial, body);
699
700 test_push_receiver<int> suc_node(g);
701
702 make_edge(fnode, suc_node);
703
704 fnode.try_put(0);
705
706 CHECK_MESSAGE((!fnode.try_put(1)), "Messages should be rejected while the first is being processed");
707
708 wait_flag_body::flag = true;
709
710 g.wait_for_all();
711 CHECK_MESSAGE((get_values(suc_node).size() == 1), "Messages should be rejected while the first is being processed");
712 }
713
714 template<typename Node, typename CountingBody>
test_output_input_class()715 void test_output_input_class() {
716 using namespace oneapi::tbb::flow;
717
718 passthru_body<CountingBody> fun;
719
720 graph g;
721 Node node1(g, unlimited, fun);
722 test_push_receiver<CountingBody> suc_node(g);
723 make_edge(node1, suc_node);
724 CountingBody b1;
725 CountingBody b2;
726 node1.try_put(b1);
727 g.wait_for_all();
728 suc_node.try_get(b2);
729 DOCTEST_WARN_MESSAGE((b1.copies_count > 0), "The type Input must meet the DefaultConstructible and CopyConstructible requirements");
730 DOCTEST_WARN_MESSAGE((b2.is_copy), "The type Output must meet the CopyConstructible requirements");
731 }
732
733 template<typename Node, typename Output = copy_counting_object<int>>
test_output_class()734 void test_output_class() {
735 using namespace oneapi::tbb::flow;
736
737 passthru_body<Output> fun;
738
739 graph g;
740 Node node1(g, fun);
741 test_push_receiver<Output> suc_node(g);
742 make_edge(node1, suc_node);
743
744 #ifdef CONFORMANCE_INPUT_NODE
745 node1.activate();
746 #else
747 node1.try_put(oneapi::tbb::flow::continue_msg());
748 #endif
749
750 g.wait_for_all();
751 Output b;
752 suc_node.try_get(b);
753 DOCTEST_WARN_MESSAGE((b.is_copy), "The type Output must meet the CopyConstructible requirements");
754 }
755
756 template<typename Node>
test_with_reserving_join_node_class()757 void test_with_reserving_join_node_class() {
758 using namespace oneapi::tbb::flow;
759
760 graph g;
761
762 function_node<int, int> static_result_computer_n(
763 g, serial,
764 [&](const int& msg) {
765 // compute the result using incoming message and pass it further, e.g.:
766 int result = int((msg >> 2) / 4);
767 return result;
768 });
769 Node testing_node(g); // for buffering once computed value
770
771 buffer_node<int> buffer_n(g);
772 join_node<std::tuple<int, int>, reserving> join_n(g);
773
774 std::atomic<int> number{2};
775 std::atomic<int> counter{0};
776 function_node<std::tuple<int, int>> consumer_n(
777 g, unlimited,
778 [&](const std::tuple<int, int>& arg) {
779 // use the precomputed static result along with dynamic data
780 ++counter;
781 #ifdef CONFORMANCE_OVERWRITE_NODE
782 CHECK_MESSAGE((std::get<0>(arg) == int((number >> 2) / 4)), "A overwrite_node store a single item that can be overwritten");
783 #else
784 CHECK_MESSAGE((std::get<0>(arg) == int((number >> 2) / 4)), "A write_once_node store a single item that cannot be overwritten");
785 #endif
786 });
787
788 make_edge(static_result_computer_n, testing_node);
789 make_edge(testing_node, input_port<0>(join_n));
790 make_edge(buffer_n, input_port<1>(join_n));
791 make_edge(join_n, consumer_n);
792
793 // do one-time calculation that will be reused many times further in the graph
794 static_result_computer_n.try_put(number);
795
796 constexpr int put_count = 50;
797 for (int i = 0; i < put_count / 2; i++) {
798 buffer_n.try_put(i);
799 }
800 #ifdef CONFORMANCE_OVERWRITE_NODE
801 number = 3;
802 #endif
803 static_result_computer_n.try_put(number);
804 for (int i = 0; i < put_count / 2; i++) {
805 buffer_n.try_put(i);
806 }
807
808 g.wait_for_all();
809 CHECK_MESSAGE((counter == put_count), "join_node with reserving policy \
810 if at least one successor accepts the tuple must consume messages");
811 }
812 }
813 #endif // __TBB_test_conformance_conformance_flowgraph_H
814