1 /* 2 Copyright (c) 2005-2021 Intel Corporation 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 #include "common/config.h" 18 19 // TODO revamp: move parts dependent on __TBB_EXTRA_DEBUG into separate test(s) since having these 20 // parts in all of tests might make testing of the product, which is different from what is actually 21 // released. 22 #define __TBB_EXTRA_DEBUG 1 23 #include "tbb/flow_graph.h" 24 25 #include "common/test.h" 26 #include "common/utils.h" 27 #include "common/utils_assert.h" 28 #include "common/graph_utils.h" 29 #include "common/test_follows_and_precedes_api.h" 30 31 32 //! \file test_overwrite_node.cpp 33 //! \brief Test for [flow_graph.overwrite_node] specification 34 35 36 #define N 300 37 #define T 4 38 #define M 5 39 40 template< typename R > 41 void simple_read_write_tests() { 42 tbb::flow::graph g; 43 tbb::flow::overwrite_node<R> n(g); 44 45 for ( int t = 0; t < T; ++t ) { 46 R v0(N+1); 47 std::vector< std::shared_ptr<harness_counting_receiver<R>> > r; 48 for (size_t i = 0; i < M; ++i) { 49 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) ); 50 } 51 52 CHECK_MESSAGE( n.is_valid() == false, "" ); 53 CHECK_MESSAGE( n.try_get( v0 ) == false, "" ); 54 if ( t % 2 ) { 55 CHECK_MESSAGE( n.try_put( static_cast<R>(N) ), "" ); 56 CHECK_MESSAGE( n.is_valid() == true, "" ); 57 CHECK_MESSAGE( n.try_get( v0 ) == true, "" ); 58 CHECK_MESSAGE( v0 == R(N), "" ); 59 } 60 61 for (int i = 0; i < M; ++i) { 62 tbb::flow::make_edge( n, *r[i] ); 63 } 64 65 for (int i = 0; i < N; ++i ) { 66 R v1(static_cast<R>(i)); 67 CHECK_MESSAGE( n.try_put( v1 ), "" ); 68 CHECK_MESSAGE( n.is_valid() == true, "" ); 69 for (int j = 0; j < N; ++j ) { 70 R v2(0); 71 CHECK_MESSAGE( n.try_get( v2 ), "" ); 72 CHECK_MESSAGE( v1 == v2, "" ); 73 } 74 } 75 for (int i = 0; i < M; ++i) { 76 size_t c = r[i]->my_count; 77 CHECK_MESSAGE( int(c) == N+t%2, "" ); 78 } 79 for (int i = 0; i < M; ++i) { 80 tbb::flow::remove_edge( n, *r[i] ); 81 } 82 CHECK_MESSAGE( n.try_put( R(0) ), "" ); 83 for (int i = 0; i < M; ++i) { 84 size_t c = r[i]->my_count; 85 CHECK_MESSAGE( int(c) == N+t%2, "" ); 86 } 87 n.clear(); 88 CHECK_MESSAGE( n.is_valid() == false, "" ); 89 CHECK_MESSAGE( n.try_get( v0 ) == false, "" ); 90 } 91 } 92 93 template< typename R > 94 class native_body : utils::NoAssign { 95 tbb::flow::overwrite_node<R> &my_node; 96 97 public: 98 99 native_body( tbb::flow::overwrite_node<R> &n ) : my_node(n) {} 100 101 void operator()( int i ) const { 102 R v1(static_cast<R>(i)); 103 CHECK_MESSAGE( my_node.try_put( v1 ), "" ); 104 CHECK_MESSAGE( my_node.is_valid() == true, "" ); 105 } 106 }; 107 108 template< typename R > 109 void parallel_read_write_tests() { 110 tbb::flow::graph g; 111 tbb::flow::overwrite_node<R> n(g); 112 //Create a vector of identical nodes 113 std::vector< tbb::flow::overwrite_node<R> > ow_vec(2, n); 114 115 for (size_t node_idx=0; node_idx<ow_vec.size(); ++node_idx) { 116 for ( int t = 0; t < T; ++t ) { 117 std::vector< std::shared_ptr<harness_counting_receiver<R>> > r; 118 for (size_t i = 0; i < M; ++i) { 119 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) ); 120 } 121 122 for (int i = 0; i < M; ++i) { 123 tbb::flow::make_edge( ow_vec[node_idx], *r[i] ); 124 } 125 R v0; 126 CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" ); 127 CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" ); 128 129 #if TBB_TEST_LOW_WORKLOAD 130 const int nthreads = 30; 131 #else 132 const int nthreads = N; 133 #endif 134 utils::NativeParallelFor( nthreads, native_body<R>( ow_vec[node_idx] ) ); 135 136 for (int i = 0; i < M; ++i) { 137 size_t c = r[i]->my_count; 138 CHECK_MESSAGE( int(c) == nthreads, "" ); 139 } 140 for (int i = 0; i < M; ++i) { 141 tbb::flow::remove_edge( ow_vec[node_idx], *r[i] ); 142 } 143 CHECK_MESSAGE( ow_vec[node_idx].try_put( R(0) ), "" ); 144 for (int i = 0; i < M; ++i) { 145 size_t c = r[i]->my_count; 146 CHECK_MESSAGE( int(c) == nthreads, "" ); 147 } 148 ow_vec[node_idx].clear(); 149 CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" ); 150 CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" ); 151 } 152 } 153 } 154 155 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET 156 #include <array> 157 #include <vector> 158 void test_follows_and_precedes_api() { 159 using msg_t = tbb::flow::continue_msg; 160 161 std::array<msg_t, 3> messages_for_follows = { {msg_t(), msg_t(), msg_t()} }; 162 std::vector<msg_t> messages_for_precedes = {msg_t()}; 163 164 follows_and_precedes_testing::test_follows<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_follows); 165 follows_and_precedes_testing::test_precedes<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_precedes); 166 } 167 #endif 168 169 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 170 void test_deduction_guides() { 171 using namespace tbb::flow; 172 173 graph g; 174 broadcast_node<int> b1(g); 175 overwrite_node<int> o0(g); 176 177 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET 178 overwrite_node o1(follows(b1)); 179 static_assert(std::is_same_v<decltype(o1), overwrite_node<int>>); 180 181 overwrite_node o2(precedes(b1)); 182 static_assert(std::is_same_v<decltype(o2), overwrite_node<int>>); 183 #endif 184 185 overwrite_node o3(o0); 186 static_assert(std::is_same_v<decltype(o3), overwrite_node<int>>); 187 } 188 #endif 189 190 //! Test read-write properties 191 //! \brief \ref requirement \ref error_guessing 192 TEST_CASE("Read-write"){ 193 simple_read_write_tests<int>(); 194 simple_read_write_tests<float>(); 195 } 196 197 //! Read-write and ParallelFor tests under limited parallelism 198 //! \brief \ref error_guessing 199 TEST_CASE("Limited parallelism"){ 200 for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) { 201 tbb::task_arena arena(p); 202 arena.execute( 203 [&]() { 204 parallel_read_write_tests<int>(); 205 parallel_read_write_tests<float>(); 206 test_reserving_nodes<tbb::flow::overwrite_node, size_t>(); 207 } 208 ); 209 } 210 } 211 212 #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET 213 //! Test follows and precedes API 214 //! \brief \ref error_guessing 215 TEST_CASE("Follows and precedes API"){ 216 test_follows_and_precedes_api(); 217 } 218 #endif 219 220 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 221 //! Test decution guides 222 //! \brief \ref requirement 223 TEST_CASE("Deduction guides"){ 224 test_deduction_guides(); 225 } 226 #endif 227 228 //! Test try_release 229 //! \brief \ref error_guessing 230 TEST_CASE("try_release"){ 231 tbb::flow::graph g; 232 233 tbb::flow::overwrite_node<int> on(g); 234 235 CHECK_MESSAGE ((on.try_release()== true), "try_release should return true"); 236 } 237