1 //===-- Utils which wrap MPFR ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MPFRUtils.h"
10 
11 #include <iostream>
12 #include <mpfr.h>
13 
14 namespace __llvm_libc {
15 namespace testing {
16 namespace mpfr {
17 
18 class MPFRNumber {
19   // A precision value which allows sufficiently large additional
20   // precision even compared to double precision floating point values.
21   static constexpr unsigned int mpfrPrecision = 96;
22 
23   mpfr_t value;
24 
25 public:
26   MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
27 
28   explicit MPFRNumber(float x) {
29     mpfr_init2(value, mpfrPrecision);
30     mpfr_set_flt(value, x, MPFR_RNDN);
31   }
32 
33   MPFRNumber(const MPFRNumber &other) {
34     mpfr_set(value, other.value, MPFR_RNDN);
35   }
36 
37   ~MPFRNumber() { mpfr_clear(value); }
38 
39   // Returns true if |other| is within the tolerance value |t| of this
40   // number.
41   bool isEqual(const MPFRNumber &other, const Tolerance &t) {
42     MPFRNumber tolerance(0.0);
43     uint32_t bitMask = 1 << (t.width - 1);
44     for (int exponent = -t.basePrecision; bitMask > 0; bitMask >>= 1) {
45       --exponent;
46       if (t.bits & bitMask) {
47         MPFRNumber delta;
48         mpfr_set_ui_2exp(delta.value, 1, exponent, MPFR_RNDN);
49         mpfr_add(tolerance.value, tolerance.value, delta.value, MPFR_RNDN);
50       }
51     }
52 
53     MPFRNumber difference;
54     if (mpfr_cmp(value, other.value) >= 0)
55       mpfr_sub(difference.value, value, other.value, MPFR_RNDN);
56     else
57       mpfr_sub(difference.value, other.value, value, MPFR_RNDN);
58 
59     return mpfr_lessequal_p(difference.value, tolerance.value);
60   }
61 
62   // These functions are useful for debugging.
63   float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); }
64   double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); }
65   void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
66 
67 public:
68   static MPFRNumber cos(float x) {
69     MPFRNumber result;
70     MPFRNumber mpfrX(x);
71     mpfr_cos(result.value, mpfrX.value, MPFR_RNDN);
72     return result;
73   }
74 
75   static MPFRNumber sin(float x) {
76     MPFRNumber result;
77     MPFRNumber mpfrX(x);
78     mpfr_sin(result.value, mpfrX.value, MPFR_RNDN);
79     return result;
80   }
81 };
82 
83 bool equalsCos(float input, float libcOutput, const Tolerance &t) {
84   MPFRNumber mpfrResult = MPFRNumber::cos(input);
85   MPFRNumber libcResult(libcOutput);
86   return mpfrResult.isEqual(libcResult, t);
87 }
88 
89 bool equalsSin(float input, float libcOutput, const Tolerance &t) {
90   MPFRNumber mpfrResult = MPFRNumber::sin(input);
91   MPFRNumber libcResult(libcOutput);
92   return mpfrResult.isEqual(libcResult, t);
93 }
94 
95 } // namespace mpfr
96 } // namespace testing
97 } // namespace __llvm_libc
98