xref: /oneTBB/test/tbb/test_overwrite_node.cpp (revision ce0d258e)
1 /*
2     Copyright (c) 2005-2022 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #include "common/config.h"
18 
19 #include "tbb/flow_graph.h"
20 
21 #include "common/test.h"
22 #include "common/utils.h"
23 #include "common/utils_assert.h"
24 #include "common/graph_utils.h"
25 #include "common/test_follows_and_precedes_api.h"
26 
27 
28 //! \file test_overwrite_node.cpp
29 //! \brief Test for [flow_graph.overwrite_node] specification
30 
31 
32 #define N 300
33 #define T 4
34 #define M 5
35 
36 template< typename R >
37 void simple_read_write_tests() {
38     tbb::flow::graph g;
39     tbb::flow::overwrite_node<R> n(g);
40 
41     for ( int t = 0; t < T; ++t ) {
42         R v0(N+1);
43         std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
44         for (size_t i = 0; i < M; ++i) {
45             r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
46         }
47 
48         CHECK_MESSAGE( n.is_valid() == false, "" );
49         CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
50         if ( t % 2 ) {
51             CHECK_MESSAGE( n.try_put( static_cast<R>(N) ), "" );
52             CHECK_MESSAGE( n.is_valid() == true, "" );
53             CHECK_MESSAGE( n.try_get( v0 ) == true, "" );
54             CHECK_MESSAGE( v0 == R(N), "" );
55         }
56 
57         for (int i = 0; i < M; ++i) {
58             tbb::flow::make_edge( n, *r[i] );
59         }
60 
61         for (int i = 0; i < N; ++i ) {
62             R v1(static_cast<R>(i));
63             CHECK_MESSAGE( n.try_put( v1 ), "" );
64             CHECK_MESSAGE( n.is_valid() == true, "" );
65             for (int j = 0; j < N; ++j ) {
66                 R v2(0);
67                 CHECK_MESSAGE( n.try_get( v2 ), "" );
68                 CHECK_MESSAGE( v1 == v2, "" );
69             }
70         }
71         for (int i = 0; i < M; ++i) {
72             size_t c = r[i]->my_count;
73             CHECK_MESSAGE( int(c) == N+t%2, "" );
74         }
75         for (int i = 0; i < M; ++i) {
76             tbb::flow::remove_edge( n, *r[i] );
77         }
78         CHECK_MESSAGE( n.try_put( R(0) ), "" );
79         for (int i = 0; i < M; ++i) {
80             size_t c = r[i]->my_count;
81             CHECK_MESSAGE( int(c) == N+t%2, "" );
82         }
83         n.clear();
84         CHECK_MESSAGE( n.is_valid() == false, "" );
85         CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
86     }
87 }
88 
89 template< typename R >
90 class native_body : utils::NoAssign {
91     tbb::flow::overwrite_node<R> &my_node;
92 
93 public:
94 
95     native_body( tbb::flow::overwrite_node<R> &n ) : my_node(n) {}
96 
97     void operator()( int i ) const {
98         R v1(static_cast<R>(i));
99         CHECK_MESSAGE( my_node.try_put( v1 ), "" );
100         CHECK_MESSAGE( my_node.is_valid() == true, "" );
101     }
102 };
103 
104 template< typename R >
105 void parallel_read_write_tests() {
106     tbb::flow::graph g;
107     tbb::flow::overwrite_node<R> n(g);
108     //Create a vector of identical nodes
109     std::vector< tbb::flow::overwrite_node<R> > ow_vec(2, n);
110 
111     for (size_t node_idx=0; node_idx<ow_vec.size(); ++node_idx) {
112         for ( int t = 0; t < T; ++t ) {
113             std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
114             for (size_t i = 0; i < M; ++i) {
115                 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
116             }
117 
118             for (int i = 0; i < M; ++i) {
119                 tbb::flow::make_edge( ow_vec[node_idx], *r[i] );
120             }
121             R v0;
122             CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" );
123             CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" );
124 
125 #if TBB_TEST_LOW_WORKLOAD
126             const int nthreads = 30;
127 #else
128             const int nthreads = N;
129 #endif
130             utils::NativeParallelFor( nthreads, native_body<R>( ow_vec[node_idx] ) );
131 
132             for (int i = 0; i < M; ++i) {
133                 size_t c = r[i]->my_count;
134                 CHECK_MESSAGE( int(c) == nthreads, "" );
135             }
136             for (int i = 0; i < M; ++i) {
137                 tbb::flow::remove_edge( ow_vec[node_idx], *r[i] );
138             }
139             CHECK_MESSAGE( ow_vec[node_idx].try_put( R(0) ), "" );
140             for (int i = 0; i < M; ++i) {
141                 size_t c = r[i]->my_count;
142                 CHECK_MESSAGE( int(c) == nthreads, "" );
143             }
144             ow_vec[node_idx].clear();
145             CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" );
146             CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" );
147         }
148     }
149 }
150 
151 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
152 #include <array>
153 #include <vector>
154 void test_follows_and_precedes_api() {
155     using msg_t = tbb::flow::continue_msg;
156 
157     std::array<msg_t, 3> messages_for_follows = { {msg_t(), msg_t(), msg_t()} };
158     std::vector<msg_t> messages_for_precedes = {msg_t()};
159 
160     follows_and_precedes_testing::test_follows<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_follows);
161     follows_and_precedes_testing::test_precedes<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_precedes);
162 }
163 #endif
164 
165 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
166 void test_deduction_guides() {
167     using namespace tbb::flow;
168 
169     graph g;
170     broadcast_node<int> b1(g);
171     overwrite_node<int> o0(g);
172 
173 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
174     overwrite_node o1(follows(b1));
175     static_assert(std::is_same_v<decltype(o1), overwrite_node<int>>);
176 
177     overwrite_node o2(precedes(b1));
178     static_assert(std::is_same_v<decltype(o2), overwrite_node<int>>);
179 #endif
180 
181     overwrite_node o3(o0);
182     static_assert(std::is_same_v<decltype(o3), overwrite_node<int>>);
183 }
184 #endif
185 
186 //! Test read-write properties
187 //! \brief \ref requirement \ref error_guessing
188 TEST_CASE("Read-write"){
189     simple_read_write_tests<int>();
190     simple_read_write_tests<float>();
191 }
192 
193 //! Read-write and ParallelFor tests under limited parallelism
194 //! \brief \ref error_guessing
195 TEST_CASE("Limited parallelism"){
196     for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) {
197         tbb::task_arena arena(p);
198         arena.execute(
199             [&]() {
200                 parallel_read_write_tests<int>();
201                 parallel_read_write_tests<float>();
202                 test_reserving_nodes<tbb::flow::overwrite_node, size_t>();
203             }
204         );
205 	}
206 }
207 
208 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
209 //! Test follows and precedes API
210 //! \brief \ref error_guessing
211 TEST_CASE("Follows and precedes API"){
212     test_follows_and_precedes_api();
213 }
214 #endif
215 
216 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
217 //! Test decution guides
218 //! \brief \ref requirement
219 TEST_CASE("Deduction guides"){
220     test_deduction_guides();
221 }
222 #endif
223 
224 //! Test try_release
225 //! \brief \ref error_guessing
226 TEST_CASE("try_release"){
227     tbb::flow::graph g;
228 
229     tbb::flow::overwrite_node<int> on(g);
230 
231     CHECK_MESSAGE ((on.try_release()== true), "try_release should return true");
232 }
233 
234 //! Test for cancel register_predecessor_task
235 //! \brief \ref error_guessing
236 TEST_CASE("Cancel register_predecessor_task") {
237     tbb::flow::graph g;
238     // Cancel graph context for preventing tasks execution and
239     // calling cancel method of spawned tasks
240     g.cancel();
241 
242     // To spawn register_predecessor_task internal buffer of overwrite_node
243     // should be valid and successor should failed during putting an item to it
244     oneapi::tbb::flow::overwrite_node<size_t> node(g);
245     // Reserving join_node always fails during putting an item to it
246     tbb::flow::join_node<std::tuple<size_t>, tbb::flow::reserving> j_node(g);
247 
248     // Make internal buffer of overwrite_node valid
249     node.try_put(1);
250     // Making an edge attempts pushing an item to join_node
251     // that immediately fails and tries to reverse an edge into PULL state
252     // by spawning register_predecessor_task, which will be cancelled
253     // during execution
254     tbb::flow::make_edge(node, tbb::flow::input_port<0>(j_node));
255 
256     // Wait for cancellation of spawned tasks
257     g.wait_for_all();
258 }
259