1 //===-- Utils which wrap MPFR ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MPFRUtils.h"
10 
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/StringRef.h"
13 
14 #include <mpfr.h>
15 #include <stdint.h>
16 #include <string>
17 
18 namespace __llvm_libc {
19 namespace testing {
20 namespace mpfr {
21 
22 template <typename T> struct FloatProperties {};
23 
24 template <> struct FloatProperties<float> {
25   typedef uint32_t BitsType;
26   static_assert(sizeof(BitsType) == sizeof(float),
27                 "Unexpected size of 'float' type.");
28 
29   static constexpr uint32_t mantissaWidth = 23;
30   static constexpr BitsType signMask = 0x7FFFFFFFU;
31   static constexpr uint32_t exponentOffset = 127;
32 };
33 
34 template <> struct FloatProperties<double> {
35   typedef uint64_t BitsType;
36   static_assert(sizeof(BitsType) == sizeof(double),
37                 "Unexpected size of 'double' type.");
38 
39   static constexpr uint32_t mantissaWidth = 52;
40   static constexpr BitsType signMask = 0x7FFFFFFFFFFFFFFFULL;
41   static constexpr uint32_t exponentOffset = 1023;
42 };
43 
44 template <typename T> typename FloatProperties<T>::BitsType getBits(T x) {
45   using BitsType = typename FloatProperties<T>::BitsType;
46   return *reinterpret_cast<BitsType *>(&x);
47 }
48 
49 // Returns the zero adjusted exponent value of abs(x).
50 template <typename T> int getExponent(T x) {
51   using Properties = FloatProperties<T>;
52   using BitsType = typename Properties::BitsType;
53   BitsType bits = *reinterpret_cast<BitsType *>(&x);
54   bits &= Properties::signMask;                // Zero the sign bit.
55   int e = (bits >> Properties::mantissaWidth); // Shift out the mantissa.
56   e -= Properties::exponentOffset;             // Zero adjust.
57   return e;
58 }
59 
60 class MPFRNumber {
61   // A precision value which allows sufficiently large additional
62   // precision even compared to double precision floating point values.
63   static constexpr unsigned int mpfrPrecision = 96;
64 
65   mpfr_t value;
66 
67 public:
68   MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
69 
70   // We use explicit EnableIf specializations to disallow implicit
71   // conversions. Implicit conversions can potentially lead to loss of
72   // precision.
73   template <typename XType,
74             cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
75   explicit MPFRNumber(XType x) {
76     mpfr_init2(value, mpfrPrecision);
77     mpfr_set_flt(value, x, MPFR_RNDN);
78   }
79 
80   template <typename XType,
81             cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
82   explicit MPFRNumber(XType x) {
83     mpfr_init2(value, mpfrPrecision);
84     mpfr_set_d(value, x, MPFR_RNDN);
85   }
86 
87   template <typename XType,
88             cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
89   explicit MPFRNumber(XType x) {
90     mpfr_init2(value, mpfrPrecision);
91     mpfr_set_sj(value, x, MPFR_RNDN);
92   }
93 
94   template <typename XType> MPFRNumber(XType x, const Tolerance &t) {
95     mpfr_init2(value, mpfrPrecision);
96     mpfr_set_zero(value, 1); // Set to positive zero.
97     MPFRNumber xExponent(getExponent(x));
98     // E = 2^E
99     mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN);
100     uint32_t bitMask = 1 << (t.width - 1);
101     for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) {
102       --n;
103       if (t.bits & bitMask) {
104         // delta = -n
105         MPFRNumber delta(n);
106 
107         // delta = 2^(-n)
108         mpfr_exp2(delta.value, delta.value, MPFR_RNDN);
109 
110         // delta = E * 2^(-n)
111         mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN);
112 
113         // tolerance += delta
114         mpfr_add(value, value, delta.value, MPFR_RNDN);
115       }
116     }
117   }
118 
119   template <typename XType,
120             cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0>
121   MPFRNumber(Operation op, XType rawValue) {
122     mpfr_init2(value, mpfrPrecision);
123     MPFRNumber mpfrInput(rawValue);
124     switch (op) {
125     case OP_Cos:
126       mpfr_cos(value, mpfrInput.value, MPFR_RNDN);
127       break;
128     case OP_Sin:
129       mpfr_sin(value, mpfrInput.value, MPFR_RNDN);
130       break;
131     }
132   }
133 
134   MPFRNumber(const MPFRNumber &other) {
135     mpfr_set(value, other.value, MPFR_RNDN);
136   }
137 
138   ~MPFRNumber() { mpfr_clear(value); }
139 
140   // Returns true if |other| is within the |tolerance| value of this
141   // number.
142   bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const {
143     MPFRNumber difference;
144     if (mpfr_cmp(value, other.value) >= 0)
145       mpfr_sub(difference.value, value, other.value, MPFR_RNDN);
146     else
147       mpfr_sub(difference.value, other.value, value, MPFR_RNDN);
148 
149     return mpfr_lessequal_p(difference.value, tolerance.value);
150   }
151 
152   std::string str() const {
153     // 200 bytes should be more than sufficient to hold a 100-digit number
154     // plus additional bytes for the decimal point, '-' sign etc.
155     constexpr size_t printBufSize = 200;
156     char buffer[printBufSize];
157     mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value);
158     llvm::StringRef ref(buffer);
159     ref = ref.trim();
160     return ref.str();
161   }
162 
163   // These functions are useful for debugging.
164   float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); }
165   double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); }
166   void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
167 };
168 
169 namespace internal {
170 
171 template <typename T>
172 void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
173   MPFRNumber mpfrResult(operation, input);
174   MPFRNumber mpfrInput(input);
175   MPFRNumber mpfrMatchValue(matchValue);
176   MPFRNumber mpfrToleranceValue(matchValue, tolerance);
177   OS << "Match value not within tolerance value of MPFR result:\n"
178      << "  Input decimal: " << mpfrInput.str() << '\n'
179      << "     Input bits: 0x" << llvm::utohexstr(getBits(input)) << '\n'
180      << "  Match decimal: " << mpfrMatchValue.str() << '\n'
181      << "     Match bits: 0x" << llvm::utohexstr(getBits(matchValue)) << '\n'
182      << "    MPFR result: " << mpfrResult.str() << '\n'
183      << "Tolerance value: " << mpfrToleranceValue.str() << '\n';
184 }
185 
186 template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &);
187 template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &);
188 
189 template <typename T>
190 bool compare(Operation op, T input, T libcResult, const Tolerance &t) {
191   MPFRNumber mpfrResult(op, input);
192   MPFRNumber mpfrLibcResult(libcResult);
193   MPFRNumber mpfrToleranceValue(libcResult, t);
194 
195   return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue);
196 };
197 
198 template bool compare<float>(Operation, float, float, const Tolerance &);
199 template bool compare<double>(Operation, double, double, const Tolerance &);
200 
201 } // namespace internal
202 
203 } // namespace mpfr
204 } // namespace testing
205 } // namespace __llvm_libc
206