1 //===-- Utils which wrap MPFR ---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MPFRUtils.h" 10 11 #include "utils/FPUtil/FloatOperations.h" 12 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/ADT/StringRef.h" 15 16 #include <mpfr.h> 17 #include <stdint.h> 18 #include <string> 19 20 namespace __llvm_libc { 21 namespace testing { 22 namespace mpfr { 23 24 class MPFRNumber { 25 // A precision value which allows sufficiently large additional 26 // precision even compared to double precision floating point values. 27 static constexpr unsigned int mpfrPrecision = 96; 28 29 mpfr_t value; 30 31 public: 32 MPFRNumber() { mpfr_init2(value, mpfrPrecision); } 33 34 // We use explicit EnableIf specializations to disallow implicit 35 // conversions. Implicit conversions can potentially lead to loss of 36 // precision. 37 template <typename XType, 38 cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0> 39 explicit MPFRNumber(XType x) { 40 mpfr_init2(value, mpfrPrecision); 41 mpfr_set_flt(value, x, MPFR_RNDN); 42 } 43 44 template <typename XType, 45 cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0> 46 explicit MPFRNumber(XType x) { 47 mpfr_init2(value, mpfrPrecision); 48 mpfr_set_d(value, x, MPFR_RNDN); 49 } 50 51 template <typename XType, 52 cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0> 53 explicit MPFRNumber(XType x) { 54 mpfr_init2(value, mpfrPrecision); 55 mpfr_set_sj(value, x, MPFR_RNDN); 56 } 57 58 template <typename XType> MPFRNumber(XType x, const Tolerance &t) { 59 mpfr_init2(value, mpfrPrecision); 60 mpfr_set_zero(value, 1); // Set to positive zero. 61 MPFRNumber xExponent(fputil::getExponent(x)); 62 // E = 2^E 63 mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN); 64 uint32_t bitMask = 1 << (t.width - 1); 65 for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) { 66 --n; 67 if (t.bits & bitMask) { 68 // delta = -n 69 MPFRNumber delta(n); 70 71 // delta = 2^(-n) 72 mpfr_exp2(delta.value, delta.value, MPFR_RNDN); 73 74 // delta = E * 2^(-n) 75 mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN); 76 77 // tolerance += delta 78 mpfr_add(value, value, delta.value, MPFR_RNDN); 79 } 80 } 81 } 82 83 template <typename XType, 84 cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0> 85 MPFRNumber(Operation op, XType rawValue) { 86 mpfr_init2(value, mpfrPrecision); 87 MPFRNumber mpfrInput(rawValue); 88 switch (op) { 89 case OP_Abs: 90 mpfr_abs(value, mpfrInput.value, MPFR_RNDN); 91 break; 92 case OP_Cos: 93 mpfr_cos(value, mpfrInput.value, MPFR_RNDN); 94 break; 95 case OP_Sin: 96 mpfr_sin(value, mpfrInput.value, MPFR_RNDN); 97 break; 98 case OP_Exp: 99 mpfr_exp(value, mpfrInput.value, MPFR_RNDN); 100 break; 101 case OP_Exp2: 102 mpfr_exp2(value, mpfrInput.value, MPFR_RNDN); 103 break; 104 } 105 } 106 107 MPFRNumber(const MPFRNumber &other) { 108 mpfr_set(value, other.value, MPFR_RNDN); 109 } 110 111 ~MPFRNumber() { mpfr_clear(value); } 112 113 // Returns true if |other| is within the |tolerance| value of this 114 // number. 115 bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const { 116 MPFRNumber difference; 117 if (mpfr_cmp(value, other.value) >= 0) 118 mpfr_sub(difference.value, value, other.value, MPFR_RNDN); 119 else 120 mpfr_sub(difference.value, other.value, value, MPFR_RNDN); 121 122 return mpfr_lessequal_p(difference.value, tolerance.value); 123 } 124 125 std::string str() const { 126 // 200 bytes should be more than sufficient to hold a 100-digit number 127 // plus additional bytes for the decimal point, '-' sign etc. 128 constexpr size_t printBufSize = 200; 129 char buffer[printBufSize]; 130 mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value); 131 llvm::StringRef ref(buffer); 132 ref = ref.trim(); 133 return ref.str(); 134 } 135 136 // These functions are useful for debugging. 137 float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); } 138 double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); } 139 void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); } 140 }; 141 142 namespace internal { 143 144 template <typename T> 145 void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) { 146 using fputil::valueAsBits; 147 148 MPFRNumber mpfrResult(operation, input); 149 MPFRNumber mpfrInput(input); 150 MPFRNumber mpfrMatchValue(matchValue); 151 MPFRNumber mpfrToleranceValue(matchValue, tolerance); 152 OS << "Match value not within tolerance value of MPFR result:\n" 153 << " Input decimal: " << mpfrInput.str() << '\n' 154 << " Input bits: 0x" << llvm::utohexstr(valueAsBits(input)) << '\n' 155 << " Match decimal: " << mpfrMatchValue.str() << '\n' 156 << " Match bits: 0x" << llvm::utohexstr(valueAsBits(matchValue)) 157 << '\n' 158 << " MPFR result: " << mpfrResult.str() << '\n' 159 << "Tolerance value: " << mpfrToleranceValue.str() << '\n'; 160 } 161 162 template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &); 163 template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &); 164 165 template <typename T> 166 bool compare(Operation op, T input, T libcResult, const Tolerance &t) { 167 MPFRNumber mpfrResult(op, input); 168 MPFRNumber mpfrLibcResult(libcResult); 169 MPFRNumber mpfrToleranceValue(libcResult, t); 170 171 return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue); 172 }; 173 174 template bool compare<float>(Operation, float, float, const Tolerance &); 175 template bool compare<double>(Operation, double, double, const Tolerance &); 176 177 } // namespace internal 178 179 } // namespace mpfr 180 } // namespace testing 181 } // namespace __llvm_libc 182