1 //===-- Utility class to test sqrt[f|l] -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "utils/MPFRWrapper/MPFRUtils.h"
10 #include "utils/UnitTest/FPMatcher.h"
11 #include "utils/UnitTest/Test.h"
12 
13 #include <math.h>
14 
15 namespace mpfr = __llvm_libc::testing::mpfr;
16 
17 template <typename T> class SqrtTest : public __llvm_libc::testing::Test {
18 
19   DECLARE_SPECIAL_CONSTANTS(T)
20 
21   static constexpr UIntType HIDDEN_BIT =
22       UIntType(1) << __llvm_libc::fputil::MantissaWidth<T>::VALUE;
23 
24 public:
25   typedef T (*SqrtFunc)(T);
26 
27   void test_special_numbers(SqrtFunc func) {
28     ASSERT_FP_EQ(aNaN, func(aNaN));
29     ASSERT_FP_EQ(inf, func(inf));
30     ASSERT_FP_EQ(aNaN, func(neg_inf));
31     ASSERT_FP_EQ(0.0, func(0.0));
32     ASSERT_FP_EQ(-0.0, func(-0.0));
33     ASSERT_FP_EQ(aNaN, func(T(-1.0)));
34     ASSERT_FP_EQ(T(1.0), func(T(1.0)));
35     ASSERT_FP_EQ(T(2.0), func(T(4.0)));
36     ASSERT_FP_EQ(T(3.0), func(T(9.0)));
37   }
38 
39   void test_denormal_values(SqrtFunc func) {
40     for (UIntType mant = 1; mant < HIDDEN_BIT; mant <<= 1) {
41       FPBits denormal(T(0.0));
42       denormal.set_mantissa(mant);
43 
44       test_all_rounding_modes(func, T(denormal));
45     }
46 
47     constexpr UIntType COUNT = 1'000'001;
48     constexpr UIntType STEP = HIDDEN_BIT / COUNT;
49     for (UIntType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
50       T x = *reinterpret_cast<T *>(&v);
51       test_all_rounding_modes(func, x);
52     }
53   }
54 
55   void test_normal_range(SqrtFunc func) {
56     constexpr UIntType COUNT = 10'000'001;
57     constexpr UIntType STEP = UIntType(-1) / COUNT;
58     for (UIntType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
59       T x = *reinterpret_cast<T *>(&v);
60       if (isnan(x) || (x < 0)) {
61         continue;
62       }
63       test_all_rounding_modes(func, x);
64     }
65   }
66 
67   void test_all_rounding_modes(SqrtFunc func, T x) {
68     mpfr::ForceRoundingMode r1(mpfr::RoundingMode::Nearest);
69     EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
70                       mpfr::RoundingMode::Nearest);
71 
72     mpfr::ForceRoundingMode r2(mpfr::RoundingMode::Upward);
73     EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
74                       mpfr::RoundingMode::Upward);
75 
76     mpfr::ForceRoundingMode r3(mpfr::RoundingMode::Downward);
77     EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
78                       mpfr::RoundingMode::Downward);
79 
80     mpfr::ForceRoundingMode r4(mpfr::RoundingMode::TowardZero);
81     EXPECT_MPFR_MATCH(mpfr::Operation::Sqrt, x, func(x), 0.5,
82                       mpfr::RoundingMode::TowardZero);
83   }
84 };
85 
86 #define LIST_SQRT_TESTS(T, func)                                               \
87   using LlvmLibcSqrtTest = SqrtTest<T>;                                        \
88   TEST_F(LlvmLibcSqrtTest, SpecialNumbers) { test_special_numbers(&func); }    \
89   TEST_F(LlvmLibcSqrtTest, DenormalValues) { test_denormal_values(&func); }    \
90   TEST_F(LlvmLibcSqrtTest, NormalRange) { test_normal_range(&func); }
91