1 //===-- Utils which wrap MPFR ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MPFRUtils.h"
10 
11 #include "utils/FPUtil/FPBits.h"
12 
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringRef.h"
15 
16 #include <mpfr.h>
17 #include <stdint.h>
18 #include <string>
19 
20 template <typename T> using FPBits = __llvm_libc::fputil::FPBits<T>;
21 
22 namespace __llvm_libc {
23 namespace testing {
24 namespace mpfr {
25 
26 class MPFRNumber {
27   // A precision value which allows sufficiently large additional
28   // precision even compared to quad-precision floating point values.
29   static constexpr unsigned int mpfrPrecision = 128;
30 
31   mpfr_t value;
32 
33 public:
34   MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
35 
36   // We use explicit EnableIf specializations to disallow implicit
37   // conversions. Implicit conversions can potentially lead to loss of
38   // precision.
39   template <typename XType,
40             cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
41   explicit MPFRNumber(XType x) {
42     mpfr_init2(value, mpfrPrecision);
43     mpfr_set_flt(value, x, MPFR_RNDN);
44   }
45 
46   template <typename XType,
47             cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
48   explicit MPFRNumber(XType x) {
49     mpfr_init2(value, mpfrPrecision);
50     mpfr_set_d(value, x, MPFR_RNDN);
51   }
52 
53   template <typename XType,
54             cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
55   explicit MPFRNumber(XType x) {
56     mpfr_init2(value, mpfrPrecision);
57     mpfr_set_ld(value, x, MPFR_RNDN);
58   }
59 
60   template <typename XType,
61             cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
62   explicit MPFRNumber(XType x) {
63     mpfr_init2(value, mpfrPrecision);
64     mpfr_set_sj(value, x, MPFR_RNDN);
65   }
66 
67   template <typename XType> MPFRNumber(XType x, const Tolerance &t) {
68     mpfr_init2(value, mpfrPrecision);
69     mpfr_set_zero(value, 1); // Set to positive zero.
70     MPFRNumber xExponent(fputil::FPBits<XType>(x).getExponent());
71     // E = 2^E
72     mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN);
73     uint32_t bitMask = 1 << (t.width - 1);
74     for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) {
75       --n;
76       if (t.bits & bitMask) {
77         // delta = -n
78         MPFRNumber delta(n);
79 
80         // delta = 2^(-n)
81         mpfr_exp2(delta.value, delta.value, MPFR_RNDN);
82 
83         // delta = E * 2^(-n)
84         mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN);
85 
86         // tolerance += delta
87         mpfr_add(value, value, delta.value, MPFR_RNDN);
88       }
89     }
90   }
91 
92   template <typename XType,
93             cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0>
94   MPFRNumber(Operation op, XType rawValue) {
95     mpfr_init2(value, mpfrPrecision);
96     MPFRNumber mpfrInput(rawValue);
97     switch (op) {
98     case Operation::Abs:
99       mpfr_abs(value, mpfrInput.value, MPFR_RNDN);
100       break;
101     case Operation::Ceil:
102       mpfr_ceil(value, mpfrInput.value);
103       break;
104     case Operation::Cos:
105       mpfr_cos(value, mpfrInput.value, MPFR_RNDN);
106       break;
107     case Operation::Exp:
108       mpfr_exp(value, mpfrInput.value, MPFR_RNDN);
109       break;
110     case Operation::Exp2:
111       mpfr_exp2(value, mpfrInput.value, MPFR_RNDN);
112       break;
113     case Operation::Floor:
114       mpfr_floor(value, mpfrInput.value);
115       break;
116     case Operation::Round:
117       mpfr_round(value, mpfrInput.value);
118       break;
119     case Operation::Sin:
120       mpfr_sin(value, mpfrInput.value, MPFR_RNDN);
121       break;
122     case Operation::Trunc:
123       mpfr_trunc(value, mpfrInput.value);
124       break;
125     }
126   }
127 
128   MPFRNumber(const MPFRNumber &other) {
129     mpfr_set(value, other.value, MPFR_RNDN);
130   }
131 
132   ~MPFRNumber() { mpfr_clear(value); }
133 
134   // Returns true if |other| is within the |tolerance| value of this
135   // number.
136   bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const {
137     MPFRNumber difference;
138     if (mpfr_cmp(value, other.value) >= 0)
139       mpfr_sub(difference.value, value, other.value, MPFR_RNDN);
140     else
141       mpfr_sub(difference.value, other.value, value, MPFR_RNDN);
142 
143     return mpfr_lessequal_p(difference.value, tolerance.value);
144   }
145 
146   std::string str() const {
147     // 200 bytes should be more than sufficient to hold a 100-digit number
148     // plus additional bytes for the decimal point, '-' sign etc.
149     constexpr size_t printBufSize = 200;
150     char buffer[printBufSize];
151     mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value);
152     llvm::StringRef ref(buffer);
153     ref = ref.trim();
154     return ref.str();
155   }
156 
157   // These functions are useful for debugging.
158   float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); }
159   double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); }
160   void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
161 };
162 
163 namespace internal {
164 
165 template <typename T>
166 void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
167   MPFRNumber mpfrResult(operation, input);
168   MPFRNumber mpfrInput(input);
169   MPFRNumber mpfrMatchValue(matchValue);
170   MPFRNumber mpfrToleranceValue(matchValue, tolerance);
171   FPBits<T> inputBits(input);
172   FPBits<T> matchBits(matchValue);
173   // TODO: Call to llvm::utohexstr implicitly converts __uint128_t values to
174   // uint64_t values. This can be fixed using a custom wrapper for
175   // llvm::utohexstr to handle __uint128_t values correctly.
176   OS << "Match value not within tolerance value of MPFR result:\n"
177      << "  Input decimal: " << mpfrInput.str() << '\n'
178      << "     Input bits: 0x" << llvm::utohexstr(inputBits.bitsAsUInt()) << '\n'
179      << "  Match decimal: " << mpfrMatchValue.str() << '\n'
180      << "     Match bits: 0x" << llvm::utohexstr(matchBits.bitsAsUInt()) << '\n'
181      << "    MPFR result: " << mpfrResult.str() << '\n'
182      << "Tolerance value: " << mpfrToleranceValue.str() << '\n';
183 }
184 
185 template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &);
186 template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &);
187 template void
188 MPFRMatcher<long double>::explainError(testutils::StreamWrapper &);
189 
190 template <typename T>
191 bool compare(Operation op, T input, T libcResult, const Tolerance &t) {
192   MPFRNumber mpfrResult(op, input);
193   MPFRNumber mpfrLibcResult(libcResult);
194   MPFRNumber mpfrToleranceValue(libcResult, t);
195 
196   return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue);
197 };
198 
199 template bool compare<float>(Operation, float, float, const Tolerance &);
200 template bool compare<double>(Operation, double, double, const Tolerance &);
201 template bool compare<long double>(Operation, long double, long double,
202                                    const Tolerance &);
203 
204 } // namespace internal
205 
206 } // namespace mpfr
207 } // namespace testing
208 } // namespace __llvm_libc
209