1 //===-- Utils which wrap MPFR ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MPFRUtils.h"
10 
11 #include "llvm/ADT/StringRef.h"
12 
13 #include <mpfr.h>
14 #include <string>
15 
16 namespace __llvm_libc {
17 namespace testing {
18 namespace mpfr {
19 
20 class MPFRNumber {
21   // A precision value which allows sufficiently large additional
22   // precision even compared to double precision floating point values.
23   static constexpr unsigned int mpfrPrecision = 96;
24 
25   mpfr_t value;
26 
27 public:
28   MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
29 
30   // We use explicit EnableIf specializations to disallow implicit
31   // conversions. Implicit conversions can potentially lead to loss of
32   // precision.
33   template <typename XType,
34             cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
35   explicit MPFRNumber(XType x) {
36     mpfr_init2(value, mpfrPrecision);
37     mpfr_set_flt(value, x, MPFR_RNDN);
38   }
39 
40   template <typename XType,
41             cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
42   explicit MPFRNumber(XType x) {
43     mpfr_init2(value, mpfrPrecision);
44     mpfr_set_d(value, x, MPFR_RNDN);
45   }
46 
47   template <typename XType,
48             cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0>
49   MPFRNumber(Operation op, XType rawValue) {
50     mpfr_init2(value, mpfrPrecision);
51     MPFRNumber mpfrInput(rawValue);
52     switch (op) {
53     case OP_Cos:
54       mpfr_cos(value, mpfrInput.value, MPFR_RNDN);
55       break;
56     case OP_Sin:
57       mpfr_sin(value, mpfrInput.value, MPFR_RNDN);
58       break;
59     }
60   }
61 
62   MPFRNumber(const MPFRNumber &other) {
63     mpfr_set(value, other.value, MPFR_RNDN);
64   }
65 
66   ~MPFRNumber() { mpfr_clear(value); }
67 
68   // Returns true if |other| is within the tolerance value |t| of this
69   // number.
70   bool isEqual(const MPFRNumber &other, const Tolerance &t) {
71     MPFRNumber tolerance(0.0);
72     uint32_t bitMask = 1 << (t.width - 1);
73     for (int exponent = -t.basePrecision; bitMask > 0; bitMask >>= 1) {
74       --exponent;
75       if (t.bits & bitMask) {
76         MPFRNumber delta;
77         mpfr_set_ui_2exp(delta.value, 1, exponent, MPFR_RNDN);
78         mpfr_add(tolerance.value, tolerance.value, delta.value, MPFR_RNDN);
79       }
80     }
81 
82     MPFRNumber difference;
83     if (mpfr_cmp(value, other.value) >= 0)
84       mpfr_sub(difference.value, value, other.value, MPFR_RNDN);
85     else
86       mpfr_sub(difference.value, other.value, value, MPFR_RNDN);
87 
88     return mpfr_lessequal_p(difference.value, tolerance.value);
89   }
90 
91   std::string str() const {
92     // 200 bytes should be more than sufficient to hold a 100-digit number
93     // plus additional bytes for the decimal point, '-' sign etc.
94     constexpr size_t printBufSize = 200;
95     char buffer[printBufSize];
96     mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value);
97     llvm::StringRef ref(buffer);
98     ref = ref.trim();
99     return ref.str();
100   }
101 
102   // These functions are useful for debugging.
103   float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); }
104   double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); }
105   void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
106 };
107 
108 namespace internal {
109 
110 template <typename T>
111 void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
112   MPFRNumber mpfrResult(operation, input);
113   MPFRNumber mpfrInput(input);
114   MPFRNumber mpfrMatchValue(matchValue);
115   OS << "Match value not within tolerance value of MPFR result:\n"
116      << "Operation input: " << mpfrInput.str() << '\n'
117      << "    Match value: " << mpfrMatchValue.str() << '\n'
118      << "    MPFR result: " << mpfrResult.str() << '\n';
119 }
120 
121 template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &);
122 template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &);
123 
124 template <typename T>
125 bool compare(Operation op, T input, T libcResult, const Tolerance &t) {
126   MPFRNumber mpfrResult(op, input);
127   MPFRNumber mpfrInput(input);
128   MPFRNumber mpfrLibcResult(libcResult);
129   return mpfrResult.isEqual(mpfrLibcResult, t);
130 };
131 
132 template bool compare<float>(Operation, float, float, const Tolerance &);
133 template bool compare<double>(Operation, double, double, const Tolerance &);
134 
135 } // namespace internal
136 
137 } // namespace mpfr
138 } // namespace testing
139 } // namespace __llvm_libc
140