1 //===-- Utils which wrap MPFR ---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MPFRUtils.h" 10 11 #include "utils/FPUtil/FPBits.h" 12 13 #include "llvm/ADT/StringExtras.h" 14 #include "llvm/ADT/StringRef.h" 15 16 #include <mpfr.h> 17 #include <stdint.h> 18 #include <string> 19 20 template <typename T> using FPBits = __llvm_libc::fputil::FPBits<T>; 21 22 namespace __llvm_libc { 23 namespace testing { 24 namespace mpfr { 25 26 class MPFRNumber { 27 // A precision value which allows sufficiently large additional 28 // precision even compared to quad-precision floating point values. 29 static constexpr unsigned int mpfrPrecision = 128; 30 31 mpfr_t value; 32 33 public: 34 MPFRNumber() { mpfr_init2(value, mpfrPrecision); } 35 36 // We use explicit EnableIf specializations to disallow implicit 37 // conversions. Implicit conversions can potentially lead to loss of 38 // precision. 39 template <typename XType, 40 cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0> 41 explicit MPFRNumber(XType x) { 42 mpfr_init2(value, mpfrPrecision); 43 mpfr_set_flt(value, x, MPFR_RNDN); 44 } 45 46 template <typename XType, 47 cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0> 48 explicit MPFRNumber(XType x) { 49 mpfr_init2(value, mpfrPrecision); 50 mpfr_set_d(value, x, MPFR_RNDN); 51 } 52 53 template <typename XType, 54 cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0> 55 explicit MPFRNumber(XType x) { 56 mpfr_init2(value, mpfrPrecision); 57 mpfr_set_ld(value, x, MPFR_RNDN); 58 } 59 60 template <typename XType, 61 cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0> 62 explicit MPFRNumber(XType x) { 63 mpfr_init2(value, mpfrPrecision); 64 mpfr_set_sj(value, x, MPFR_RNDN); 65 } 66 67 template <typename XType> MPFRNumber(XType x, const Tolerance &t) { 68 mpfr_init2(value, mpfrPrecision); 69 mpfr_set_zero(value, 1); // Set to positive zero. 70 MPFRNumber xExponent(fputil::FPBits<XType>(x).getExponent()); 71 // E = 2^E 72 mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN); 73 uint32_t bitMask = 1 << (t.width - 1); 74 for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) { 75 --n; 76 if (t.bits & bitMask) { 77 // delta = -n 78 MPFRNumber delta(n); 79 80 // delta = 2^(-n) 81 mpfr_exp2(delta.value, delta.value, MPFR_RNDN); 82 83 // delta = E * 2^(-n) 84 mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN); 85 86 // tolerance += delta 87 mpfr_add(value, value, delta.value, MPFR_RNDN); 88 } 89 } 90 } 91 92 template <typename XType, 93 cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0> 94 MPFRNumber(Operation op, XType rawValue) { 95 mpfr_init2(value, mpfrPrecision); 96 MPFRNumber mpfrInput(rawValue); 97 switch (op) { 98 case Operation::Abs: 99 mpfr_abs(value, mpfrInput.value, MPFR_RNDN); 100 break; 101 case Operation::Ceil: 102 mpfr_ceil(value, mpfrInput.value); 103 break; 104 case Operation::Cos: 105 mpfr_cos(value, mpfrInput.value, MPFR_RNDN); 106 break; 107 case Operation::Exp: 108 mpfr_exp(value, mpfrInput.value, MPFR_RNDN); 109 break; 110 case Operation::Exp2: 111 mpfr_exp2(value, mpfrInput.value, MPFR_RNDN); 112 break; 113 case Operation::Floor: 114 mpfr_floor(value, mpfrInput.value); 115 break; 116 case Operation::Round: 117 mpfr_round(value, mpfrInput.value); 118 break; 119 case Operation::Sin: 120 mpfr_sin(value, mpfrInput.value, MPFR_RNDN); 121 break; 122 case Operation::Trunc: 123 mpfr_trunc(value, mpfrInput.value); 124 break; 125 } 126 } 127 128 MPFRNumber(const MPFRNumber &other) { 129 mpfr_set(value, other.value, MPFR_RNDN); 130 } 131 132 ~MPFRNumber() { mpfr_clear(value); } 133 134 // Returns true if |other| is within the |tolerance| value of this 135 // number. 136 bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const { 137 MPFRNumber difference; 138 if (mpfr_cmp(value, other.value) >= 0) 139 mpfr_sub(difference.value, value, other.value, MPFR_RNDN); 140 else 141 mpfr_sub(difference.value, other.value, value, MPFR_RNDN); 142 143 return mpfr_lessequal_p(difference.value, tolerance.value); 144 } 145 146 std::string str() const { 147 // 200 bytes should be more than sufficient to hold a 100-digit number 148 // plus additional bytes for the decimal point, '-' sign etc. 149 constexpr size_t printBufSize = 200; 150 char buffer[printBufSize]; 151 mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value); 152 llvm::StringRef ref(buffer); 153 ref = ref.trim(); 154 return ref.str(); 155 } 156 157 // These functions are useful for debugging. 158 float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); } 159 double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); } 160 void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); } 161 }; 162 163 namespace internal { 164 165 template <typename T> 166 void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) { 167 MPFRNumber mpfrResult(operation, input); 168 MPFRNumber mpfrInput(input); 169 MPFRNumber mpfrMatchValue(matchValue); 170 MPFRNumber mpfrToleranceValue(matchValue, tolerance); 171 FPBits<T> inputBits(input); 172 FPBits<T> matchBits(matchValue); 173 // TODO: Call to llvm::utohexstr implicitly converts __uint128_t values to 174 // uint64_t values. This can be fixed using a custom wrapper for 175 // llvm::utohexstr to handle __uint128_t values correctly. 176 OS << "Match value not within tolerance value of MPFR result:\n" 177 << " Input decimal: " << mpfrInput.str() << '\n' 178 << " Input bits: 0x" << llvm::utohexstr(inputBits.bitsAsUInt()) << '\n' 179 << " Match decimal: " << mpfrMatchValue.str() << '\n' 180 << " Match bits: 0x" << llvm::utohexstr(matchBits.bitsAsUInt()) << '\n' 181 << " MPFR result: " << mpfrResult.str() << '\n' 182 << "Tolerance value: " << mpfrToleranceValue.str() << '\n'; 183 } 184 185 template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &); 186 template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &); 187 template void 188 MPFRMatcher<long double>::explainError(testutils::StreamWrapper &); 189 190 template <typename T> 191 bool compare(Operation op, T input, T libcResult, const Tolerance &t) { 192 MPFRNumber mpfrResult(op, input); 193 MPFRNumber mpfrLibcResult(libcResult); 194 MPFRNumber mpfrToleranceValue(libcResult, t); 195 196 return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue); 197 }; 198 199 template bool compare<float>(Operation, float, float, const Tolerance &); 200 template bool compare<double>(Operation, double, double, const Tolerance &); 201 template bool compare<long double>(Operation, long double, long double, 202 const Tolerance &); 203 204 } // namespace internal 205 206 } // namespace mpfr 207 } // namespace testing 208 } // namespace __llvm_libc 209