1 //===-- Utils which wrap MPFR ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MPFRUtils.h"
10 
11 #include "utils/FPUtil/FloatOperations.h"
12 
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringRef.h"
15 
16 #include <mpfr.h>
17 #include <stdint.h>
18 #include <string>
19 
20 namespace __llvm_libc {
21 namespace testing {
22 namespace mpfr {
23 
24 class MPFRNumber {
25   // A precision value which allows sufficiently large additional
26   // precision even compared to double precision floating point values.
27   static constexpr unsigned int mpfrPrecision = 96;
28 
29   mpfr_t value;
30 
31 public:
32   MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
33 
34   // We use explicit EnableIf specializations to disallow implicit
35   // conversions. Implicit conversions can potentially lead to loss of
36   // precision.
37   template <typename XType,
38             cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
39   explicit MPFRNumber(XType x) {
40     mpfr_init2(value, mpfrPrecision);
41     mpfr_set_flt(value, x, MPFR_RNDN);
42   }
43 
44   template <typename XType,
45             cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
46   explicit MPFRNumber(XType x) {
47     mpfr_init2(value, mpfrPrecision);
48     mpfr_set_d(value, x, MPFR_RNDN);
49   }
50 
51   template <typename XType,
52             cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
53   explicit MPFRNumber(XType x) {
54     mpfr_init2(value, mpfrPrecision);
55     mpfr_set_sj(value, x, MPFR_RNDN);
56   }
57 
58   template <typename XType> MPFRNumber(XType x, const Tolerance &t) {
59     mpfr_init2(value, mpfrPrecision);
60     mpfr_set_zero(value, 1); // Set to positive zero.
61     MPFRNumber xExponent(fputil::getExponent(x));
62     // E = 2^E
63     mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN);
64     uint32_t bitMask = 1 << (t.width - 1);
65     for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) {
66       --n;
67       if (t.bits & bitMask) {
68         // delta = -n
69         MPFRNumber delta(n);
70 
71         // delta = 2^(-n)
72         mpfr_exp2(delta.value, delta.value, MPFR_RNDN);
73 
74         // delta = E * 2^(-n)
75         mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN);
76 
77         // tolerance += delta
78         mpfr_add(value, value, delta.value, MPFR_RNDN);
79       }
80     }
81   }
82 
83   template <typename XType,
84             cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0>
85   MPFRNumber(Operation op, XType rawValue) {
86     mpfr_init2(value, mpfrPrecision);
87     MPFRNumber mpfrInput(rawValue);
88     switch (op) {
89     case OP_Abs:
90       mpfr_abs(value, mpfrInput.value, MPFR_RNDN);
91       break;
92     case OP_Cos:
93       mpfr_cos(value, mpfrInput.value, MPFR_RNDN);
94       break;
95     case OP_Sin:
96       mpfr_sin(value, mpfrInput.value, MPFR_RNDN);
97       break;
98     case OP_Exp:
99       mpfr_exp(value, mpfrInput.value, MPFR_RNDN);
100       break;
101     case OP_Exp2:
102       mpfr_exp2(value, mpfrInput.value, MPFR_RNDN);
103       break;
104     }
105   }
106 
107   MPFRNumber(const MPFRNumber &other) {
108     mpfr_set(value, other.value, MPFR_RNDN);
109   }
110 
111   ~MPFRNumber() { mpfr_clear(value); }
112 
113   // Returns true if |other| is within the |tolerance| value of this
114   // number.
115   bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const {
116     MPFRNumber difference;
117     if (mpfr_cmp(value, other.value) >= 0)
118       mpfr_sub(difference.value, value, other.value, MPFR_RNDN);
119     else
120       mpfr_sub(difference.value, other.value, value, MPFR_RNDN);
121 
122     return mpfr_lessequal_p(difference.value, tolerance.value);
123   }
124 
125   std::string str() const {
126     // 200 bytes should be more than sufficient to hold a 100-digit number
127     // plus additional bytes for the decimal point, '-' sign etc.
128     constexpr size_t printBufSize = 200;
129     char buffer[printBufSize];
130     mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value);
131     llvm::StringRef ref(buffer);
132     ref = ref.trim();
133     return ref.str();
134   }
135 
136   // These functions are useful for debugging.
137   float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); }
138   double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); }
139   void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
140 };
141 
142 namespace internal {
143 
144 template <typename T>
145 void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
146   using fputil::valueAsBits;
147 
148   MPFRNumber mpfrResult(operation, input);
149   MPFRNumber mpfrInput(input);
150   MPFRNumber mpfrMatchValue(matchValue);
151   MPFRNumber mpfrToleranceValue(matchValue, tolerance);
152   OS << "Match value not within tolerance value of MPFR result:\n"
153      << "  Input decimal: " << mpfrInput.str() << '\n'
154      << "     Input bits: 0x" << llvm::utohexstr(valueAsBits(input)) << '\n'
155      << "  Match decimal: " << mpfrMatchValue.str() << '\n'
156      << "     Match bits: 0x" << llvm::utohexstr(valueAsBits(matchValue))
157      << '\n'
158      << "    MPFR result: " << mpfrResult.str() << '\n'
159      << "Tolerance value: " << mpfrToleranceValue.str() << '\n';
160 }
161 
162 template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &);
163 template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &);
164 
165 template <typename T>
166 bool compare(Operation op, T input, T libcResult, const Tolerance &t) {
167   MPFRNumber mpfrResult(op, input);
168   MPFRNumber mpfrLibcResult(libcResult);
169   MPFRNumber mpfrToleranceValue(libcResult, t);
170 
171   return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue);
172 };
173 
174 template bool compare<float>(Operation, float, float, const Tolerance &);
175 template bool compare<double>(Operation, double, double, const Tolerance &);
176 
177 } // namespace internal
178 
179 } // namespace mpfr
180 } // namespace testing
181 } // namespace __llvm_libc
182