1 //===-- Utility class to test different flavors of rint ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_TEST_SRC_MATH_RINTTEST_H
10 #define LLVM_LIBC_TEST_SRC_MATH_RINTTEST_H
11 
12 #include "src/__support/FPUtil/FEnvImpl.h"
13 #include "src/__support/FPUtil/FPBits.h"
14 #include "utils/MPFRWrapper/MPFRUtils.h"
15 #include "utils/UnitTest/FPMatcher.h"
16 #include "utils/UnitTest/Test.h"
17 
18 #include <fenv.h>
19 #include <math.h>
20 #include <stdio.h>
21 
22 namespace mpfr = __llvm_libc::testing::mpfr;
23 
24 static constexpr int ROUNDING_MODES[4] = {FE_UPWARD, FE_DOWNWARD, FE_TOWARDZERO,
25                                           FE_TONEAREST};
26 
27 template <typename T>
28 class RIntTestTemplate : public __llvm_libc::testing::Test {
29 public:
30   typedef T (*RIntFunc)(T);
31 
32 private:
33   using FPBits = __llvm_libc::fputil::FPBits<T>;
34   using UIntType = typename FPBits::UIntType;
35 
36   const T zero = T(FPBits::zero());
37   const T neg_zero = T(FPBits::neg_zero());
38   const T inf = T(FPBits::inf());
39   const T neg_inf = T(FPBits::neg_inf());
40   const T nan = T(FPBits::build_nan(1));
41 
to_mpfr_rounding_mode(int mode)42   static inline mpfr::RoundingMode to_mpfr_rounding_mode(int mode) {
43     switch (mode) {
44     case FE_UPWARD:
45       return mpfr::RoundingMode::Upward;
46     case FE_DOWNWARD:
47       return mpfr::RoundingMode::Downward;
48     case FE_TOWARDZERO:
49       return mpfr::RoundingMode::TowardZero;
50     case FE_TONEAREST:
51       return mpfr::RoundingMode::Nearest;
52     default:
53       __builtin_unreachable();
54     }
55   }
56 
57 public:
testSpecialNumbers(RIntFunc func)58   void testSpecialNumbers(RIntFunc func) {
59     for (int mode : ROUNDING_MODES) {
60       __llvm_libc::fputil::set_round(mode);
61       ASSERT_FP_EQ(inf, func(inf));
62       ASSERT_FP_EQ(neg_inf, func(neg_inf));
63       ASSERT_FP_EQ(nan, func(nan));
64       ASSERT_FP_EQ(zero, func(zero));
65       ASSERT_FP_EQ(neg_zero, func(neg_zero));
66     }
67   }
68 
testRoundNumbers(RIntFunc func)69   void testRoundNumbers(RIntFunc func) {
70     for (int mode : ROUNDING_MODES) {
71       __llvm_libc::fputil::set_round(mode);
72       mpfr::RoundingMode mpfr_mode = to_mpfr_rounding_mode(mode);
73       ASSERT_FP_EQ(func(T(1.0)), mpfr::round(T(1.0), mpfr_mode));
74       ASSERT_FP_EQ(func(T(-1.0)), mpfr::round(T(-1.0), mpfr_mode));
75       ASSERT_FP_EQ(func(T(10.0)), mpfr::round(T(10.0), mpfr_mode));
76       ASSERT_FP_EQ(func(T(-10.0)), mpfr::round(T(-10.0), mpfr_mode));
77       ASSERT_FP_EQ(func(T(1234.0)), mpfr::round(T(1234.0), mpfr_mode));
78       ASSERT_FP_EQ(func(T(-1234.0)), mpfr::round(T(-1234.0), mpfr_mode));
79     }
80   }
81 
testFractions(RIntFunc func)82   void testFractions(RIntFunc func) {
83     for (int mode : ROUNDING_MODES) {
84       __llvm_libc::fputil::set_round(mode);
85       mpfr::RoundingMode mpfr_mode = to_mpfr_rounding_mode(mode);
86       ASSERT_FP_EQ(func(T(0.5)), mpfr::round(T(0.5), mpfr_mode));
87       ASSERT_FP_EQ(func(T(-0.5)), mpfr::round(T(-0.5), mpfr_mode));
88       ASSERT_FP_EQ(func(T(0.115)), mpfr::round(T(0.115), mpfr_mode));
89       ASSERT_FP_EQ(func(T(-0.115)), mpfr::round(T(-0.115), mpfr_mode));
90       ASSERT_FP_EQ(func(T(0.715)), mpfr::round(T(0.715), mpfr_mode));
91       ASSERT_FP_EQ(func(T(-0.715)), mpfr::round(T(-0.715), mpfr_mode));
92     }
93   }
94 
testSubnormalRange(RIntFunc func)95   void testSubnormalRange(RIntFunc func) {
96     constexpr UIntType COUNT = 1000001;
97     constexpr UIntType STEP =
98         (FPBits::MAX_SUBNORMAL - FPBits::MIN_SUBNORMAL) / COUNT;
99     for (UIntType i = FPBits::MIN_SUBNORMAL; i <= FPBits::MAX_SUBNORMAL;
100          i += STEP) {
101       T x = T(FPBits(i));
102       for (int mode : ROUNDING_MODES) {
103         __llvm_libc::fputil::set_round(mode);
104         mpfr::RoundingMode mpfr_mode = to_mpfr_rounding_mode(mode);
105         ASSERT_FP_EQ(func(x), mpfr::round(x, mpfr_mode));
106       }
107     }
108   }
109 
testNormalRange(RIntFunc func)110   void testNormalRange(RIntFunc func) {
111     constexpr UIntType COUNT = 1000001;
112     constexpr UIntType STEP = (FPBits::MAX_NORMAL - FPBits::MIN_NORMAL) / COUNT;
113     for (UIntType i = FPBits::MIN_NORMAL; i <= FPBits::MAX_NORMAL; i += STEP) {
114       T x = T(FPBits(i));
115       // In normal range on x86 platforms, the long double implicit 1 bit can be
116       // zero making the numbers NaN. We will skip them.
117       if (isnan(x)) {
118         continue;
119       }
120 
121       for (int mode : ROUNDING_MODES) {
122         __llvm_libc::fputil::set_round(mode);
123         mpfr::RoundingMode mpfr_mode = to_mpfr_rounding_mode(mode);
124         ASSERT_FP_EQ(func(x), mpfr::round(x, mpfr_mode));
125       }
126     }
127   }
128 };
129 
130 #define LIST_RINT_TESTS(F, func)                                               \
131   using LlvmLibcRIntTest = RIntTestTemplate<F>;                                \
132   TEST_F(LlvmLibcRIntTest, specialNumbers) { testSpecialNumbers(&func); }      \
133   TEST_F(LlvmLibcRIntTest, RoundNumbers) { testRoundNumbers(&func); }          \
134   TEST_F(LlvmLibcRIntTest, Fractions) { testFractions(&func); }                \
135   TEST_F(LlvmLibcRIntTest, SubnormalRange) { testSubnormalRange(&func); }      \
136   TEST_F(LlvmLibcRIntTest, NormalRange) { testNormalRange(&func); }
137 
138 #endif // LLVM_LIBC_TEST_SRC_MATH_RINTTEST_H
139