xref: /oneTBB/test/tbb/test_write_once_node.cpp (revision fa3268c3)
1 /*
2     Copyright (c) 2005-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 
18 #include "common/config.h"
19 
20 #include "tbb/flow_graph.h"
21 
22 #include "common/test.h"
23 #include "common/utils.h"
24 #include "common/utils_assert.h"
25 #include "common/graph_utils.h"
26 #include "common/test_follows_and_precedes_api.h"
27 
28 #define N 300
29 #define T 4
30 #define M 4
31 
32 
33 //! \file test_write_once_node.cpp
34 //! \brief Test for [flow_graph.write_once_node] specification
35 
36 
37 template< typename R >
38 void simple_read_write_tests() {
39     tbb::flow::graph g;
40     tbb::flow::write_once_node<R> n(g);
41 
42     for ( int t = 0; t < T; ++t ) {
43         R v0(0);
44         std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
45         for (size_t i = 0; i < M; ++i) {
46             r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
47         }
48 
49 
50         CHECK_MESSAGE( n.is_valid() == false, "" );
51         CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
52 
53         if ( t % 2 ) {
54             CHECK_MESSAGE( n.try_put( static_cast<R>(N+1) ), "" );
55             CHECK_MESSAGE( n.is_valid() == true, "" );
56             CHECK_MESSAGE( n.try_get( v0 ) == true, "" );
57             CHECK_MESSAGE( v0 == R(N+1), "" );
58         }
59 
60         for (int i = 0; i < M; ++i) {
61             tbb::flow::make_edge( n, *r[i] );
62         }
63 
64         if ( t%2 ) {
65             for (int i = 0; i < M; ++i) {
66                 size_t c = r[i]->my_count;
67                 CHECK_MESSAGE( int(c) == 1, "" );
68             }
69         }
70 
71         for (int i = 1; i <= N; ++i ) {
72             R v1(static_cast<R>(i));
73 
74             bool result = n.try_put( v1 );
75             if ( !(t%2) && i == 1 )
76                 CHECK_MESSAGE( result == true, "" );
77             else
78                 CHECK_MESSAGE( result == false, "" );
79 
80             CHECK_MESSAGE( n.is_valid() == true, "" );
81 
82             for (int j = 0; j < N; ++j ) {
83                 R v2(0);
84                 CHECK_MESSAGE( n.try_get( v2 ), "" );
85                 if ( t%2 )
86                     CHECK_MESSAGE( R(N+1) == v2, "" );
87                 else
88                     CHECK_MESSAGE( R(1) == v2, "" );
89             }
90         }
91         for (int i = 0; i < M; ++i) {
92             size_t c = r[i]->my_count;
93             CHECK_MESSAGE( int(c) == 1, "" );
94         }
95         for (int i = 0; i < M; ++i) {
96             tbb::flow::remove_edge( n, *r[i] );
97         }
98         CHECK_MESSAGE( n.try_put( R(0) ) == false, "" );
99         for (int i = 0; i < M; ++i) {
100             size_t c = r[i]->my_count;
101             CHECK_MESSAGE( int(c) == 1, "" );
102         }
103         n.clear();
104         CHECK_MESSAGE( n.is_valid() == false, "" );
105         CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
106     }
107 }
108 
109 template< typename R >
110 class native_body : utils::NoAssign {
111     tbb::flow::write_once_node<R> &my_node;
112 
113 public:
114 
115     native_body( tbb::flow::write_once_node<R> &n ) : my_node(n) {}
116 
117     void operator()( int i ) const {
118         R v1(static_cast<R>(i));
119         CHECK_MESSAGE( my_node.try_put( v1 ) == false, "" );
120         CHECK_MESSAGE( my_node.is_valid() == true, "" );
121         CHECK_MESSAGE( my_node.try_get( v1 ) == true, "" );
122         CHECK_MESSAGE( v1 == R(-1), "" );
123     }
124 };
125 
126 template< typename R >
127 void parallel_read_write_tests() {
128     tbb::flow::graph g;
129     tbb::flow::write_once_node<R> n(g);
130     //Create a vector of identical nodes
131     std::vector< tbb::flow::write_once_node<R> > wo_vec(2, n);
132 
133     for (size_t node_idx=0; node_idx<wo_vec.size(); ++node_idx) {
134         for ( int t = 0; t < T; ++t ) {
135             std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
136             for (size_t i = 0; i < M; ++i) {
137                 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
138             }
139 
140 
141             for (int i = 0; i < M; ++i) {
142                 tbb::flow::make_edge( wo_vec[node_idx], *r[i] );
143             }
144             R v0;
145             CHECK_MESSAGE( wo_vec[node_idx].is_valid() == false, "" );
146             CHECK_MESSAGE( wo_vec[node_idx].try_get( v0 ) == false, "" );
147 
148             CHECK_MESSAGE( wo_vec[node_idx].try_put( R(-1) ), "" );
149 #if TBB_TEST_LOW_WORKLOAD
150             const int nthreads = 30;
151 #else
152             const int nthreads = N;
153 #endif
154             utils::NativeParallelFor( nthreads, native_body<R>( wo_vec[node_idx] ) );
155 
156             for (int i = 0; i < M; ++i) {
157                 size_t c = r[i]->my_count;
158                 CHECK_MESSAGE( int(c) == 1, "" );
159             }
160             for (int i = 0; i < M; ++i) {
161                 tbb::flow::remove_edge( wo_vec[node_idx], *r[i] );
162             }
163             CHECK_MESSAGE( wo_vec[node_idx].try_put( R(0) ) == false, "" );
164             for (int i = 0; i < M; ++i) {
165                 size_t c = r[i]->my_count;
166                 CHECK_MESSAGE( int(c) == 1, "" );
167             }
168             wo_vec[node_idx].clear();
169             CHECK_MESSAGE( wo_vec[node_idx].is_valid() == false, "" );
170             CHECK_MESSAGE( wo_vec[node_idx].try_get( v0 ) == false, "" );
171         }
172     }
173 }
174 
175 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
176 #include <array>
177 #include <vector>
178 void test_follows_and_precedes_api() {
179     using msg_t = tbb::flow::continue_msg;
180 
181     std::array<msg_t, 3> messages_for_follows= {msg_t(), msg_t(), msg_t()};
182     std::vector<msg_t> messages_for_precedes = {msg_t()};
183 
184     follows_and_precedes_testing::test_follows<msg_t, tbb::flow::write_once_node<msg_t>>(messages_for_follows);
185     follows_and_precedes_testing::test_precedes<msg_t, tbb::flow::write_once_node<msg_t>>(messages_for_precedes);
186 }
187 #endif // __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
188 
189 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
190 void test_deduction_guides() {
191     using namespace tbb::flow;
192 
193     graph g;
194     broadcast_node<int> b1(g);
195     write_once_node<int> wo0(g);
196 
197 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
198     write_once_node wo1(follows(b1));
199     static_assert(std::is_same_v<decltype(wo1), write_once_node<int>>);
200 
201     write_once_node wo2(precedes(b1));
202     static_assert(std::is_same_v<decltype(wo2), write_once_node<int>>);
203 #endif
204 
205     write_once_node wo3(wo0);
206     static_assert(std::is_same_v<decltype(wo3), write_once_node<int>>);
207 }
208 #endif
209 
210 //! Test read-write properties
211 //! \brief \ref requirement \ref error_guessing
212 TEST_CASE("Read-write tests"){
213     simple_read_write_tests<int>();
214     simple_read_write_tests<float>();
215 }
216 
217 //! Test read-write properties under parallelism
218 //! \brief \ref requirement \ref error_guessing \ref stress
219 TEST_CASE("Parallel read-write tests"){
220     for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) {
221         tbb::task_arena arena(p);
222         arena.execute(
223             [&]() {
224                 parallel_read_write_tests<int>();
225                 parallel_read_write_tests<float>();
226                 test_reserving_nodes<tbb::flow::write_once_node, size_t>();
227             }
228         );
229 	}
230 }
231 
232 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
233 //! Test deprecated follows and precedes API
234 //! \brief \ref error_guessing
235 TEST_CASE("Test follows and precedes API"){
236     test_follows_and_precedes_api();
237 }
238 #endif
239 
240 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
241 //! Test deduction guides
242 //! \brief \ref requirement
243 TEST_CASE("Deduction guides"){
244     test_deduction_guides();
245 }
246 #endif
247