1 //===-- Utils which wrap MPFR ---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MPFRUtils.h" 10 11 #include "utils/FPUtil/FloatOperations.h" 12 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/ADT/StringRef.h" 15 16 #include <mpfr.h> 17 #include <stdint.h> 18 #include <string> 19 20 namespace __llvm_libc { 21 namespace testing { 22 namespace mpfr { 23 24 class MPFRNumber { 25 // A precision value which allows sufficiently large additional 26 // precision even compared to double precision floating point values. 27 static constexpr unsigned int mpfrPrecision = 96; 28 29 mpfr_t value; 30 31 public: 32 MPFRNumber() { mpfr_init2(value, mpfrPrecision); } 33 34 // We use explicit EnableIf specializations to disallow implicit 35 // conversions. Implicit conversions can potentially lead to loss of 36 // precision. 37 template <typename XType, 38 cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0> 39 explicit MPFRNumber(XType x) { 40 mpfr_init2(value, mpfrPrecision); 41 mpfr_set_flt(value, x, MPFR_RNDN); 42 } 43 44 template <typename XType, 45 cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0> 46 explicit MPFRNumber(XType x) { 47 mpfr_init2(value, mpfrPrecision); 48 mpfr_set_d(value, x, MPFR_RNDN); 49 } 50 51 template <typename XType, 52 cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0> 53 explicit MPFRNumber(XType x) { 54 mpfr_init2(value, mpfrPrecision); 55 mpfr_set_sj(value, x, MPFR_RNDN); 56 } 57 58 template <typename XType> MPFRNumber(XType x, const Tolerance &t) { 59 mpfr_init2(value, mpfrPrecision); 60 mpfr_set_zero(value, 1); // Set to positive zero. 61 MPFRNumber xExponent(fputil::getExponent(x)); 62 // E = 2^E 63 mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN); 64 uint32_t bitMask = 1 << (t.width - 1); 65 for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) { 66 --n; 67 if (t.bits & bitMask) { 68 // delta = -n 69 MPFRNumber delta(n); 70 71 // delta = 2^(-n) 72 mpfr_exp2(delta.value, delta.value, MPFR_RNDN); 73 74 // delta = E * 2^(-n) 75 mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN); 76 77 // tolerance += delta 78 mpfr_add(value, value, delta.value, MPFR_RNDN); 79 } 80 } 81 } 82 83 template <typename XType, 84 cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0> 85 MPFRNumber(Operation op, XType rawValue) { 86 mpfr_init2(value, mpfrPrecision); 87 MPFRNumber mpfrInput(rawValue); 88 switch (op) { 89 case Operation::Abs: 90 mpfr_abs(value, mpfrInput.value, MPFR_RNDN); 91 break; 92 case Operation::Ceil: 93 mpfr_ceil(value, mpfrInput.value); 94 break; 95 case Operation::Cos: 96 mpfr_cos(value, mpfrInput.value, MPFR_RNDN); 97 break; 98 case Operation::Exp: 99 mpfr_exp(value, mpfrInput.value, MPFR_RNDN); 100 break; 101 case Operation::Exp2: 102 mpfr_exp2(value, mpfrInput.value, MPFR_RNDN); 103 break; 104 case Operation::Floor: 105 mpfr_floor(value, mpfrInput.value); 106 break; 107 case Operation::Round: 108 mpfr_round(value, mpfrInput.value); 109 break; 110 case Operation::Sin: 111 mpfr_sin(value, mpfrInput.value, MPFR_RNDN); 112 break; 113 case Operation::Trunc: 114 mpfr_trunc(value, mpfrInput.value); 115 break; 116 } 117 } 118 119 MPFRNumber(const MPFRNumber &other) { 120 mpfr_set(value, other.value, MPFR_RNDN); 121 } 122 123 ~MPFRNumber() { mpfr_clear(value); } 124 125 // Returns true if |other| is within the |tolerance| value of this 126 // number. 127 bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const { 128 MPFRNumber difference; 129 if (mpfr_cmp(value, other.value) >= 0) 130 mpfr_sub(difference.value, value, other.value, MPFR_RNDN); 131 else 132 mpfr_sub(difference.value, other.value, value, MPFR_RNDN); 133 134 return mpfr_lessequal_p(difference.value, tolerance.value); 135 } 136 137 std::string str() const { 138 // 200 bytes should be more than sufficient to hold a 100-digit number 139 // plus additional bytes for the decimal point, '-' sign etc. 140 constexpr size_t printBufSize = 200; 141 char buffer[printBufSize]; 142 mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value); 143 llvm::StringRef ref(buffer); 144 ref = ref.trim(); 145 return ref.str(); 146 } 147 148 // These functions are useful for debugging. 149 float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); } 150 double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); } 151 void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); } 152 }; 153 154 namespace internal { 155 156 template <typename T> 157 void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) { 158 using fputil::valueAsBits; 159 160 MPFRNumber mpfrResult(operation, input); 161 MPFRNumber mpfrInput(input); 162 MPFRNumber mpfrMatchValue(matchValue); 163 MPFRNumber mpfrToleranceValue(matchValue, tolerance); 164 OS << "Match value not within tolerance value of MPFR result:\n" 165 << " Input decimal: " << mpfrInput.str() << '\n' 166 << " Input bits: 0x" << llvm::utohexstr(valueAsBits(input)) << '\n' 167 << " Match decimal: " << mpfrMatchValue.str() << '\n' 168 << " Match bits: 0x" << llvm::utohexstr(valueAsBits(matchValue)) 169 << '\n' 170 << " MPFR result: " << mpfrResult.str() << '\n' 171 << "Tolerance value: " << mpfrToleranceValue.str() << '\n'; 172 } 173 174 template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &); 175 template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &); 176 177 template <typename T> 178 bool compare(Operation op, T input, T libcResult, const Tolerance &t) { 179 MPFRNumber mpfrResult(op, input); 180 MPFRNumber mpfrLibcResult(libcResult); 181 MPFRNumber mpfrToleranceValue(libcResult, t); 182 183 return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue); 184 }; 185 186 template bool compare<float>(Operation, float, float, const Tolerance &); 187 template bool compare<double>(Operation, double, double, const Tolerance &); 188 189 } // namespace internal 190 191 } // namespace mpfr 192 } // namespace testing 193 } // namespace __llvm_libc 194