xref: /oneTBB/test/common/test_comparisons.h (revision 8b6f831c)
1 /*
2     Copyright (c) 2020-2021 Intel Corporation
3 
4     Licensed under the Apache License, Version 2.0 (the "License");
5     you may not use this file except in compliance with the License.
6     You may obtain a copy of the License at
7 
8         http://www.apache.org/licenses/LICENSE-2.0
9 
10     Unless required by applicable law or agreed to in writing, software
11     distributed under the License is distributed on an "AS IS" BASIS,
12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13     See the License for the specific language governing permissions and
14     limitations under the License.
15 */
16 
17 #ifndef __TBB_test_common_test_comparisons_H
18 #define __TBB_test_common_test_comparisons_H
19 
20 #include "test.h"
21 
22 #ifndef __TBB_TEST_CPP20_COMPARISONS
23 #define __TBB_TEST_CPP20_COMPARISONS __TBB_CPP20_COMPARISONS_PRESENT
24 #endif
25 
26 #if __TBB_TEST_CPP20_COMPARISONS
27 #include <compare>
28 #endif
29 
30 namespace comparisons_testing {
31 
32 template <bool ExpectEqual, bool ExpectLess, typename T>
testTwoWayComparisons(const T & lhs,const T & rhs)33 void testTwoWayComparisons( const T& lhs, const T& rhs ) {
34     REQUIRE_MESSAGE(((lhs < rhs) == ExpectLess),
35                     "Incorrect 2-way comparison result for less operation");
36     REQUIRE_MESSAGE(((lhs <= rhs) == (ExpectLess || ExpectEqual)),
37                     "Incorrect 2-way comparison result for less or equal operation");
38     bool ExpectGreater = ExpectEqual ? false : !ExpectLess;
39     REQUIRE_MESSAGE(((lhs > rhs) == ExpectGreater),
40                     "Incorrect 2-way comparison result for greater operation");
41     REQUIRE_MESSAGE(((lhs >= rhs) == (ExpectGreater || ExpectEqual)),
42                     "Incorrect 2-way comparison result for greater or equal operation");
43 }
44 
45 template <bool ExpectEqual, typename T>
testEqualityComparisons(const T & lhs,const T & rhs)46 void testEqualityComparisons( const T& lhs, const T& rhs ) {
47     REQUIRE_MESSAGE((lhs == rhs) == ExpectEqual,
48                     "Incorrect 2-way comparison result for equal operation");
49     REQUIRE_MESSAGE((lhs != rhs) == !ExpectEqual,
50                     "Incorrect 2-way comparison result for unequal operation");
51 }
52 
53 #if __TBB_TEST_CPP20_COMPARISONS
54 template <bool ExpectEqual, bool ExpectLess, typename T>
testThreeWayComparisons(const T & lhs,const T & rhs)55 void testThreeWayComparisons( const T& lhs, const T& rhs ) {
56     auto three_way_result = lhs <=> rhs;
57     REQUIRE_MESSAGE((three_way_result < 0) == ExpectLess,
58                     "Incorrect 3-way comparison result for less operation");
59     REQUIRE_MESSAGE((lhs <=> rhs <= 0) == (ExpectLess || ExpectEqual),
60                     "Incorrect 3-way comparison result for less or equal operation");
61     bool ExpectGreater = ExpectEqual ? false : !ExpectLess;
62     REQUIRE_MESSAGE((lhs <=> rhs > 0) == ExpectGreater,
63                     "Incorrect 3-way comparison result for greater operation");
64     REQUIRE_MESSAGE((lhs <=> rhs >= 0) == (ExpectGreater || ExpectEqual),
65                     "Incorrect 3-way comparison result for greater or equal operation");
66     REQUIRE_MESSAGE((lhs <=> rhs == 0) == ExpectEqual,
67                     "Incorrect 3-way comparison result for equal operation");
68     REQUIRE_MESSAGE((lhs <=> rhs != 0) == !ExpectEqual,
69                     "Incorrect 3-way comparison result for unequal operation");
70 }
71 
72 #endif // __TBB_TEST_CPP20_COMPARISONS
73 
74 template <bool ExpectEqual, bool ExpectLess, typename T>
testEqualityAndLessComparisons(const T & lhs,const T & rhs)75 void testEqualityAndLessComparisons( const T& lhs, const T& rhs ) {
76     testEqualityComparisons<ExpectEqual>(lhs, rhs);
77     testTwoWayComparisons<ExpectEqual, ExpectLess>(lhs, rhs);
78 #if __TBB_TEST_CPP20_COMPARISONS
79     testThreeWayComparisons<ExpectEqual, ExpectLess>(lhs, rhs);
80 #endif
81 }
82 
83 class TwoWayComparable {
84 public:
TwoWayComparable()85     TwoWayComparable() : n(0) {
86         reset();
87     }
88 
TwoWayComparable(std::size_t num)89     TwoWayComparable( std::size_t num ) : n(num) {
90         reset();
91     }
92 
reset()93     static void reset() {
94         equal_called = false;
95         unequal_called = false;
96         less_called = false;
97         greater_called = false;
98         less_or_equal_called = false;
99         greater_or_equal_called = false;
100     }
101 
102     static bool equal_called;
103     static bool unequal_called;
104     static bool less_called;
105     static bool greater_called;
106     static bool less_or_equal_called;
107     static bool greater_or_equal_called;
108 
109     friend bool operator==( const TwoWayComparable& lhs, const TwoWayComparable& rhs ) {
110         equal_called = true;
111         return lhs.n == rhs.n;
112     }
113 
114     friend bool operator!=( const TwoWayComparable& lhs, const TwoWayComparable& rhs ) {
115         unequal_called = true;
116         return lhs.n != rhs.n;
117     }
118 
119     friend bool operator<( const TwoWayComparable& lhs, const TwoWayComparable& rhs ) {
120         less_called = true;
121         return lhs.n < rhs.n;
122     }
123 
124     friend bool operator>( const TwoWayComparable& lhs, const TwoWayComparable& rhs ) {
125         greater_called = true;
126         return lhs.n > rhs.n;
127     }
128 
129     friend bool operator<=( const TwoWayComparable& lhs, const TwoWayComparable& rhs ) {
130         less_or_equal_called = true;
131         return lhs.n <= rhs.n;
132     }
133 
134     friend bool operator>=( const TwoWayComparable& lhs, const TwoWayComparable& rhs ) {
135         greater_or_equal_called = true;
136         return lhs.n >= rhs.n;
137     }
138 
139 protected:
140     std::size_t n;
141 
142     friend struct std::hash<TwoWayComparable>;
143 }; // struct TwoWayComparable
144 
145 bool TwoWayComparable::equal_called = false;
146 bool TwoWayComparable::unequal_called = false;
147 bool TwoWayComparable::less_called = false;
148 bool TwoWayComparable::greater_called = false;
149 bool TwoWayComparable::less_or_equal_called = false;
150 bool TwoWayComparable::greater_or_equal_called = false;
151 
152 // This function should be executed after comparing two objects, containing TwoWayComparables
153 // using one of the comparison operators (<=>, <, >, <=, >=)
154 void check_two_way_comparison() {
155     REQUIRE_MESSAGE(TwoWayComparable::less_called,
156                     "operator < was not called during the comparison");
157     REQUIRE_MESSAGE(!TwoWayComparable::greater_called,
158                     "operator > was called during the comparison");
159     REQUIRE_MESSAGE(!TwoWayComparable::less_or_equal_called,
160                     "operator <= was called during the comparison");
161     REQUIRE_MESSAGE(!TwoWayComparable::greater_or_equal_called,
162                     "operator >= was called during the comparison");
163     REQUIRE_MESSAGE(!(TwoWayComparable::equal_called),
164                     "operator == was called during the comparison");
165     REQUIRE_MESSAGE(!(TwoWayComparable::unequal_called),
166                     "operator == was called during the comparison");
167     TwoWayComparable::reset();
168 }
169 
170 // This function should be executed after comparing two objects, containing TwoWayComparables
171 // using operator == or !=
172 void check_equality_comparison() {
173     REQUIRE_MESSAGE(TwoWayComparable::equal_called,
174                     "operator == was not called during the comparison");
175     REQUIRE_MESSAGE(!(TwoWayComparable::unequal_called),
176                     "operator != was called during the comparison");
177     TwoWayComparable::reset();
178 }
179 
180 #if __TBB_TEST_CPP20_COMPARISONS
181 class ThreeWayComparable : public TwoWayComparable {
182 public:
183     ThreeWayComparable() : TwoWayComparable() { reset(); }
184 
185     ThreeWayComparable( std::size_t num ) : TwoWayComparable(num) { reset(); }
186 
187     static void reset() {
188         TwoWayComparable::reset();
189         three_way_called = false;
190     }
191 
192     static bool three_way_called;
193 
194     friend auto operator<=>( const ThreeWayComparable& lhs, const ThreeWayComparable& rhs ) {
195         three_way_called = true;
196         return lhs.n <=> rhs.n;
197     }
198 
199     friend bool operator==( const ThreeWayComparable&, const ThreeWayComparable& ) = default;
200 }; // class ThreeWayComparable
201 
202 bool ThreeWayComparable::three_way_called = false;
203 
204 // This function should be executed after comparing objects, containing ThreeWayComparables
205 // using one of the comparison operators (<=>, <, >, <=, >=)
206 void check_three_way_comparison() {
207     REQUIRE_MESSAGE(ThreeWayComparable::three_way_called, "operator <=> was not called during the comparison");
208     REQUIRE_MESSAGE(!ThreeWayComparable::less_called, "operator < was called during the comparison");
209     REQUIRE_MESSAGE(!ThreeWayComparable::greater_called, "operator > was called during the comparison");
210     REQUIRE_MESSAGE(!ThreeWayComparable::less_or_equal_called, "operator <= was called during the comparison");
211     REQUIRE_MESSAGE(!ThreeWayComparable::greater_or_equal_called, "operator >= was called during the comparison");
212     ThreeWayComparable::reset();
213 }
214 
215 // Required for testing synthesized_three_way_comparison
216 class ThreeWayComparableOnly {
217 public:
218     ThreeWayComparableOnly() : n(0) {}
219     ThreeWayComparableOnly( std::size_t num ) : n(num) {}
220 
221     friend auto operator<=>( const ThreeWayComparableOnly& lhs, const ThreeWayComparableOnly& rhs ) {
222         return lhs.n <=> rhs.n;
223     }
224     friend bool operator==( const ThreeWayComparableOnly& lhs, const ThreeWayComparableOnly& rhs ) {
225         return lhs.n == rhs.n;
226     }
227 private:
228     std::size_t n;
229 }; // class ThreeWayComparableOnly
230 
231 // Required for testing synthesized_three_way_comparison
232 class LessComparableOnly {
233 public:
234     LessComparableOnly() : n(0) {}
235     LessComparableOnly( std::size_t num ) : n(num) {}
236 
237     friend bool operator<( const LessComparableOnly& lhs, const LessComparableOnly& rhs ) {
238         return lhs.n < rhs.n;
239     }
240     friend bool operator==( const LessComparableOnly& lhs, const LessComparableOnly& rhs ) {
241         return lhs.n == rhs.n;
242     }
243 private:
244     std::size_t n;
245 }; // class LessComparableOnly
246 
247 #endif // __TBB_TEST_CPP20_COMPARISONS
248 } // namespace comparisons_testing
249 
250 namespace std {
251 
252 template <>
253 struct hash<comparisons_testing::TwoWayComparable> {
254     std::size_t operator()( const comparisons_testing::TwoWayComparable& val ) const {
255         return std::hash<std::size_t>{}(val.n);
256     }
257 };
258 
259 } // namespace std
260 
261 #endif // __TBB_test_common_test_comparisons_H
262