xref: /oneTBB/test/tbb/test_write_once_node.cpp (revision b15aabb3)
1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 
18 #include "common/config.h"
19 
20 // TODO revamp: move parts dependent on __TBB_EXTRA_DEBUG into separate test(s) since having these
21 // parts in all of tests might make testing of the product, which is different from what is actually
22 // released.
23 #define __TBB_EXTRA_DEBUG 1
24 #include "tbb/flow_graph.h"
25 
26 #include "common/test.h"
27 #include "common/utils.h"
28 #include "common/utils_assert.h"
29 #include "common/graph_utils.h"
30 #include "common/test_follows_and_precedes_api.h"
31 
32 #define N 300
33 #define T 4
34 #define M 4
35 
36 
37 //! \file test_write_once_node.cpp
38 //! \brief Test for [flow_graph.write_once_node] specification
39 
40 
41 template< typename R >
42 void simple_read_write_tests() {
43     tbb::flow::graph g;
44     tbb::flow::write_once_node<R> n(g);
45 
46     for ( int t = 0; t < T; ++t ) {
47         R v0(0);
48         std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
49         for (size_t i = 0; i < M; ++i) {
50             r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
51         }
52 
53 
54         CHECK_MESSAGE( n.is_valid() == false, "" );
55         CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
56 
57         if ( t % 2 ) {
58             CHECK_MESSAGE( n.try_put( static_cast<R>(N+1) ), "" );
59             CHECK_MESSAGE( n.is_valid() == true, "" );
60             CHECK_MESSAGE( n.try_get( v0 ) == true, "" );
61             CHECK_MESSAGE( v0 == R(N+1), "" );
62         }
63 
64         for (int i = 0; i < M; ++i) {
65             tbb::flow::make_edge( n, *r[i] );
66         }
67 
68         if ( t%2 ) {
69             for (int i = 0; i < M; ++i) {
70                 size_t c = r[i]->my_count;
71                 CHECK_MESSAGE( int(c) == 1, "" );
72             }
73         }
74 
75         for (int i = 1; i <= N; ++i ) {
76             R v1(static_cast<R>(i));
77 
78             bool result = n.try_put( v1 );
79             if ( !(t%2) && i == 1 )
80                 CHECK_MESSAGE( result == true, "" );
81             else
82                 CHECK_MESSAGE( result == false, "" );
83 
84             CHECK_MESSAGE( n.is_valid() == true, "" );
85 
86             for (int j = 0; j < N; ++j ) {
87                 R v2(0);
88                 CHECK_MESSAGE( n.try_get( v2 ), "" );
89                 if ( t%2 )
90                     CHECK_MESSAGE( R(N+1) == v2, "" );
91                 else
92                     CHECK_MESSAGE( R(1) == v2, "" );
93             }
94         }
95         for (int i = 0; i < M; ++i) {
96             size_t c = r[i]->my_count;
97             CHECK_MESSAGE( int(c) == 1, "" );
98         }
99         for (int i = 0; i < M; ++i) {
100             tbb::flow::remove_edge( n, *r[i] );
101         }
102         CHECK_MESSAGE( n.try_put( R(0) ) == false, "" );
103         for (int i = 0; i < M; ++i) {
104             size_t c = r[i]->my_count;
105             CHECK_MESSAGE( int(c) == 1, "" );
106         }
107         n.clear();
108         CHECK_MESSAGE( n.is_valid() == false, "" );
109         CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
110     }
111 }
112 
113 template< typename R >
114 class native_body : utils::NoAssign {
115     tbb::flow::write_once_node<R> &my_node;
116 
117 public:
118 
119     native_body( tbb::flow::write_once_node<R> &n ) : my_node(n) {}
120 
121     void operator()( int i ) const {
122         R v1(static_cast<R>(i));
123         CHECK_MESSAGE( my_node.try_put( v1 ) == false, "" );
124         CHECK_MESSAGE( my_node.is_valid() == true, "" );
125         CHECK_MESSAGE( my_node.try_get( v1 ) == true, "" );
126         CHECK_MESSAGE( v1 == R(-1), "" );
127     }
128 };
129 
130 template< typename R >
131 void parallel_read_write_tests() {
132     tbb::flow::graph g;
133     tbb::flow::write_once_node<R> n(g);
134     //Create a vector of identical nodes
135     std::vector< tbb::flow::write_once_node<R> > wo_vec(2, n);
136 
137     for (size_t node_idx=0; node_idx<wo_vec.size(); ++node_idx) {
138         for ( int t = 0; t < T; ++t ) {
139             std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
140             for (size_t i = 0; i < M; ++i) {
141                 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
142             }
143 
144 
145             for (int i = 0; i < M; ++i) {
146                 tbb::flow::make_edge( wo_vec[node_idx], *r[i] );
147             }
148             R v0;
149             CHECK_MESSAGE( wo_vec[node_idx].is_valid() == false, "" );
150             CHECK_MESSAGE( wo_vec[node_idx].try_get( v0 ) == false, "" );
151 
152             CHECK_MESSAGE( wo_vec[node_idx].try_put( R(-1) ), "" );
153 #if TBB_TEST_LOW_WORKLOAD
154             const int nthreads = 30;
155 #else
156             const int nthreads = N;
157 #endif
158             utils::NativeParallelFor( nthreads, native_body<R>( wo_vec[node_idx] ) );
159 
160             for (int i = 0; i < M; ++i) {
161                 size_t c = r[i]->my_count;
162                 CHECK_MESSAGE( int(c) == 1, "" );
163             }
164             for (int i = 0; i < M; ++i) {
165                 tbb::flow::remove_edge( wo_vec[node_idx], *r[i] );
166             }
167             CHECK_MESSAGE( wo_vec[node_idx].try_put( R(0) ) == false, "" );
168             for (int i = 0; i < M; ++i) {
169                 size_t c = r[i]->my_count;
170                 CHECK_MESSAGE( int(c) == 1, "" );
171             }
172             wo_vec[node_idx].clear();
173             CHECK_MESSAGE( wo_vec[node_idx].is_valid() == false, "" );
174             CHECK_MESSAGE( wo_vec[node_idx].try_get( v0 ) == false, "" );
175         }
176     }
177 }
178 
179 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
180 #include <array>
181 #include <vector>
182 void test_follows_and_precedes_api() {
183     using msg_t = tbb::flow::continue_msg;
184 
185     std::array<msg_t, 3> messages_for_follows= {msg_t(), msg_t(), msg_t()};
186     std::vector<msg_t> messages_for_precedes = {msg_t()};
187 
188     follows_and_precedes_testing::test_follows<msg_t, tbb::flow::write_once_node<msg_t>>(messages_for_follows);
189     follows_and_precedes_testing::test_precedes<msg_t, tbb::flow::write_once_node<msg_t>>(messages_for_precedes);
190 }
191 #endif // __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
192 
193 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
194 void test_deduction_guides() {
195     using namespace tbb::flow;
196 
197     graph g;
198     broadcast_node<int> b1(g);
199     write_once_node<int> wo0(g);
200 
201 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
202     write_once_node wo1(follows(b1));
203     static_assert(std::is_same_v<decltype(wo1), write_once_node<int>>);
204 
205     write_once_node wo2(precedes(b1));
206     static_assert(std::is_same_v<decltype(wo2), write_once_node<int>>);
207 #endif
208 
209     write_once_node wo3(wo0);
210     static_assert(std::is_same_v<decltype(wo3), write_once_node<int>>);
211 }
212 #endif
213 
214 //! Test read-write properties
215 //! \brief \ref requirement \ref error_guessing
216 TEST_CASE("Read-write tests"){
217     simple_read_write_tests<int>();
218     simple_read_write_tests<float>();
219 }
220 
221 //! Test read-write properties under parallelism
222 //! \brief \ref requirement \ref error_guessing \ref stress
223 TEST_CASE("Parallel read-write tests"){
224     for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) {
225         tbb::task_arena arena(p);
226         arena.execute(
227             [&]() {
228                 parallel_read_write_tests<int>();
229                 parallel_read_write_tests<float>();
230                 test_reserving_nodes<tbb::flow::write_once_node, size_t>();
231             }
232         );
233 	}
234 }
235 
236 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
237 //! Test deprecated follows and precedes API
238 //! \brief \ref error_guessing
239 TEST_CASE("Test follows and precedes API"){
240     test_follows_and_precedes_api();
241 }
242 #endif
243 
244 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
245 //! Test deduction guides
246 //! \brief \ref requirement
247 TEST_CASE("Deduction guides"){
248     test_deduction_guides();
249 }
250 #endif
251