1 //===-- Utility class to test different flavors of fma --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
10 #define LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
11 
12 #include "utils/FPUtil/FPBits.h"
13 #include "utils/FPUtil/TestHelpers.h"
14 #include "utils/MPFRWrapper/MPFRUtils.h"
15 #include "utils/UnitTest/Test.h"
16 #include "utils/testutils/RandUtils.h"
17 
18 namespace mpfr = __llvm_libc::testing::mpfr;
19 
20 template <typename T>
21 class FmaTestTemplate : public __llvm_libc::testing::Test {
22 private:
23   using Func = T (*)(T, T, T);
24   using FPBits = __llvm_libc::fputil::FPBits<T>;
25   using UIntType = typename FPBits::UIntType;
26   const T nan = __llvm_libc::fputil::FPBits<T>::buildNaN(1);
27   const T inf = __llvm_libc::fputil::FPBits<T>::inf();
28   const T negInf = __llvm_libc::fputil::FPBits<T>::negInf();
29   const T zero = __llvm_libc::fputil::FPBits<T>::zero();
30   const T negZero = __llvm_libc::fputil::FPBits<T>::negZero();
31 
32   UIntType getRandomBitPattern() {
33     UIntType bits{0};
34     for (UIntType i = 0; i < sizeof(UIntType) / 2; ++i) {
35       bits =
36           (bits << 2) + static_cast<uint16_t>(__llvm_libc::testutils::rand());
37     }
38     return bits;
39   }
40 
41 public:
42   void testSpecialNumbers(Func func) {
43     EXPECT_FP_EQ(func(zero, zero, zero), zero);
44     EXPECT_FP_EQ(func(zero, negZero, negZero), negZero);
45     EXPECT_FP_EQ(func(inf, inf, zero), inf);
46     EXPECT_FP_EQ(func(negInf, inf, negInf), negInf);
47     EXPECT_FP_EQ(func(inf, zero, zero), nan);
48     EXPECT_FP_EQ(func(inf, negInf, inf), nan);
49     EXPECT_FP_EQ(func(nan, zero, inf), nan);
50     EXPECT_FP_EQ(func(inf, negInf, nan), nan);
51 
52     // Test underflow rounding up.
53     EXPECT_FP_EQ(func(T(0.5), FPBits(FPBits::minSubnormal),
54                       FPBits(FPBits::minSubnormal)),
55                  FPBits(UIntType(2)));
56     // Test underflow rounding down.
57     FPBits v(FPBits::minNormal + UIntType(1));
58     EXPECT_FP_EQ(
59         func(T(1) / T(FPBits::minNormal << 1), v, FPBits(FPBits::minNormal)),
60         v);
61     // Test overflow.
62     FPBits z(FPBits::maxNormal);
63     EXPECT_FP_EQ(func(T(1.75), z, -z), T(0.75) * z);
64   }
65 
66   void testSubnormalRange(Func func) {
67     constexpr UIntType count = 1000001;
68     constexpr UIntType step =
69         (FPBits::maxSubnormal - FPBits::minSubnormal) / count;
70     for (UIntType v = FPBits::minSubnormal, w = FPBits::maxSubnormal;
71          v <= FPBits::maxSubnormal && w >= FPBits::minSubnormal;
72          v += step, w -= step) {
73       T x = FPBits(getRandomBitPattern()), y = FPBits(v), z = FPBits(w);
74       T result = func(x, y, z);
75       mpfr::TernaryInput<T> input{x, y, z};
76       ASSERT_MPFR_MATCH(mpfr::Operation::Fma, input, result, 0.5);
77     }
78   }
79 
80   void testNormalRange(Func func) {
81     constexpr UIntType count = 1000001;
82     constexpr UIntType step = (FPBits::maxNormal - FPBits::minNormal) / count;
83     for (UIntType v = FPBits::minNormal, w = FPBits::maxNormal;
84          v <= FPBits::maxNormal && w >= FPBits::minNormal;
85          v += step, w -= step) {
86       T x = FPBits(v), y = FPBits(w), z = FPBits(getRandomBitPattern());
87       T result = func(x, y, z);
88       mpfr::TernaryInput<T> input{x, y, z};
89       ASSERT_MPFR_MATCH(mpfr::Operation::Fma, input, result, 0.5);
90     }
91   }
92 };
93 
94 #endif // LLVM_LIBC_TEST_SRC_MATH_FMATEST_H
95