1 //===-- Utils which wrap MPFR ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MPFRUtils.h"
10 
11 #include "utils/FPUtil/FloatOperations.h"
12 
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringRef.h"
15 
16 #include <mpfr.h>
17 #include <stdint.h>
18 #include <string>
19 
20 namespace __llvm_libc {
21 namespace testing {
22 namespace mpfr {
23 
24 class MPFRNumber {
25   // A precision value which allows sufficiently large additional
26   // precision even compared to double precision floating point values.
27   static constexpr unsigned int mpfrPrecision = 96;
28 
29   mpfr_t value;
30 
31 public:
32   MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
33 
34   // We use explicit EnableIf specializations to disallow implicit
35   // conversions. Implicit conversions can potentially lead to loss of
36   // precision.
37   template <typename XType,
38             cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
39   explicit MPFRNumber(XType x) {
40     mpfr_init2(value, mpfrPrecision);
41     mpfr_set_flt(value, x, MPFR_RNDN);
42   }
43 
44   template <typename XType,
45             cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
46   explicit MPFRNumber(XType x) {
47     mpfr_init2(value, mpfrPrecision);
48     mpfr_set_d(value, x, MPFR_RNDN);
49   }
50 
51   template <typename XType,
52             cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
53   explicit MPFRNumber(XType x) {
54     mpfr_init2(value, mpfrPrecision);
55     mpfr_set_sj(value, x, MPFR_RNDN);
56   }
57 
58   template <typename XType> MPFRNumber(XType x, const Tolerance &t) {
59     mpfr_init2(value, mpfrPrecision);
60     mpfr_set_zero(value, 1); // Set to positive zero.
61     MPFRNumber xExponent(fputil::getExponent(x));
62     // E = 2^E
63     mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN);
64     uint32_t bitMask = 1 << (t.width - 1);
65     for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) {
66       --n;
67       if (t.bits & bitMask) {
68         // delta = -n
69         MPFRNumber delta(n);
70 
71         // delta = 2^(-n)
72         mpfr_exp2(delta.value, delta.value, MPFR_RNDN);
73 
74         // delta = E * 2^(-n)
75         mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN);
76 
77         // tolerance += delta
78         mpfr_add(value, value, delta.value, MPFR_RNDN);
79       }
80     }
81   }
82 
83   template <typename XType,
84             cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0>
85   MPFRNumber(Operation op, XType rawValue) {
86     mpfr_init2(value, mpfrPrecision);
87     MPFRNumber mpfrInput(rawValue);
88     switch (op) {
89     case Operation::Abs:
90       mpfr_abs(value, mpfrInput.value, MPFR_RNDN);
91       break;
92     case Operation::Ceil:
93       mpfr_ceil(value, mpfrInput.value);
94       break;
95     case Operation::Cos:
96       mpfr_cos(value, mpfrInput.value, MPFR_RNDN);
97       break;
98     case Operation::Exp:
99       mpfr_exp(value, mpfrInput.value, MPFR_RNDN);
100       break;
101     case Operation::Exp2:
102       mpfr_exp2(value, mpfrInput.value, MPFR_RNDN);
103       break;
104     case Operation::Floor:
105       mpfr_floor(value, mpfrInput.value);
106       break;
107     case Operation::Round:
108       mpfr_round(value, mpfrInput.value);
109       break;
110     case Operation::Sin:
111       mpfr_sin(value, mpfrInput.value, MPFR_RNDN);
112       break;
113     case Operation::Trunc:
114       mpfr_trunc(value, mpfrInput.value);
115       break;
116     }
117   }
118 
119   MPFRNumber(const MPFRNumber &other) {
120     mpfr_set(value, other.value, MPFR_RNDN);
121   }
122 
123   ~MPFRNumber() { mpfr_clear(value); }
124 
125   // Returns true if |other| is within the |tolerance| value of this
126   // number.
127   bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const {
128     MPFRNumber difference;
129     if (mpfr_cmp(value, other.value) >= 0)
130       mpfr_sub(difference.value, value, other.value, MPFR_RNDN);
131     else
132       mpfr_sub(difference.value, other.value, value, MPFR_RNDN);
133 
134     return mpfr_lessequal_p(difference.value, tolerance.value);
135   }
136 
137   std::string str() const {
138     // 200 bytes should be more than sufficient to hold a 100-digit number
139     // plus additional bytes for the decimal point, '-' sign etc.
140     constexpr size_t printBufSize = 200;
141     char buffer[printBufSize];
142     mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value);
143     llvm::StringRef ref(buffer);
144     ref = ref.trim();
145     return ref.str();
146   }
147 
148   // These functions are useful for debugging.
149   float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); }
150   double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); }
151   void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
152 };
153 
154 namespace internal {
155 
156 template <typename T>
157 void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
158   using fputil::valueAsBits;
159 
160   MPFRNumber mpfrResult(operation, input);
161   MPFRNumber mpfrInput(input);
162   MPFRNumber mpfrMatchValue(matchValue);
163   MPFRNumber mpfrToleranceValue(matchValue, tolerance);
164   OS << "Match value not within tolerance value of MPFR result:\n"
165      << "  Input decimal: " << mpfrInput.str() << '\n'
166      << "     Input bits: 0x" << llvm::utohexstr(valueAsBits(input)) << '\n'
167      << "  Match decimal: " << mpfrMatchValue.str() << '\n'
168      << "     Match bits: 0x" << llvm::utohexstr(valueAsBits(matchValue))
169      << '\n'
170      << "    MPFR result: " << mpfrResult.str() << '\n'
171      << "Tolerance value: " << mpfrToleranceValue.str() << '\n';
172 }
173 
174 template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &);
175 template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &);
176 
177 template <typename T>
178 bool compare(Operation op, T input, T libcResult, const Tolerance &t) {
179   MPFRNumber mpfrResult(op, input);
180   MPFRNumber mpfrLibcResult(libcResult);
181   MPFRNumber mpfrToleranceValue(libcResult, t);
182 
183   return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue);
184 };
185 
186 template bool compare<float>(Operation, float, float, const Tolerance &);
187 template bool compare<double>(Operation, double, double, const Tolerance &);
188 
189 } // namespace internal
190 
191 } // namespace mpfr
192 } // namespace testing
193 } // namespace __llvm_libc
194