1 /*
2 Copyright (c) 2005-2022 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17 #include "common/config.h"
18
19 #include "tbb/flow_graph.h"
20
21 #include "common/test.h"
22 #include "common/utils.h"
23 #include "common/utils_assert.h"
24 #include "common/graph_utils.h"
25 #include "common/test_follows_and_precedes_api.h"
26
27
28 //! \file test_overwrite_node.cpp
29 //! \brief Test for [flow_graph.overwrite_node] specification
30
31
32 #define N 300
33 #define T 4
34 #define M 5
35
36 template< typename R >
simple_read_write_tests()37 void simple_read_write_tests() {
38 tbb::flow::graph g;
39 tbb::flow::overwrite_node<R> n(g);
40
41 for ( int t = 0; t < T; ++t ) {
42 R v0(N+1);
43 std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
44 for (size_t i = 0; i < M; ++i) {
45 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
46 }
47
48 CHECK_MESSAGE( n.is_valid() == false, "" );
49 CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
50 if ( t % 2 ) {
51 CHECK_MESSAGE( n.try_put( static_cast<R>(N) ), "" );
52 CHECK_MESSAGE( n.is_valid() == true, "" );
53 CHECK_MESSAGE( n.try_get( v0 ) == true, "" );
54 CHECK_MESSAGE( v0 == R(N), "" );
55 }
56
57 for (int i = 0; i < M; ++i) {
58 tbb::flow::make_edge( n, *r[i] );
59 }
60
61 for (int i = 0; i < N; ++i ) {
62 R v1(static_cast<R>(i));
63 CHECK_MESSAGE( n.try_put( v1 ), "" );
64 CHECK_MESSAGE( n.is_valid() == true, "" );
65 for (int j = 0; j < N; ++j ) {
66 R v2(0);
67 CHECK_MESSAGE( n.try_get( v2 ), "" );
68 CHECK_MESSAGE( v1 == v2, "" );
69 }
70 }
71 for (int i = 0; i < M; ++i) {
72 size_t c = r[i]->my_count;
73 CHECK_MESSAGE( int(c) == N+t%2, "" );
74 }
75 for (int i = 0; i < M; ++i) {
76 tbb::flow::remove_edge( n, *r[i] );
77 }
78 CHECK_MESSAGE( n.try_put( R(0) ), "" );
79 for (int i = 0; i < M; ++i) {
80 size_t c = r[i]->my_count;
81 CHECK_MESSAGE( int(c) == N+t%2, "" );
82 }
83 n.clear();
84 CHECK_MESSAGE( n.is_valid() == false, "" );
85 CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
86 }
87 }
88
89 template< typename R >
90 class native_body : utils::NoAssign {
91 tbb::flow::overwrite_node<R> &my_node;
92
93 public:
94
native_body(tbb::flow::overwrite_node<R> & n)95 native_body( tbb::flow::overwrite_node<R> &n ) : my_node(n) {}
96
operator ()(int i) const97 void operator()( int i ) const {
98 R v1(static_cast<R>(i));
99 CHECK_MESSAGE( my_node.try_put( v1 ), "" );
100 CHECK_MESSAGE( my_node.is_valid() == true, "" );
101 }
102 };
103
104 template< typename R >
parallel_read_write_tests()105 void parallel_read_write_tests() {
106 tbb::flow::graph g;
107 tbb::flow::overwrite_node<R> n(g);
108 //Create a vector of identical nodes
109 std::vector< tbb::flow::overwrite_node<R> > ow_vec(2, n);
110
111 for (size_t node_idx=0; node_idx<ow_vec.size(); ++node_idx) {
112 for ( int t = 0; t < T; ++t ) {
113 std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
114 for (size_t i = 0; i < M; ++i) {
115 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
116 }
117
118 for (int i = 0; i < M; ++i) {
119 tbb::flow::make_edge( ow_vec[node_idx], *r[i] );
120 }
121 R v0;
122 CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" );
123 CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" );
124
125 #if TBB_TEST_LOW_WORKLOAD
126 const int nthreads = 30;
127 #else
128 const int nthreads = N;
129 #endif
130 utils::NativeParallelFor( nthreads, native_body<R>( ow_vec[node_idx] ) );
131
132 for (int i = 0; i < M; ++i) {
133 size_t c = r[i]->my_count;
134 CHECK_MESSAGE( int(c) == nthreads, "" );
135 }
136 for (int i = 0; i < M; ++i) {
137 tbb::flow::remove_edge( ow_vec[node_idx], *r[i] );
138 }
139 CHECK_MESSAGE( ow_vec[node_idx].try_put( R(0) ), "" );
140 for (int i = 0; i < M; ++i) {
141 size_t c = r[i]->my_count;
142 CHECK_MESSAGE( int(c) == nthreads, "" );
143 }
144 ow_vec[node_idx].clear();
145 CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" );
146 CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" );
147 }
148 }
149 }
150
151 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
152 #include <array>
153 #include <vector>
test_follows_and_precedes_api()154 void test_follows_and_precedes_api() {
155 using msg_t = tbb::flow::continue_msg;
156
157 std::array<msg_t, 3> messages_for_follows = { {msg_t(), msg_t(), msg_t()} };
158 std::vector<msg_t> messages_for_precedes = {msg_t()};
159
160 follows_and_precedes_testing::test_follows<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_follows);
161 follows_and_precedes_testing::test_precedes<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_precedes);
162 }
163 #endif
164
165 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
test_deduction_guides()166 void test_deduction_guides() {
167 using namespace tbb::flow;
168
169 graph g;
170 broadcast_node<int> b1(g);
171 overwrite_node<int> o0(g);
172
173 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
174 overwrite_node o1(follows(b1));
175 static_assert(std::is_same_v<decltype(o1), overwrite_node<int>>);
176
177 overwrite_node o2(precedes(b1));
178 static_assert(std::is_same_v<decltype(o2), overwrite_node<int>>);
179 #endif
180
181 overwrite_node o3(o0);
182 static_assert(std::is_same_v<decltype(o3), overwrite_node<int>>);
183 }
184 #endif
185
186 //! Test read-write properties
187 //! \brief \ref requirement \ref error_guessing
188 TEST_CASE("Read-write"){
189 simple_read_write_tests<int>();
190 simple_read_write_tests<float>();
191 }
192
193 //! Read-write and ParallelFor tests under limited parallelism
194 //! \brief \ref error_guessing
195 TEST_CASE("Limited parallelism"){
196 for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) {
197 tbb::task_arena arena(p);
198 arena.execute(
__anona8c67c360102() 199 [&]() {
200 parallel_read_write_tests<int>();
201 parallel_read_write_tests<float>();
202 test_reserving_nodes<tbb::flow::overwrite_node, size_t>();
203 }
204 );
205 }
206 }
207
208 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
209 //! Test follows and precedes API
210 //! \brief \ref error_guessing
211 TEST_CASE("Follows and precedes API"){
212 test_follows_and_precedes_api();
213 }
214 #endif
215
216 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
217 //! Test decution guides
218 //! \brief \ref requirement
219 TEST_CASE("Deduction guides"){
220 test_deduction_guides();
221 }
222 #endif
223
224 //! Test try_release
225 //! \brief \ref error_guessing
226 TEST_CASE("try_release"){
227 tbb::flow::graph g;
228
229 tbb::flow::overwrite_node<int> on(g);
230
231 CHECK_MESSAGE ((on.try_release()== true), "try_release should return true");
232 }
233
234 //! Test for cancel register_predecessor_task
235 //! \brief \ref error_guessing
236 TEST_CASE("Cancel register_predecessor_task") {
237 tbb::flow::graph g;
238 // Cancel graph context for preventing tasks execution and
239 // calling cancel method of spawned tasks
240 g.cancel();
241
242 // To spawn register_predecessor_task internal buffer of overwrite_node
243 // should be valid and successor should failed during putting an item to it
244 oneapi::tbb::flow::overwrite_node<size_t> node(g);
245 // Reserving join_node always fails during putting an item to it
246 tbb::flow::join_node<std::tuple<size_t>, tbb::flow::reserving> j_node(g);
247
248 // Make internal buffer of overwrite_node valid
249 node.try_put(1);
250 // Making an edge attempts pushing an item to join_node
251 // that immediately fails and tries to reverse an edge into PULL state
252 // by spawning register_predecessor_task, which will be cancelled
253 // during execution
254 tbb::flow::make_edge(node, tbb::flow::input_port<0>(j_node));
255
256 // Wait for cancellation of spawned tasks
257 g.wait_for_all();
258 }
259