xref: /oneTBB/test/tbb/test_overwrite_node.cpp (revision d86ed7fb)
1 /*
2     Copyright (c) 2005-2020 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #include "common/config.h"
18 
19 // TODO revamp: move parts dependent on __TBB_EXTRA_DEBUG into separate test(s) since having these
20 // parts in all of tests might make testing of the product, which is different from what is actually
21 // released.
22 #define __TBB_EXTRA_DEBUG 1
23 #include "tbb/flow_graph.h"
24 
25 #include "common/test.h"
26 #include "common/utils.h"
27 #include "common/utils_assert.h"
28 #include "common/graph_utils.h"
29 #include "common/test_follows_and_precedes_api.h"
30 
31 
32 //! \file test_overwrite_node.cpp
33 //! \brief Test for [flow_graph.overwrite_node] specification
34 
35 
36 #define N 300
37 #define T 4
38 #define M 5
39 
40 template< typename R >
41 void simple_read_write_tests() {
42     tbb::flow::graph g;
43     tbb::flow::overwrite_node<R> n(g);
44 
45     for ( int t = 0; t < T; ++t ) {
46         R v0(N+1);
47         std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
48         for (size_t i = 0; i < M; ++i) {
49             r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
50         }
51 
52         CHECK_MESSAGE( n.is_valid() == false, "" );
53         CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
54         if ( t % 2 ) {
55             CHECK_MESSAGE( n.try_put( static_cast<R>(N) ), "" );
56             CHECK_MESSAGE( n.is_valid() == true, "" );
57             CHECK_MESSAGE( n.try_get( v0 ) == true, "" );
58             CHECK_MESSAGE( v0 == R(N), "" );
59         }
60 
61         for (int i = 0; i < M; ++i) {
62             tbb::flow::make_edge( n, *r[i] );
63         }
64 
65         for (int i = 0; i < N; ++i ) {
66             R v1(static_cast<R>(i));
67             CHECK_MESSAGE( n.try_put( v1 ), "" );
68             CHECK_MESSAGE( n.is_valid() == true, "" );
69             for (int j = 0; j < N; ++j ) {
70                 R v2(0);
71                 CHECK_MESSAGE( n.try_get( v2 ), "" );
72                 CHECK_MESSAGE( v1 == v2, "" );
73             }
74         }
75         for (int i = 0; i < M; ++i) {
76             size_t c = r[i]->my_count;
77             CHECK_MESSAGE( int(c) == N+t%2, "" );
78         }
79         for (int i = 0; i < M; ++i) {
80             tbb::flow::remove_edge( n, *r[i] );
81         }
82         CHECK_MESSAGE( n.try_put( R(0) ), "" );
83         for (int i = 0; i < M; ++i) {
84             size_t c = r[i]->my_count;
85             CHECK_MESSAGE( int(c) == N+t%2, "" );
86         }
87         n.clear();
88         CHECK_MESSAGE( n.is_valid() == false, "" );
89         CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
90     }
91 }
92 
93 template< typename R >
94 class native_body : utils::NoAssign {
95     tbb::flow::overwrite_node<R> &my_node;
96 
97 public:
98 
99     native_body( tbb::flow::overwrite_node<R> &n ) : my_node(n) {}
100 
101     void operator()( int i ) const {
102         R v1(static_cast<R>(i));
103         CHECK_MESSAGE( my_node.try_put( v1 ), "" );
104         CHECK_MESSAGE( my_node.is_valid() == true, "" );
105     }
106 };
107 
108 template< typename R >
109 void parallel_read_write_tests() {
110     tbb::flow::graph g;
111     tbb::flow::overwrite_node<R> n(g);
112     //Create a vector of identical nodes
113     std::vector< tbb::flow::overwrite_node<R> > ow_vec(2, n);
114 
115     for (size_t node_idx=0; node_idx<ow_vec.size(); ++node_idx) {
116         for ( int t = 0; t < T; ++t ) {
117             std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
118             for (size_t i = 0; i < M; ++i) {
119                 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
120             }
121 
122             for (int i = 0; i < M; ++i) {
123                 tbb::flow::make_edge( ow_vec[node_idx], *r[i] );
124             }
125             R v0;
126             CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" );
127             CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" );
128 
129 #if TBB_TEST_LOW_WORKLOAD
130             const int nthreads = 30;
131 #else
132             const int nthreads = N;
133 #endif
134             utils::NativeParallelFor( nthreads, native_body<R>( ow_vec[node_idx] ) );
135 
136             for (int i = 0; i < M; ++i) {
137                 size_t c = r[i]->my_count;
138                 CHECK_MESSAGE( int(c) == nthreads, "" );
139             }
140             for (int i = 0; i < M; ++i) {
141                 tbb::flow::remove_edge( ow_vec[node_idx], *r[i] );
142             }
143             CHECK_MESSAGE( ow_vec[node_idx].try_put( R(0) ), "" );
144             for (int i = 0; i < M; ++i) {
145                 size_t c = r[i]->my_count;
146                 CHECK_MESSAGE( int(c) == nthreads, "" );
147             }
148             ow_vec[node_idx].clear();
149             CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" );
150             CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" );
151         }
152     }
153 }
154 
155 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
156 #include <array>
157 #include <vector>
158 void test_follows_and_precedes_api() {
159     using msg_t = tbb::flow::continue_msg;
160 
161     std::array<msg_t, 3> messages_for_follows = { {msg_t(), msg_t(), msg_t()} };
162     std::vector<msg_t> messages_for_precedes = {msg_t()};
163 
164     follows_and_precedes_testing::test_follows<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_follows);
165     follows_and_precedes_testing::test_precedes<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_precedes);
166 }
167 #endif
168 
169 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
170 void test_deduction_guides() {
171     using namespace tbb::flow;
172 
173     graph g;
174     broadcast_node<int> b1(g);
175     overwrite_node<int> o0(g);
176 
177 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
178     overwrite_node o1(follows(b1));
179     static_assert(std::is_same_v<decltype(o1), overwrite_node<int>>);
180 
181     overwrite_node o2(precedes(b1));
182     static_assert(std::is_same_v<decltype(o2), overwrite_node<int>>);
183 #endif
184 
185     overwrite_node o3(o0);
186     static_assert(std::is_same_v<decltype(o3), overwrite_node<int>>);
187 }
188 #endif
189 
190 //! Test read-write properties
191 //! \brief \ref requirement \ref error_guessing
192 TEST_CASE("Read-write"){
193     simple_read_write_tests<int>();
194     simple_read_write_tests<float>();
195 }
196 
197 //! Read-write and ParallelFor tests under limited parallelism
198 //! \brief \ref error_guessing
199 TEST_CASE("Limited parallelism"){
200     for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) {
201         tbb::task_arena arena(p);
202         arena.execute(
203             [&]() {
204                 parallel_read_write_tests<int>();
205                 parallel_read_write_tests<float>();
206                 test_reserving_nodes<tbb::flow::overwrite_node, size_t>();
207             }
208         );
209 	}
210 }
211 
212 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
213 //! Test follows and precedes API
214 //! \brief \ref error_guessing
215 TEST_CASE("Follows and precedes API"){
216     test_follows_and_precedes_api();
217 }
218 #endif
219 
220 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
221 //! Test decution guides
222 //! \brief \ref requirement
223 TEST_CASE("Deduction guides"){
224     test_deduction_guides();
225 }
226 #endif
227 
228 //! Test try_release
229 //! \brief \ref error_guessing
230 TEST_CASE("try_release"){
231     tbb::flow::graph g;
232 
233     tbb::flow::overwrite_node<int> on(g);
234 
235     CHECK_MESSAGE ((on.try_release()== true), "try_release should return true");
236 }
237