1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H 10 #define LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H 11 12 #include "src/__support/CPP/TypeTraits.h" 13 #include "utils/UnitTest/Test.h" 14 15 #include <stdint.h> 16 17 namespace __llvm_libc { 18 namespace testing { 19 namespace mpfr { 20 21 enum class Operation : int { 22 // Operations with take a single floating point number as input 23 // and produce a single floating point number as output. The input 24 // and output floating point numbers are of the same kind. 25 BeginUnaryOperationsSingleOutput, 26 Abs, 27 Ceil, 28 Cos, 29 Exp, 30 Exp2, 31 Expm1, 32 Floor, 33 Log, 34 Log2, 35 Mod2PI, 36 ModPIOver2, 37 ModPIOver4, 38 Round, 39 Sin, 40 Sqrt, 41 Tan, 42 Trunc, 43 EndUnaryOperationsSingleOutput, 44 45 // Operations which take a single floating point nubmer as input 46 // but produce two outputs. The first ouput is a floating point 47 // number of the same type as the input. The second output is of type 48 // 'int'. 49 BeginUnaryOperationsTwoOutputs, 50 Frexp, // Floating point output, the first output, is the fractional part. 51 EndUnaryOperationsTwoOutputs, 52 53 // Operations wich take two floating point nubmers of the same type as 54 // input and produce a single floating point number of the same type as 55 // output. 56 BeginBinaryOperationsSingleOutput, 57 Hypot, 58 EndBinaryOperationsSingleOutput, 59 60 // Operations which take two floating point numbers of the same type as 61 // input and produce two outputs. The first output is a floating nubmer of 62 // the same type as the inputs. The second output is af type 'int'. 63 BeginBinaryOperationsTwoOutputs, 64 RemQuo, // The first output, the floating point output, is the remainder. 65 EndBinaryOperationsTwoOutputs, 66 67 // Operations which take three floating point nubmers of the same type as 68 // input and produce a single floating point number of the same type as 69 // output. 70 BeginTernaryOperationsSingleOuput, 71 Fma, 72 EndTernaryOperationsSingleOutput, 73 }; 74 75 enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest }; 76 77 int get_fe_rounding(RoundingMode mode); 78 79 struct ForceRoundingMode { 80 ForceRoundingMode(RoundingMode); 81 ~ForceRoundingMode(); 82 83 int old_rounding_mode; 84 int rounding_mode; 85 }; 86 87 template <typename T> struct BinaryInput { 88 static_assert( 89 __llvm_libc::cpp::IsFloatingPointType<T>::Value, 90 "Template parameter of BinaryInput must be a floating point type."); 91 92 using Type = T; 93 T x, y; 94 }; 95 96 template <typename T> struct TernaryInput { 97 static_assert( 98 __llvm_libc::cpp::IsFloatingPointType<T>::Value, 99 "Template parameter of TernaryInput must be a floating point type."); 100 101 using Type = T; 102 T x, y, z; 103 }; 104 105 template <typename T> struct BinaryOutput { 106 T f; 107 int i; 108 }; 109 110 namespace internal { 111 112 template <typename T1, typename T2> 113 struct AreMatchingBinaryInputAndBinaryOutput { 114 static constexpr bool VALUE = false; 115 }; 116 117 template <typename T> 118 struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> { 119 static constexpr bool VALUE = cpp::IsFloatingPointType<T>::Value; 120 }; 121 122 template <typename T> 123 bool compare_unary_operation_single_output(Operation op, T input, T libc_output, 124 double ulp_tolerance, 125 RoundingMode rounding); 126 template <typename T> 127 bool compare_unary_operation_two_outputs(Operation op, T input, 128 const BinaryOutput<T> &libc_output, 129 double ulp_tolerance, 130 RoundingMode rounding); 131 template <typename T> 132 bool compare_binary_operation_two_outputs(Operation op, 133 const BinaryInput<T> &input, 134 const BinaryOutput<T> &libc_output, 135 double ulp_tolerance, 136 RoundingMode rounding); 137 138 template <typename T> 139 bool compare_binary_operation_one_output(Operation op, 140 const BinaryInput<T> &input, 141 T libc_output, double ulp_tolerance, 142 RoundingMode rounding); 143 144 template <typename T> 145 bool compare_ternary_operation_one_output(Operation op, 146 const TernaryInput<T> &input, 147 T libc_output, double ulp_tolerance, 148 RoundingMode rounding); 149 150 template <typename T> 151 void explain_unary_operation_single_output_error(Operation op, T input, 152 T match_value, 153 double ulp_tolerance, 154 RoundingMode rounding, 155 testutils::StreamWrapper &OS); 156 template <typename T> 157 void explain_unary_operation_two_outputs_error( 158 Operation op, T input, const BinaryOutput<T> &match_value, 159 double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS); 160 template <typename T> 161 void explain_binary_operation_two_outputs_error( 162 Operation op, const BinaryInput<T> &input, 163 const BinaryOutput<T> &match_value, double ulp_tolerance, 164 RoundingMode rounding, testutils::StreamWrapper &OS); 165 166 template <typename T> 167 void explain_binary_operation_one_output_error( 168 Operation op, const BinaryInput<T> &input, T match_value, 169 double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS); 170 171 template <typename T> 172 void explain_ternary_operation_one_output_error( 173 Operation op, const TernaryInput<T> &input, T match_value, 174 double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS); 175 176 template <Operation op, typename InputType, typename OutputType> 177 class MPFRMatcher : public testing::Matcher<OutputType> { 178 InputType input; 179 OutputType match_value; 180 double ulp_tolerance; 181 RoundingMode rounding; 182 183 public: 184 MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding) 185 : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {} 186 187 bool match(OutputType libcResult) { 188 match_value = libcResult; 189 return match(input, match_value); 190 } 191 192 // This method is marked with NOLINT because it the name `explainError` 193 // does not confirm to the coding style. 194 void explainError(testutils::StreamWrapper &OS) override { // NOLINT 195 explain_error(input, match_value, OS); 196 } 197 198 private: 199 template <typename T> bool match(T in, T out) { 200 return compare_unary_operation_single_output(op, in, out, ulp_tolerance, 201 rounding); 202 } 203 204 template <typename T> bool match(T in, const BinaryOutput<T> &out) { 205 return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance, 206 rounding); 207 } 208 209 template <typename T> bool match(const BinaryInput<T> &in, T out) { 210 return compare_binary_operation_one_output(op, in, out, ulp_tolerance, 211 rounding); 212 } 213 214 template <typename T> 215 bool match(BinaryInput<T> in, const BinaryOutput<T> &out) { 216 return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance, 217 rounding); 218 } 219 220 template <typename T> bool match(const TernaryInput<T> &in, T out) { 221 return compare_ternary_operation_one_output(op, in, out, ulp_tolerance, 222 rounding); 223 } 224 225 template <typename T> 226 void explain_error(T in, T out, testutils::StreamWrapper &OS) { 227 explain_unary_operation_single_output_error(op, in, out, ulp_tolerance, 228 rounding, OS); 229 } 230 231 template <typename T> 232 void explain_error(T in, const BinaryOutput<T> &out, 233 testutils::StreamWrapper &OS) { 234 explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance, 235 rounding, OS); 236 } 237 238 template <typename T> 239 void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out, 240 testutils::StreamWrapper &OS) { 241 explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance, 242 rounding, OS); 243 } 244 245 template <typename T> 246 void explain_error(const BinaryInput<T> &in, T out, 247 testutils::StreamWrapper &OS) { 248 explain_binary_operation_one_output_error(op, in, out, ulp_tolerance, 249 rounding, OS); 250 } 251 252 template <typename T> 253 void explain_error(const TernaryInput<T> &in, T out, 254 testutils::StreamWrapper &OS) { 255 explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance, 256 rounding, OS); 257 } 258 }; 259 260 } // namespace internal 261 262 // Return true if the input and ouput types for the operation op are valid 263 // types. 264 template <Operation op, typename InputType, typename OutputType> 265 constexpr bool is_valid_operation() { 266 return (Operation::BeginUnaryOperationsSingleOutput < op && 267 op < Operation::EndUnaryOperationsSingleOutput && 268 cpp::IsSame<InputType, OutputType>::Value && 269 cpp::IsFloatingPointType<InputType>::Value) || 270 (Operation::BeginUnaryOperationsTwoOutputs < op && 271 op < Operation::EndUnaryOperationsTwoOutputs && 272 cpp::IsFloatingPointType<InputType>::Value && 273 cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) || 274 (Operation::BeginBinaryOperationsSingleOutput < op && 275 op < Operation::EndBinaryOperationsSingleOutput && 276 cpp::IsFloatingPointType<OutputType>::Value && 277 cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) || 278 (Operation::BeginBinaryOperationsTwoOutputs < op && 279 op < Operation::EndBinaryOperationsTwoOutputs && 280 internal::AreMatchingBinaryInputAndBinaryOutput<InputType, 281 OutputType>::VALUE) || 282 (Operation::BeginTernaryOperationsSingleOuput < op && 283 op < Operation::EndTernaryOperationsSingleOutput && 284 cpp::IsFloatingPointType<OutputType>::Value && 285 cpp::IsSame<InputType, TernaryInput<OutputType>>::Value); 286 } 287 288 template <Operation op, typename InputType, typename OutputType> 289 __attribute__((no_sanitize("address"))) 290 cpp::EnableIfType<is_valid_operation<op, InputType, OutputType>(), 291 internal::MPFRMatcher<op, InputType, OutputType>> 292 get_mpfr_matcher(InputType input, OutputType output_unused, 293 double ulp_tolerance, RoundingMode rounding) { 294 return internal::MPFRMatcher<op, InputType, OutputType>(input, ulp_tolerance, 295 rounding); 296 } 297 298 template <typename T> T round(T x, RoundingMode mode); 299 300 template <typename T> bool round_to_long(T x, long &result); 301 template <typename T> bool round_to_long(T x, RoundingMode mode, long &result); 302 303 } // namespace mpfr 304 } // namespace testing 305 } // namespace __llvm_libc 306 307 // GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a 308 // simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`. 309 #define GET_MPFR_DUMMY_ARG(...) 0 310 311 #define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME 312 313 #define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ 314 EXPECT_THAT(match_value, \ 315 __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \ 316 input, match_value, ulp_tolerance, \ 317 __llvm_libc::testing::mpfr::RoundingMode::Nearest)) 318 319 #define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ 320 rounding) \ 321 EXPECT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \ 322 input, match_value, ulp_tolerance, rounding)) 323 324 #define EXPECT_MPFR_MATCH(...) \ 325 GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING, \ 326 EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \ 327 (__VA_ARGS__) 328 329 #define EXPECT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \ 330 { \ 331 namespace mpfr = __llvm_libc::testing::mpfr; \ 332 mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest); \ 333 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 334 mpfr::RoundingMode::Nearest); \ 335 mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward); \ 336 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 337 mpfr::RoundingMode::Upward); \ 338 mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward); \ 339 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 340 mpfr::RoundingMode::Downward); \ 341 mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero); \ 342 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 343 mpfr::RoundingMode::TowardZero); \ 344 } 345 346 #define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ 347 ASSERT_THAT(match_value, \ 348 __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \ 349 input, match_value, ulp_tolerance, \ 350 __llvm_libc::testing::mpfr::RoundingMode::Nearest)) 351 352 #define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ 353 rounding) \ 354 ASSERT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \ 355 input, match_value, ulp_tolerance, rounding)) 356 357 #define ASSERT_MPFR_MATCH(...) \ 358 GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING, \ 359 ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \ 360 (__VA_ARGS__) 361 362 #define ASSERT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \ 363 { \ 364 namespace mpfr = __llvm_libc::testing::mpfr; \ 365 mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest); \ 366 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 367 mpfr::RoundingMode::Nearest); \ 368 mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward); \ 369 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 370 mpfr::RoundingMode::Upward); \ 371 mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward); \ 372 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 373 mpfr::RoundingMode::Downward); \ 374 mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero); \ 375 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 376 mpfr::RoundingMode::TowardZero); \ 377 } 378 379 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H 380