1 /* 2 Copyright (c) 2020-2021 Intel Corporation 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 #if __INTEL_COMPILER && _MSC_VER 18 #pragma warning(disable : 2586) // decorated name length exceeded, name was truncated 19 #endif 20 21 #include "common/test.h" 22 23 #include "common/utils.h" 24 #include "common/graph_utils.h" 25 26 #include "oneapi/tbb/flow_graph.h" 27 #include "oneapi/tbb/task_arena.h" 28 29 #include "conformance_flowgraph.h" 30 31 //! \file conformance_input_node.cpp 32 //! \brief Test for [flow_graph.input_node] specification 33 34 /* 35 TODO: implement missing conformance tests for input_node: 36 - [ ] The `copy_body' function copies altered body (e.g. after its successful invocation). 37 - [ ] Check that in `test_forwarding' the value passed is the actual one received. 38 - [ ] Improve CTAD test to assert result node type. 39 - [ ] Explicit test for copy constructor of the node. 40 - [ ] Check `Output' type indeed copy-constructed and copy-assigned while working with the node. 41 - [ ] Check node cannot have predecessors (Will ADL be of any help here?) 42 - [ ] Check the node is serial and its body never invoked concurrently. 43 - [ ] `try_get()' call testing: a call to body is made only when the internal buffer is empty. 44 */ 45 46 std::atomic<size_t> global_execute_count; 47 48 template<typename OutputType> 49 struct input_functor { 50 const size_t n; 51 52 input_functor( ) : n(10) { } 53 input_functor( const input_functor &f ) : n(f.n) { } 54 void operator=(const input_functor &f) { n = f.n; } 55 56 OutputType operator()( oneapi::tbb::flow_control & fc ) { 57 ++global_execute_count; 58 if(global_execute_count > n){ 59 fc.stop(); 60 return OutputType(); 61 } 62 return OutputType(global_execute_count.load()); 63 } 64 65 }; 66 67 template<typename O> 68 struct CopyCounterBody{ 69 size_t copy_count; 70 71 CopyCounterBody(): 72 copy_count(0) {} 73 74 CopyCounterBody(const CopyCounterBody<O>& other): 75 copy_count(other.copy_count + 1) {} 76 77 CopyCounterBody& operator=(const CopyCounterBody<O>& other) { 78 copy_count = other.copy_count + 1; return *this; 79 } 80 81 O operator()(oneapi::tbb::flow_control & fc){ 82 fc.stop(); 83 return O(); 84 } 85 }; 86 87 88 void test_input_body(){ 89 oneapi::tbb::flow::graph g; 90 input_functor<int> fun; 91 92 global_execute_count = 0; 93 oneapi::tbb::flow::input_node<int> node1(g, fun); 94 test_push_receiver<int> node2(g); 95 96 oneapi::tbb::flow::make_edge(node1, node2); 97 98 node1.activate(); 99 g.wait_for_all(); 100 101 CHECK_MESSAGE( (get_count(node2) == 10), "Descendant of the node needs to be receive N messages"); 102 CHECK_MESSAGE( (global_execute_count == 10 + 1), "Body of the node needs to be executed N + 1 times"); 103 } 104 105 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 106 void test_deduction_guides(){ 107 oneapi::tbb::flow::graph g; 108 input_functor<int> fun; 109 oneapi::tbb::flow::input_node node1(g, fun); 110 } 111 #endif 112 113 void test_buffering(){ 114 oneapi::tbb::flow::graph g; 115 input_functor<int> fun; 116 global_execute_count = 0; 117 118 oneapi::tbb::flow::input_node<int> source(g, fun); 119 oneapi::tbb::flow::limiter_node<int> rejecter(g, 0); 120 121 oneapi::tbb::flow::make_edge(source, rejecter); 122 source.activate(); 123 g.wait_for_all(); 124 125 int tmp = -1; 126 CHECK_MESSAGE( (source.try_get(tmp) == true), "try_get after rejection should succeed"); 127 CHECK_MESSAGE( (tmp == 1), "try_get should return correct value"); 128 } 129 130 void test_forwarding(){ 131 oneapi::tbb::flow::graph g; 132 input_functor<int> fun; 133 134 global_execute_count = 0; 135 oneapi::tbb::flow::input_node<int> node1(g, fun); 136 test_push_receiver<int> node2(g); 137 test_push_receiver<int> node3(g); 138 139 oneapi::tbb::flow::make_edge(node1, node2); 140 oneapi::tbb::flow::make_edge(node1, node3); 141 142 node1.activate(); 143 g.wait_for_all(); 144 145 CHECK_MESSAGE( (get_count(node2) == 10), "Descendant of the node needs to be receive N messages"); 146 CHECK_MESSAGE( (get_count(node3) == 10), "Descendant of the node needs to be receive N messages"); 147 } 148 149 template<typename O> 150 void test_inheritance(){ 151 using namespace oneapi::tbb::flow; 152 153 CHECK_MESSAGE( (std::is_base_of<graph_node, input_node<O>>::value), "input_node should be derived from graph_node"); 154 CHECK_MESSAGE( (std::is_base_of<sender<O>, input_node<O>>::value), "input_node should be derived from sender<Output>"); 155 } 156 157 void test_copies(){ 158 using namespace oneapi::tbb::flow; 159 160 CopyCounterBody<int> b; 161 162 graph g; 163 input_node<int> fn(g, b); 164 165 CopyCounterBody<int> b2 = copy_body<CopyCounterBody<int>, input_node<int>>(fn); 166 167 CHECK_MESSAGE( (b.copy_count + 2 <= b2.copy_count), "copy_body and constructor should copy bodies"); 168 } 169 170 //! Test body copying and copy_body logic 171 //! \brief \ref interface 172 TEST_CASE("input_node and body copying"){ 173 test_copies(); 174 } 175 176 //! Test inheritance relations 177 //! \brief \ref interface 178 TEST_CASE("input_node superclasses"){ 179 test_inheritance<int>(); 180 test_inheritance<void*>(); 181 } 182 183 //! Test input_node forwarding 184 //! \brief \ref requirement 185 TEST_CASE("input_node forwarding"){ 186 test_forwarding(); 187 } 188 189 //! Test input_node buffering 190 //! \brief \ref requirement 191 TEST_CASE("input_node buffering"){ 192 test_buffering(); 193 } 194 195 //! Test calling input_node body 196 //! \brief \ref interface \ref requirement 197 TEST_CASE("input_node body") { 198 test_input_body(); 199 } 200 201 //! Test deduction guides 202 //! \brief \ref interface \ref requirement 203 TEST_CASE("Deduction guides"){ 204 #if __TBB_CPP17_DEDUCTION_GUIDES_PRESENT 205 test_deduction_guides(); 206 #endif 207 } 208