1 //===-- Utils which wrap MPFR ---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MPFRUtils.h" 10 11 #include "llvm/ADT/StringExtras.h" 12 #include "llvm/ADT/StringRef.h" 13 14 #include <mpfr.h> 15 #include <stdint.h> 16 #include <string> 17 18 namespace __llvm_libc { 19 namespace testing { 20 namespace mpfr { 21 22 template <typename T> struct FloatProperties {}; 23 24 template <> struct FloatProperties<float> { 25 typedef uint32_t BitsType; 26 static_assert(sizeof(BitsType) == sizeof(float), 27 "Unexpected size of 'float' type."); 28 29 static constexpr uint32_t mantissaWidth = 23; 30 static constexpr BitsType signMask = 0x7FFFFFFFU; 31 static constexpr uint32_t exponentOffset = 127; 32 }; 33 34 template <> struct FloatProperties<double> { 35 typedef uint64_t BitsType; 36 static_assert(sizeof(BitsType) == sizeof(double), 37 "Unexpected size of 'double' type."); 38 39 static constexpr uint32_t mantissaWidth = 52; 40 static constexpr BitsType signMask = 0x7FFFFFFFFFFFFFFFULL; 41 static constexpr uint32_t exponentOffset = 1023; 42 }; 43 44 template <typename T> typename FloatProperties<T>::BitsType getBits(T x) { 45 using BitsType = typename FloatProperties<T>::BitsType; 46 return *reinterpret_cast<BitsType *>(&x); 47 } 48 49 // Returns the zero adjusted exponent value of abs(x). 50 template <typename T> int getExponent(T x) { 51 using Properties = FloatProperties<T>; 52 using BitsType = typename Properties::BitsType; 53 BitsType bits = *reinterpret_cast<BitsType *>(&x); 54 bits &= Properties::signMask; // Zero the sign bit. 55 int e = (bits >> Properties::mantissaWidth); // Shift out the mantissa. 56 e -= Properties::exponentOffset; // Zero adjust. 57 return e; 58 } 59 60 class MPFRNumber { 61 // A precision value which allows sufficiently large additional 62 // precision even compared to double precision floating point values. 63 static constexpr unsigned int mpfrPrecision = 96; 64 65 mpfr_t value; 66 67 public: 68 MPFRNumber() { mpfr_init2(value, mpfrPrecision); } 69 70 // We use explicit EnableIf specializations to disallow implicit 71 // conversions. Implicit conversions can potentially lead to loss of 72 // precision. 73 template <typename XType, 74 cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0> 75 explicit MPFRNumber(XType x) { 76 mpfr_init2(value, mpfrPrecision); 77 mpfr_set_flt(value, x, MPFR_RNDN); 78 } 79 80 template <typename XType, 81 cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0> 82 explicit MPFRNumber(XType x) { 83 mpfr_init2(value, mpfrPrecision); 84 mpfr_set_d(value, x, MPFR_RNDN); 85 } 86 87 template <typename XType, 88 cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0> 89 explicit MPFRNumber(XType x) { 90 mpfr_init2(value, mpfrPrecision); 91 mpfr_set_sj(value, x, MPFR_RNDN); 92 } 93 94 template <typename XType> MPFRNumber(XType x, const Tolerance &t) { 95 mpfr_init2(value, mpfrPrecision); 96 mpfr_set_zero(value, 1); // Set to positive zero. 97 MPFRNumber xExponent(getExponent(x)); 98 // E = 2^E 99 mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN); 100 uint32_t bitMask = 1 << (t.width - 1); 101 for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) { 102 --n; 103 if (t.bits & bitMask) { 104 // delta = -n 105 MPFRNumber delta(n); 106 107 // delta = 2^(-n) 108 mpfr_exp2(delta.value, delta.value, MPFR_RNDN); 109 110 // delta = E * 2^(-n) 111 mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN); 112 113 // tolerance += delta 114 mpfr_add(value, value, delta.value, MPFR_RNDN); 115 } 116 } 117 } 118 119 template <typename XType, 120 cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0> 121 MPFRNumber(Operation op, XType rawValue) { 122 mpfr_init2(value, mpfrPrecision); 123 MPFRNumber mpfrInput(rawValue); 124 switch (op) { 125 case OP_Cos: 126 mpfr_cos(value, mpfrInput.value, MPFR_RNDN); 127 break; 128 case OP_Sin: 129 mpfr_sin(value, mpfrInput.value, MPFR_RNDN); 130 break; 131 } 132 } 133 134 MPFRNumber(const MPFRNumber &other) { 135 mpfr_set(value, other.value, MPFR_RNDN); 136 } 137 138 ~MPFRNumber() { mpfr_clear(value); } 139 140 // Returns true if |other| is within the |tolerance| value of this 141 // number. 142 bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const { 143 MPFRNumber difference; 144 if (mpfr_cmp(value, other.value) >= 0) 145 mpfr_sub(difference.value, value, other.value, MPFR_RNDN); 146 else 147 mpfr_sub(difference.value, other.value, value, MPFR_RNDN); 148 149 return mpfr_lessequal_p(difference.value, tolerance.value); 150 } 151 152 std::string str() const { 153 // 200 bytes should be more than sufficient to hold a 100-digit number 154 // plus additional bytes for the decimal point, '-' sign etc. 155 constexpr size_t printBufSize = 200; 156 char buffer[printBufSize]; 157 mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value); 158 llvm::StringRef ref(buffer); 159 ref = ref.trim(); 160 return ref.str(); 161 } 162 163 // These functions are useful for debugging. 164 float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); } 165 double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); } 166 void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); } 167 }; 168 169 namespace internal { 170 171 template <typename T> 172 void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) { 173 MPFRNumber mpfrResult(operation, input); 174 MPFRNumber mpfrInput(input); 175 MPFRNumber mpfrMatchValue(matchValue); 176 MPFRNumber mpfrToleranceValue(matchValue, tolerance); 177 OS << "Match value not within tolerance value of MPFR result:\n" 178 << " Input decimal: " << mpfrInput.str() << '\n' 179 << " Input bits: 0x" << llvm::utohexstr(getBits(input)) << '\n' 180 << " Match decimal: " << mpfrMatchValue.str() << '\n' 181 << " Match bits: 0x" << llvm::utohexstr(getBits(matchValue)) << '\n' 182 << " MPFR result: " << mpfrResult.str() << '\n' 183 << "Tolerance value: " << mpfrToleranceValue.str() << '\n'; 184 } 185 186 template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &); 187 template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &); 188 189 template <typename T> 190 bool compare(Operation op, T input, T libcResult, const Tolerance &t) { 191 MPFRNumber mpfrResult(op, input); 192 MPFRNumber mpfrLibcResult(libcResult); 193 MPFRNumber mpfrToleranceValue(libcResult, t); 194 195 return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue); 196 }; 197 198 template bool compare<float>(Operation, float, float, const Tolerance &); 199 template bool compare<double>(Operation, double, double, const Tolerance &); 200 201 } // namespace internal 202 203 } // namespace mpfr 204 } // namespace testing 205 } // namespace __llvm_libc 206