1 //===-- Utils which wrap MPFR ---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "MPFRUtils.h" 10 11 #include <iostream> 12 #include <mpfr.h> 13 14 namespace __llvm_libc { 15 namespace testing { 16 namespace mpfr { 17 18 class MPFRNumber { 19 // A precision value which allows sufficiently large additional 20 // precision even compared to double precision floating point values. 21 static constexpr unsigned int mpfrPrecision = 96; 22 23 mpfr_t value; 24 25 public: 26 MPFRNumber() { mpfr_init2(value, mpfrPrecision); } 27 28 explicit MPFRNumber(float x) { 29 mpfr_init2(value, mpfrPrecision); 30 mpfr_set_flt(value, x, MPFR_RNDN); 31 } 32 33 MPFRNumber(const MPFRNumber &other) { 34 mpfr_set(value, other.value, MPFR_RNDN); 35 } 36 37 ~MPFRNumber() { mpfr_clear(value); } 38 39 // Returns true if |other| is within the tolerance value |t| of this 40 // number. 41 bool isEqual(const MPFRNumber &other, const Tolerance &t) { 42 MPFRNumber tolerance(0.0); 43 uint32_t bitMask = 1 << (t.width - 1); 44 for (int exponent = -t.basePrecision; bitMask > 0; bitMask >>= 1) { 45 --exponent; 46 if (t.bits & bitMask) { 47 MPFRNumber delta; 48 mpfr_set_ui_2exp(delta.value, 1, exponent, MPFR_RNDN); 49 mpfr_add(tolerance.value, tolerance.value, delta.value, MPFR_RNDN); 50 } 51 } 52 53 MPFRNumber difference; 54 if (mpfr_cmp(value, other.value) >= 0) 55 mpfr_sub(difference.value, value, other.value, MPFR_RNDN); 56 else 57 mpfr_sub(difference.value, other.value, value, MPFR_RNDN); 58 59 return mpfr_lessequal_p(difference.value, tolerance.value); 60 } 61 62 // These functions are useful for debugging. 63 float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); } 64 double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); } 65 void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); } 66 67 public: 68 static MPFRNumber cos(float x) { 69 MPFRNumber result; 70 MPFRNumber mpfrX(x); 71 mpfr_cos(result.value, mpfrX.value, MPFR_RNDN); 72 return result; 73 } 74 75 static MPFRNumber sin(float x) { 76 MPFRNumber result; 77 MPFRNumber mpfrX(x); 78 mpfr_sin(result.value, mpfrX.value, MPFR_RNDN); 79 return result; 80 } 81 }; 82 83 bool equalsCos(float input, float libcOutput, const Tolerance &t) { 84 MPFRNumber mpfrResult = MPFRNumber::cos(input); 85 MPFRNumber libcResult(libcOutput); 86 return mpfrResult.isEqual(libcResult, t); 87 } 88 89 bool equalsSin(float input, float libcOutput, const Tolerance &t) { 90 MPFRNumber mpfrResult = MPFRNumber::sin(input); 91 MPFRNumber libcResult(libcOutput); 92 return mpfrResult.isEqual(libcResult, t); 93 } 94 95 } // namespace mpfr 96 } // namespace testing 97 } // namespace __llvm_libc 98