xref: /oneTBB/test/tbb/test_overwrite_node.cpp (revision 5d21288a)
151c0b2f7Stbbdev /*
2*5d21288aSVladimir Serov     Copyright (c) 2005-2022 Intel Corporation
351c0b2f7Stbbdev 
451c0b2f7Stbbdev     Licensed under the Apache License, Version 2.0 (the "License");
551c0b2f7Stbbdev     you may not use this file except in compliance with the License.
651c0b2f7Stbbdev     You may obtain a copy of the License at
751c0b2f7Stbbdev 
851c0b2f7Stbbdev         http://www.apache.org/licenses/LICENSE-2.0
951c0b2f7Stbbdev 
1051c0b2f7Stbbdev     Unless required by applicable law or agreed to in writing, software
1151c0b2f7Stbbdev     distributed under the License is distributed on an "AS IS" BASIS,
1251c0b2f7Stbbdev     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1351c0b2f7Stbbdev     See the License for the specific language governing permissions and
1451c0b2f7Stbbdev     limitations under the License.
1551c0b2f7Stbbdev */
1651c0b2f7Stbbdev 
1751c0b2f7Stbbdev #include "common/config.h"
1851c0b2f7Stbbdev 
1951c0b2f7Stbbdev #include "tbb/flow_graph.h"
2051c0b2f7Stbbdev 
2151c0b2f7Stbbdev #include "common/test.h"
2251c0b2f7Stbbdev #include "common/utils.h"
2351c0b2f7Stbbdev #include "common/utils_assert.h"
2451c0b2f7Stbbdev #include "common/graph_utils.h"
2551c0b2f7Stbbdev #include "common/test_follows_and_precedes_api.h"
2651c0b2f7Stbbdev 
2751c0b2f7Stbbdev 
2851c0b2f7Stbbdev //! \file test_overwrite_node.cpp
2951c0b2f7Stbbdev //! \brief Test for [flow_graph.overwrite_node] specification
3051c0b2f7Stbbdev 
3151c0b2f7Stbbdev 
3251c0b2f7Stbbdev #define N 300
3351c0b2f7Stbbdev #define T 4
3451c0b2f7Stbbdev #define M 5
3551c0b2f7Stbbdev 
3651c0b2f7Stbbdev template< typename R >
simple_read_write_tests()3751c0b2f7Stbbdev void simple_read_write_tests() {
3851c0b2f7Stbbdev     tbb::flow::graph g;
3951c0b2f7Stbbdev     tbb::flow::overwrite_node<R> n(g);
4051c0b2f7Stbbdev 
4151c0b2f7Stbbdev     for ( int t = 0; t < T; ++t ) {
4251c0b2f7Stbbdev         R v0(N+1);
4351c0b2f7Stbbdev         std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
4451c0b2f7Stbbdev         for (size_t i = 0; i < M; ++i) {
4551c0b2f7Stbbdev             r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
4651c0b2f7Stbbdev         }
4751c0b2f7Stbbdev 
4851c0b2f7Stbbdev         CHECK_MESSAGE( n.is_valid() == false, "" );
4951c0b2f7Stbbdev         CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
5051c0b2f7Stbbdev         if ( t % 2 ) {
5151c0b2f7Stbbdev             CHECK_MESSAGE( n.try_put( static_cast<R>(N) ), "" );
5251c0b2f7Stbbdev             CHECK_MESSAGE( n.is_valid() == true, "" );
5351c0b2f7Stbbdev             CHECK_MESSAGE( n.try_get( v0 ) == true, "" );
5451c0b2f7Stbbdev             CHECK_MESSAGE( v0 == R(N), "" );
5551c0b2f7Stbbdev         }
5651c0b2f7Stbbdev 
5751c0b2f7Stbbdev         for (int i = 0; i < M; ++i) {
5851c0b2f7Stbbdev             tbb::flow::make_edge( n, *r[i] );
5951c0b2f7Stbbdev         }
6051c0b2f7Stbbdev 
6151c0b2f7Stbbdev         for (int i = 0; i < N; ++i ) {
6251c0b2f7Stbbdev             R v1(static_cast<R>(i));
6351c0b2f7Stbbdev             CHECK_MESSAGE( n.try_put( v1 ), "" );
6451c0b2f7Stbbdev             CHECK_MESSAGE( n.is_valid() == true, "" );
6551c0b2f7Stbbdev             for (int j = 0; j < N; ++j ) {
6651c0b2f7Stbbdev                 R v2(0);
6751c0b2f7Stbbdev                 CHECK_MESSAGE( n.try_get( v2 ), "" );
6851c0b2f7Stbbdev                 CHECK_MESSAGE( v1 == v2, "" );
6951c0b2f7Stbbdev             }
7051c0b2f7Stbbdev         }
7151c0b2f7Stbbdev         for (int i = 0; i < M; ++i) {
7251c0b2f7Stbbdev             size_t c = r[i]->my_count;
7351c0b2f7Stbbdev             CHECK_MESSAGE( int(c) == N+t%2, "" );
7451c0b2f7Stbbdev         }
7551c0b2f7Stbbdev         for (int i = 0; i < M; ++i) {
7651c0b2f7Stbbdev             tbb::flow::remove_edge( n, *r[i] );
7751c0b2f7Stbbdev         }
7851c0b2f7Stbbdev         CHECK_MESSAGE( n.try_put( R(0) ), "" );
7951c0b2f7Stbbdev         for (int i = 0; i < M; ++i) {
8051c0b2f7Stbbdev             size_t c = r[i]->my_count;
8151c0b2f7Stbbdev             CHECK_MESSAGE( int(c) == N+t%2, "" );
8251c0b2f7Stbbdev         }
8351c0b2f7Stbbdev         n.clear();
8451c0b2f7Stbbdev         CHECK_MESSAGE( n.is_valid() == false, "" );
8551c0b2f7Stbbdev         CHECK_MESSAGE( n.try_get( v0 ) == false, "" );
8651c0b2f7Stbbdev     }
8751c0b2f7Stbbdev }
8851c0b2f7Stbbdev 
8951c0b2f7Stbbdev template< typename R >
9051c0b2f7Stbbdev class native_body : utils::NoAssign {
9151c0b2f7Stbbdev     tbb::flow::overwrite_node<R> &my_node;
9251c0b2f7Stbbdev 
9351c0b2f7Stbbdev public:
9451c0b2f7Stbbdev 
native_body(tbb::flow::overwrite_node<R> & n)9551c0b2f7Stbbdev     native_body( tbb::flow::overwrite_node<R> &n ) : my_node(n) {}
9651c0b2f7Stbbdev 
operator ()(int i) const9751c0b2f7Stbbdev     void operator()( int i ) const {
9851c0b2f7Stbbdev         R v1(static_cast<R>(i));
9951c0b2f7Stbbdev         CHECK_MESSAGE( my_node.try_put( v1 ), "" );
10051c0b2f7Stbbdev         CHECK_MESSAGE( my_node.is_valid() == true, "" );
10151c0b2f7Stbbdev     }
10251c0b2f7Stbbdev };
10351c0b2f7Stbbdev 
10451c0b2f7Stbbdev template< typename R >
parallel_read_write_tests()10551c0b2f7Stbbdev void parallel_read_write_tests() {
10651c0b2f7Stbbdev     tbb::flow::graph g;
10751c0b2f7Stbbdev     tbb::flow::overwrite_node<R> n(g);
10851c0b2f7Stbbdev     //Create a vector of identical nodes
10951c0b2f7Stbbdev     std::vector< tbb::flow::overwrite_node<R> > ow_vec(2, n);
11051c0b2f7Stbbdev 
11151c0b2f7Stbbdev     for (size_t node_idx=0; node_idx<ow_vec.size(); ++node_idx) {
11251c0b2f7Stbbdev         for ( int t = 0; t < T; ++t ) {
11351c0b2f7Stbbdev             std::vector< std::shared_ptr<harness_counting_receiver<R>> > r;
11451c0b2f7Stbbdev             for (size_t i = 0; i < M; ++i) {
11551c0b2f7Stbbdev                 r.push_back( std::make_shared<harness_counting_receiver<R>>(g) );
11651c0b2f7Stbbdev             }
11751c0b2f7Stbbdev 
11851c0b2f7Stbbdev             for (int i = 0; i < M; ++i) {
11951c0b2f7Stbbdev                 tbb::flow::make_edge( ow_vec[node_idx], *r[i] );
12051c0b2f7Stbbdev             }
12151c0b2f7Stbbdev             R v0;
12251c0b2f7Stbbdev             CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" );
12351c0b2f7Stbbdev             CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" );
12451c0b2f7Stbbdev 
12551c0b2f7Stbbdev #if TBB_TEST_LOW_WORKLOAD
12651c0b2f7Stbbdev             const int nthreads = 30;
12751c0b2f7Stbbdev #else
12851c0b2f7Stbbdev             const int nthreads = N;
12951c0b2f7Stbbdev #endif
13051c0b2f7Stbbdev             utils::NativeParallelFor( nthreads, native_body<R>( ow_vec[node_idx] ) );
13151c0b2f7Stbbdev 
13251c0b2f7Stbbdev             for (int i = 0; i < M; ++i) {
13351c0b2f7Stbbdev                 size_t c = r[i]->my_count;
13451c0b2f7Stbbdev                 CHECK_MESSAGE( int(c) == nthreads, "" );
13551c0b2f7Stbbdev             }
13651c0b2f7Stbbdev             for (int i = 0; i < M; ++i) {
13751c0b2f7Stbbdev                 tbb::flow::remove_edge( ow_vec[node_idx], *r[i] );
13851c0b2f7Stbbdev             }
13951c0b2f7Stbbdev             CHECK_MESSAGE( ow_vec[node_idx].try_put( R(0) ), "" );
14051c0b2f7Stbbdev             for (int i = 0; i < M; ++i) {
14151c0b2f7Stbbdev                 size_t c = r[i]->my_count;
14251c0b2f7Stbbdev                 CHECK_MESSAGE( int(c) == nthreads, "" );
14351c0b2f7Stbbdev             }
14451c0b2f7Stbbdev             ow_vec[node_idx].clear();
14551c0b2f7Stbbdev             CHECK_MESSAGE( ow_vec[node_idx].is_valid() == false, "" );
14651c0b2f7Stbbdev             CHECK_MESSAGE( ow_vec[node_idx].try_get( v0 ) == false, "" );
14751c0b2f7Stbbdev         }
14851c0b2f7Stbbdev     }
14951c0b2f7Stbbdev }
15051c0b2f7Stbbdev 
15151c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
15251c0b2f7Stbbdev #include <array>
15351c0b2f7Stbbdev #include <vector>
test_follows_and_precedes_api()15451c0b2f7Stbbdev void test_follows_and_precedes_api() {
15551c0b2f7Stbbdev     using msg_t = tbb::flow::continue_msg;
15651c0b2f7Stbbdev 
15751c0b2f7Stbbdev     std::array<msg_t, 3> messages_for_follows = { {msg_t(), msg_t(), msg_t()} };
15851c0b2f7Stbbdev     std::vector<msg_t> messages_for_precedes = {msg_t()};
15951c0b2f7Stbbdev 
16051c0b2f7Stbbdev     follows_and_precedes_testing::test_follows<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_follows);
16151c0b2f7Stbbdev     follows_and_precedes_testing::test_precedes<msg_t, tbb::flow::overwrite_node<msg_t>>(messages_for_precedes);
16251c0b2f7Stbbdev }
16351c0b2f7Stbbdev #endif
16451c0b2f7Stbbdev 
16551c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
test_deduction_guides()16651c0b2f7Stbbdev void test_deduction_guides() {
16751c0b2f7Stbbdev     using namespace tbb::flow;
16851c0b2f7Stbbdev 
16951c0b2f7Stbbdev     graph g;
17051c0b2f7Stbbdev     broadcast_node<int> b1(g);
17151c0b2f7Stbbdev     overwrite_node<int> o0(g);
17251c0b2f7Stbbdev 
17351c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
17451c0b2f7Stbbdev     overwrite_node o1(follows(b1));
17551c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(o1), overwrite_node<int>>);
17651c0b2f7Stbbdev 
17751c0b2f7Stbbdev     overwrite_node o2(precedes(b1));
17851c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(o2), overwrite_node<int>>);
17951c0b2f7Stbbdev #endif
18051c0b2f7Stbbdev 
18151c0b2f7Stbbdev     overwrite_node o3(o0);
18251c0b2f7Stbbdev     static_assert(std::is_same_v<decltype(o3), overwrite_node<int>>);
18351c0b2f7Stbbdev }
18451c0b2f7Stbbdev #endif
18551c0b2f7Stbbdev 
18651c0b2f7Stbbdev //! Test read-write properties
18751c0b2f7Stbbdev //! \brief \ref requirement \ref error_guessing
18851c0b2f7Stbbdev TEST_CASE("Read-write"){
18951c0b2f7Stbbdev     simple_read_write_tests<int>();
19051c0b2f7Stbbdev     simple_read_write_tests<float>();
19151c0b2f7Stbbdev }
19251c0b2f7Stbbdev 
19351c0b2f7Stbbdev //! Read-write and ParallelFor tests under limited parallelism
19451c0b2f7Stbbdev //! \brief \ref error_guessing
19551c0b2f7Stbbdev TEST_CASE("Limited parallelism"){
19651c0b2f7Stbbdev     for( unsigned int p=utils::MinThread; p<=utils::MaxThread; ++p ) {
19751c0b2f7Stbbdev         tbb::task_arena arena(p);
19851c0b2f7Stbbdev         arena.execute(
__anona8c67c360102() 19951c0b2f7Stbbdev             [&]() {
20051c0b2f7Stbbdev                 parallel_read_write_tests<int>();
20151c0b2f7Stbbdev                 parallel_read_write_tests<float>();
20251c0b2f7Stbbdev                 test_reserving_nodes<tbb::flow::overwrite_node, size_t>();
20351c0b2f7Stbbdev             }
20451c0b2f7Stbbdev         );
20551c0b2f7Stbbdev 	}
20651c0b2f7Stbbdev }
20751c0b2f7Stbbdev 
20851c0b2f7Stbbdev #if __TBB_PREVIEW_FLOW_GRAPH_NODE_SET
20951c0b2f7Stbbdev //! Test follows and precedes API
21051c0b2f7Stbbdev //! \brief \ref error_guessing
21151c0b2f7Stbbdev TEST_CASE("Follows and precedes API"){
21251c0b2f7Stbbdev     test_follows_and_precedes_api();
21351c0b2f7Stbbdev }
21451c0b2f7Stbbdev #endif
21551c0b2f7Stbbdev 
21651c0b2f7Stbbdev #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT
21751c0b2f7Stbbdev //! Test decution guides
21851c0b2f7Stbbdev //! \brief \ref requirement
21951c0b2f7Stbbdev TEST_CASE("Deduction guides"){
22051c0b2f7Stbbdev     test_deduction_guides();
22151c0b2f7Stbbdev }
22251c0b2f7Stbbdev #endif
22351c0b2f7Stbbdev 
22451c0b2f7Stbbdev //! Test try_release
22551c0b2f7Stbbdev //! \brief \ref error_guessing
22651c0b2f7Stbbdev TEST_CASE("try_release"){
22751c0b2f7Stbbdev     tbb::flow::graph g;
22851c0b2f7Stbbdev 
22951c0b2f7Stbbdev     tbb::flow::overwrite_node<int> on(g);
23051c0b2f7Stbbdev 
23151c0b2f7Stbbdev     CHECK_MESSAGE ((on.try_release()== true), "try_release should return true");
23251c0b2f7Stbbdev }
233*5d21288aSVladimir Serov 
234*5d21288aSVladimir Serov //! Test for cancel register_predecessor_task
235*5d21288aSVladimir Serov //! \brief \ref error_guessing
236*5d21288aSVladimir Serov TEST_CASE("Cancel register_predecessor_task") {
237*5d21288aSVladimir Serov     tbb::flow::graph g;
238*5d21288aSVladimir Serov     // Cancel graph context for preventing tasks execution and
239*5d21288aSVladimir Serov     // calling cancel method of spawned tasks
240*5d21288aSVladimir Serov     g.cancel();
241*5d21288aSVladimir Serov 
242*5d21288aSVladimir Serov     // To spawn register_predecessor_task internal buffer of overwrite_node
243*5d21288aSVladimir Serov     // should be valid and successor should failed during putting an item to it
244*5d21288aSVladimir Serov     oneapi::tbb::flow::overwrite_node<size_t> node(g);
245*5d21288aSVladimir Serov     // Reserving join_node always fails during putting an item to it
246*5d21288aSVladimir Serov     tbb::flow::join_node<std::tuple<size_t>, tbb::flow::reserving> j_node(g);
247*5d21288aSVladimir Serov 
248*5d21288aSVladimir Serov     // Make internal buffer of overwrite_node valid
249*5d21288aSVladimir Serov     node.try_put(1);
250*5d21288aSVladimir Serov     // Making an edge attempts pushing an item to join_node
251*5d21288aSVladimir Serov     // that immediately fails and tries to reverse an edge into PULL state
252*5d21288aSVladimir Serov     // by spawning register_predecessor_task, which will be cancelled
253*5d21288aSVladimir Serov     // during execution
254*5d21288aSVladimir Serov     tbb::flow::make_edge(node, tbb::flow::input_port<0>(j_node));
255*5d21288aSVladimir Serov 
256*5d21288aSVladimir Serov     // Wait for cancellation of spawned tasks
257*5d21288aSVladimir Serov     g.wait_for_all();
258*5d21288aSVladimir Serov }
259