1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H 10 #define LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H 11 12 #include "src/__support/CPP/TypeTraits.h" 13 #include "utils/UnitTest/Test.h" 14 15 #include <stdint.h> 16 17 namespace __llvm_libc { 18 namespace testing { 19 namespace mpfr { 20 21 enum class Operation : int { 22 // Operations with take a single floating point number as input 23 // and produce a single floating point number as output. The input 24 // and output floating point numbers are of the same kind. 25 BeginUnaryOperationsSingleOutput, 26 Abs, 27 Ceil, 28 Cos, 29 Exp, 30 Exp2, 31 Expm1, 32 Floor, 33 Log, 34 Log2, 35 Log10, 36 Log1p, 37 Mod2PI, 38 ModPIOver2, 39 ModPIOver4, 40 Round, 41 Sin, 42 Sqrt, 43 Tan, 44 Trunc, 45 EndUnaryOperationsSingleOutput, 46 47 // Operations which take a single floating point nubmer as input 48 // but produce two outputs. The first ouput is a floating point 49 // number of the same type as the input. The second output is of type 50 // 'int'. 51 BeginUnaryOperationsTwoOutputs, 52 Frexp, // Floating point output, the first output, is the fractional part. 53 EndUnaryOperationsTwoOutputs, 54 55 // Operations wich take two floating point nubmers of the same type as 56 // input and produce a single floating point number of the same type as 57 // output. 58 BeginBinaryOperationsSingleOutput, 59 Fmod, 60 Hypot, 61 EndBinaryOperationsSingleOutput, 62 63 // Operations which take two floating point numbers of the same type as 64 // input and produce two outputs. The first output is a floating nubmer of 65 // the same type as the inputs. The second output is af type 'int'. 66 BeginBinaryOperationsTwoOutputs, 67 RemQuo, // The first output, the floating point output, is the remainder. 68 EndBinaryOperationsTwoOutputs, 69 70 // Operations which take three floating point nubmers of the same type as 71 // input and produce a single floating point number of the same type as 72 // output. 73 BeginTernaryOperationsSingleOuput, 74 Fma, 75 EndTernaryOperationsSingleOutput, 76 }; 77 78 enum class RoundingMode : uint8_t { Upward, Downward, TowardZero, Nearest }; 79 80 int get_fe_rounding(RoundingMode mode); 81 82 struct ForceRoundingMode { 83 ForceRoundingMode(RoundingMode); 84 ~ForceRoundingMode(); 85 86 int old_rounding_mode; 87 int rounding_mode; 88 }; 89 90 template <typename T> struct BinaryInput { 91 static_assert( 92 __llvm_libc::cpp::IsFloatingPointType<T>::Value, 93 "Template parameter of BinaryInput must be a floating point type."); 94 95 using Type = T; 96 T x, y; 97 }; 98 99 template <typename T> struct TernaryInput { 100 static_assert( 101 __llvm_libc::cpp::IsFloatingPointType<T>::Value, 102 "Template parameter of TernaryInput must be a floating point type."); 103 104 using Type = T; 105 T x, y, z; 106 }; 107 108 template <typename T> struct BinaryOutput { 109 T f; 110 int i; 111 }; 112 113 namespace internal { 114 115 template <typename T1, typename T2> 116 struct AreMatchingBinaryInputAndBinaryOutput { 117 static constexpr bool VALUE = false; 118 }; 119 120 template <typename T> 121 struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> { 122 static constexpr bool VALUE = cpp::IsFloatingPointType<T>::Value; 123 }; 124 125 template <typename T> 126 bool compare_unary_operation_single_output(Operation op, T input, T libc_output, 127 double ulp_tolerance, 128 RoundingMode rounding); 129 template <typename T> 130 bool compare_unary_operation_two_outputs(Operation op, T input, 131 const BinaryOutput<T> &libc_output, 132 double ulp_tolerance, 133 RoundingMode rounding); 134 template <typename T> 135 bool compare_binary_operation_two_outputs(Operation op, 136 const BinaryInput<T> &input, 137 const BinaryOutput<T> &libc_output, 138 double ulp_tolerance, 139 RoundingMode rounding); 140 141 template <typename T> 142 bool compare_binary_operation_one_output(Operation op, 143 const BinaryInput<T> &input, 144 T libc_output, double ulp_tolerance, 145 RoundingMode rounding); 146 147 template <typename T> 148 bool compare_ternary_operation_one_output(Operation op, 149 const TernaryInput<T> &input, 150 T libc_output, double ulp_tolerance, 151 RoundingMode rounding); 152 153 template <typename T> 154 void explain_unary_operation_single_output_error(Operation op, T input, 155 T match_value, 156 double ulp_tolerance, 157 RoundingMode rounding, 158 testutils::StreamWrapper &OS); 159 template <typename T> 160 void explain_unary_operation_two_outputs_error( 161 Operation op, T input, const BinaryOutput<T> &match_value, 162 double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS); 163 template <typename T> 164 void explain_binary_operation_two_outputs_error( 165 Operation op, const BinaryInput<T> &input, 166 const BinaryOutput<T> &match_value, double ulp_tolerance, 167 RoundingMode rounding, testutils::StreamWrapper &OS); 168 169 template <typename T> 170 void explain_binary_operation_one_output_error( 171 Operation op, const BinaryInput<T> &input, T match_value, 172 double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS); 173 174 template <typename T> 175 void explain_ternary_operation_one_output_error( 176 Operation op, const TernaryInput<T> &input, T match_value, 177 double ulp_tolerance, RoundingMode rounding, testutils::StreamWrapper &OS); 178 179 template <Operation op, typename InputType, typename OutputType> 180 class MPFRMatcher : public testing::Matcher<OutputType> { 181 InputType input; 182 OutputType match_value; 183 double ulp_tolerance; 184 RoundingMode rounding; 185 186 public: 187 MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding) 188 : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {} 189 190 bool match(OutputType libcResult) { 191 match_value = libcResult; 192 return match(input, match_value); 193 } 194 195 // This method is marked with NOLINT because it the name `explainError` 196 // does not confirm to the coding style. 197 void explainError(testutils::StreamWrapper &OS) override { // NOLINT 198 explain_error(input, match_value, OS); 199 } 200 201 private: 202 template <typename T> bool match(T in, T out) { 203 return compare_unary_operation_single_output(op, in, out, ulp_tolerance, 204 rounding); 205 } 206 207 template <typename T> bool match(T in, const BinaryOutput<T> &out) { 208 return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance, 209 rounding); 210 } 211 212 template <typename T> bool match(const BinaryInput<T> &in, T out) { 213 return compare_binary_operation_one_output(op, in, out, ulp_tolerance, 214 rounding); 215 } 216 217 template <typename T> 218 bool match(BinaryInput<T> in, const BinaryOutput<T> &out) { 219 return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance, 220 rounding); 221 } 222 223 template <typename T> bool match(const TernaryInput<T> &in, T out) { 224 return compare_ternary_operation_one_output(op, in, out, ulp_tolerance, 225 rounding); 226 } 227 228 template <typename T> 229 void explain_error(T in, T out, testutils::StreamWrapper &OS) { 230 explain_unary_operation_single_output_error(op, in, out, ulp_tolerance, 231 rounding, OS); 232 } 233 234 template <typename T> 235 void explain_error(T in, const BinaryOutput<T> &out, 236 testutils::StreamWrapper &OS) { 237 explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance, 238 rounding, OS); 239 } 240 241 template <typename T> 242 void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out, 243 testutils::StreamWrapper &OS) { 244 explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance, 245 rounding, OS); 246 } 247 248 template <typename T> 249 void explain_error(const BinaryInput<T> &in, T out, 250 testutils::StreamWrapper &OS) { 251 explain_binary_operation_one_output_error(op, in, out, ulp_tolerance, 252 rounding, OS); 253 } 254 255 template <typename T> 256 void explain_error(const TernaryInput<T> &in, T out, 257 testutils::StreamWrapper &OS) { 258 explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance, 259 rounding, OS); 260 } 261 }; 262 263 } // namespace internal 264 265 // Return true if the input and ouput types for the operation op are valid 266 // types. 267 template <Operation op, typename InputType, typename OutputType> 268 constexpr bool is_valid_operation() { 269 return (Operation::BeginUnaryOperationsSingleOutput < op && 270 op < Operation::EndUnaryOperationsSingleOutput && 271 cpp::IsSame<InputType, OutputType>::Value && 272 cpp::IsFloatingPointType<InputType>::Value) || 273 (Operation::BeginUnaryOperationsTwoOutputs < op && 274 op < Operation::EndUnaryOperationsTwoOutputs && 275 cpp::IsFloatingPointType<InputType>::Value && 276 cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) || 277 (Operation::BeginBinaryOperationsSingleOutput < op && 278 op < Operation::EndBinaryOperationsSingleOutput && 279 cpp::IsFloatingPointType<OutputType>::Value && 280 cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) || 281 (Operation::BeginBinaryOperationsTwoOutputs < op && 282 op < Operation::EndBinaryOperationsTwoOutputs && 283 internal::AreMatchingBinaryInputAndBinaryOutput<InputType, 284 OutputType>::VALUE) || 285 (Operation::BeginTernaryOperationsSingleOuput < op && 286 op < Operation::EndTernaryOperationsSingleOutput && 287 cpp::IsFloatingPointType<OutputType>::Value && 288 cpp::IsSame<InputType, TernaryInput<OutputType>>::Value); 289 } 290 291 template <Operation op, typename InputType, typename OutputType> 292 __attribute__((no_sanitize("address"))) 293 cpp::EnableIfType<is_valid_operation<op, InputType, OutputType>(), 294 internal::MPFRMatcher<op, InputType, OutputType>> 295 get_mpfr_matcher(InputType input, OutputType output_unused, 296 double ulp_tolerance, RoundingMode rounding) { 297 return internal::MPFRMatcher<op, InputType, OutputType>(input, ulp_tolerance, 298 rounding); 299 } 300 301 template <typename T> T round(T x, RoundingMode mode); 302 303 template <typename T> bool round_to_long(T x, long &result); 304 template <typename T> bool round_to_long(T x, RoundingMode mode, long &result); 305 306 } // namespace mpfr 307 } // namespace testing 308 } // namespace __llvm_libc 309 310 // GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a 311 // simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`. 312 #define GET_MPFR_DUMMY_ARG(...) 0 313 314 #define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME 315 316 #define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ 317 EXPECT_THAT(match_value, \ 318 __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \ 319 input, match_value, ulp_tolerance, \ 320 __llvm_libc::testing::mpfr::RoundingMode::Nearest)) 321 322 #define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ 323 rounding) \ 324 EXPECT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \ 325 input, match_value, ulp_tolerance, rounding)) 326 327 #define EXPECT_MPFR_MATCH(...) \ 328 GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING, \ 329 EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \ 330 (__VA_ARGS__) 331 332 #define EXPECT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \ 333 { \ 334 namespace mpfr = __llvm_libc::testing::mpfr; \ 335 mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest); \ 336 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 337 mpfr::RoundingMode::Nearest); \ 338 mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward); \ 339 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 340 mpfr::RoundingMode::Upward); \ 341 mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward); \ 342 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 343 mpfr::RoundingMode::Downward); \ 344 mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero); \ 345 EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 346 mpfr::RoundingMode::TowardZero); \ 347 } 348 349 #define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance) \ 350 ASSERT_THAT(match_value, \ 351 __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \ 352 input, match_value, ulp_tolerance, \ 353 __llvm_libc::testing::mpfr::RoundingMode::Nearest)) 354 355 #define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance, \ 356 rounding) \ 357 ASSERT_THAT(match_value, __llvm_libc::testing::mpfr::get_mpfr_matcher<op>( \ 358 input, match_value, ulp_tolerance, rounding)) 359 360 #define ASSERT_MPFR_MATCH(...) \ 361 GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING, \ 362 ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG) \ 363 (__VA_ARGS__) 364 365 #define ASSERT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance) \ 366 { \ 367 namespace mpfr = __llvm_libc::testing::mpfr; \ 368 mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest); \ 369 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 370 mpfr::RoundingMode::Nearest); \ 371 mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward); \ 372 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 373 mpfr::RoundingMode::Upward); \ 374 mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward); \ 375 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 376 mpfr::RoundingMode::Downward); \ 377 mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero); \ 378 ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance, \ 379 mpfr::RoundingMode::TowardZero); \ 380 } 381 382 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H 383