1 /* 2 Copyright (c) 2005-2021 Intel Corporation 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 #include "common/config.h" 18 19 #include "tbb/flow_graph.h" 20 21 #include "common/test.h" 22 #include "common/utils.h" 23 #include "common/utils_assert.h" 24 #include "common/graph_utils.h" 25 #include "common/test_follows_and_precedes_api.h" 26 27 28 //! \file test_overwrite_node.cpp 29 //! \brief Test for [flow_graph.overwrite_node] specification 30 31 32 #define N 300 33 #define T 4 34 #define M 5 35 36 template< typename R > 37 void simple_read_write_tests() { 38 tbb::flow::graph g; 39 tbb::flow::overwrite_node<R> n(g); 40 41 for ( int t = 0; t < T; ++t ) { 42 R v0(N+1); 43 std::vector< std::shared_ptr<harness_counting_receiver<R>> > r; 44 for (size_t i = 0; i < M; ++i) { 45 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) ); 46 } 47 48 CHECK_MESSAGE( n.is_valid() == false, "" ); 49 CHECK_MESSAGE( n.try_get( v0 ) == false, "" ); 50 if ( t % 2 ) { 51 CHECK_MESSAGE( n.try_put( static_cast<R>(N) ), "" ); 52 CHECK_MESSAGE( n.is_valid() == true, "" ); 53 CHECK_MESSAGE( n.try_get( v0 ) == true, "" ); 54 CHECK_MESSAGE( v0 == R(N), "" ); 55 } 56 57 for (int i = 0; i < M; ++i) { 58 tbb::flow::make_edge( n, *r[i] ); 59 } 60 61 for (int i = 0; i < N; ++i ) { 62 R v1(static_cast<R>(i)); 63 CHECK_MESSAGE( n.try_put( v1 ), "" ); 64 CHECK_MESSAGE( n.is_valid() == true, "" ); 65 for (int j = 0; j < N; ++j ) { 66 R v2(0); 67 CHECK_MESSAGE( n.try_get( v2 ), "" ); 68 CHECK_MESSAGE( v1 == v2, "" ); 69 } 70 } 71 for (int i = 0; i < M; ++i) { 72 size_t c = r[i]->my_count; 73 CHECK_MESSAGE( int(c) == N+t%2, "" ); 74 } 75 for (int i = 0; i < M; ++i) { 76 tbb::flow::remove_edge( n, *r[i] ); 77 } 78 CHECK_MESSAGE( n.try_put( R(0) ), "" ); 79 for (int i = 0; i < M; ++i) { 80 size_t c = r[i]->my_count; 81 CHECK_MESSAGE( int(c) == N+t%2, "" ); 82 } 83 n.clear(); 84 CHECK_MESSAGE( n.is_valid() == false, "" ); 85 CHECK_MESSAGE( n.try_get( v0 ) == false, "" ); 86 } 87 } 88 89 template< typename R > 90 class native_body : utils::NoAssign { 91 tbb::flow::overwrite_node<R> &my_node; 92 93 public: 94 95 native_body( tbb::flow::overwrite_node<R> &n ) : my_node(n) {} 96 97 void operator()( int i ) const { 98 R v1(static_cast<R>(i)); 99 CHECK_MESSAGE( my_node.try_put( v1 ), "" ); 100 CHECK_MESSAGE( my_node.is_valid() == true, "" ); 101 } 102 }; 103 104 template< typename R > 105 void parallel_read_write_tests() { 106 tbb::flow::graph g; 107 tbb::flow::overwrite_node<R> n(g); 108 //Create a vector of identical nodes 109 std::vector< tbb::flow::overwrite_node<R> > ow_vec(2, n); 110 111 for (size_t node_idx=0; node_idx<ow_vec.size(); ++node_idx) { 112 for ( int t = 0; t < T; ++t ) { 113 std::vector< std::shared_ptr<harness_counting_receiver<R>> > r; 114 for (size_t i = 0; i < M; ++i) { 115 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) ); 116 } 117 118 for (int i = 0; i < M; ++i) { 119 tbb::flow::make_edge( ow_vec[node_idx], *r[i] ); 120 } 121 R v0; 122 CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" ); 123 CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" ); 124 125 #if TBB_TEST_LOW_WORKLOAD 126 const int nthreads = 30; 127 #else 128 const int nthreads = N; 129 #endif 130 utils::NativeParallelFor( nthreads, native_body<R>( ow_vec[node_idx] ) ); 131 132 for (int i = 0; i < M; ++i) { 133 size_t c = r[i]->my_count; 134 CHECK_MESSAGE( int(c) == nthreads, "" ); 135 } 136 for (int i = 0; i < M; ++i) { 137 tbb::flow::remove_edge( ow_vec[node_idx], *r[i] ); 138 } 139 CHECK_MESSAGE( ow_vec[node_idx].try_put( R(0) ), "" ); 140 for (int i = 0; i < M; ++i) { 141 size_t c = r[i]->my_count; 142 CHECK_MESSAGE( int(c) == nthreads, "" ); 143 } 144 ow_vec[node_idx].clear(); 145 CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" ); 146 CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" ); 147 } 148 } 149 } 150 151 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET 152 #include <array> 153 #include <vector> 154 void test_follows_and_precedes_api() { 155 using msg_t = tbb::flow::continue_msg; 156 157 std::array<msg_t, 3> messages_for_follows = { {msg_t(), msg_t(), msg_t()} }; 158 std::vector<msg_t> messages_for_precedes = {msg_t()}; 159 160 follows_and_precedes_testing::test_follows<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_follows); 161 follows_and_precedes_testing::test_precedes<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_precedes); 162 } 163 #endif 164 165 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 166 void test_deduction_guides() { 167 using namespace tbb::flow; 168 169 graph g; 170 broadcast_node<int> b1(g); 171 overwrite_node<int> o0(g); 172 173 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET 174 overwrite_node o1(follows(b1)); 175 static_assert(std::is_same_v<decltype(o1), overwrite_node<int>>); 176 177 overwrite_node o2(precedes(b1)); 178 static_assert(std::is_same_v<decltype(o2), overwrite_node<int>>); 179 #endif 180 181 overwrite_node o3(o0); 182 static_assert(std::is_same_v<decltype(o3), overwrite_node<int>>); 183 } 184 #endif 185 186 //! Test read-write properties 187 //! \brief \ref requirement \ref error_guessing 188 TEST_CASE("Read-write"){ 189 simple_read_write_tests<int>(); 190 simple_read_write_tests<float>(); 191 } 192 193 //! Read-write and ParallelFor tests under limited parallelism 194 //! \brief \ref error_guessing 195 TEST_CASE("Limited parallelism"){ 196 for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) { 197 tbb::task_arena arena(p); 198 arena.execute( 199 [&]() { 200 parallel_read_write_tests<int>(); 201 parallel_read_write_tests<float>(); 202 test_reserving_nodes<tbb::flow::overwrite_node, size_t>(); 203 } 204 ); 205 } 206 } 207 208 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET 209 //! Test follows and precedes API 210 //! \brief \ref error_guessing 211 TEST_CASE("Follows and precedes API"){ 212 test_follows_and_precedes_api(); 213 } 214 #endif 215 216 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 217 //! Test decution guides 218 //! \brief \ref requirement 219 TEST_CASE("Deduction guides"){ 220 test_deduction_guides(); 221 } 222 #endif 223 224 //! Test try_release 225 //! \brief \ref error_guessing 226 TEST_CASE("try_release"){ 227 tbb::flow::graph g; 228 229 tbb::flow::overwrite_node<int> on(g); 230 231 CHECK_MESSAGE ((on.try_release()== true), "try_release should return true"); 232 } 233