1 //===-- Utils which wrap MPFR ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MPFRUtils.h"
10 
11 #include "src/__support/CPP/StringView.h"
12 #include "src/__support/FPUtil/FPBits.h"
13 #include "src/__support/architectures.h"
14 #include "utils/UnitTest/FPMatcher.h"
15 
16 #include <cmath>
17 #include <memory>
18 #include <stdint.h>
19 #include <string>
20 
21 #ifdef CUSTOM_MPFR_INCLUDER
22 // Some downstream repos are monoliths carrying MPFR sources in their third
23 // party directory. In such repos, including the MPFR header as
24 // `#include <mpfr.h>` is either disallowed or not possible. If that is the
25 // case, a file named `CustomMPFRIncluder.h` should be added through which the
26 // MPFR header can be included in manner allowed in that repo.
27 #include "CustomMPFRIncluder.h"
28 #else
29 #include <mpfr.h>
30 #endif
31 
32 template <typename T> using FPBits = __llvm_libc::fputil::FPBits<T>;
33 
34 namespace __llvm_libc {
35 namespace testing {
36 namespace mpfr {
37 
38 template <typename T> struct Precision;
39 
40 template <> struct Precision<float> {
41   static constexpr unsigned int value = 24;
42 };
43 
44 template <> struct Precision<double> {
45   static constexpr unsigned int value = 53;
46 };
47 
48 #if !(defined(LLVM_LIBC_ARCH_X86))
49 template <> struct Precision<long double> {
50   static constexpr unsigned int value = 64;
51 };
52 #else
53 template <> struct Precision<long double> {
54   static constexpr unsigned int value = 113;
55 };
56 #endif
57 
58 class MPFRNumber {
59   // A precision value which allows sufficiently large additional
60   // precision even compared to quad-precision floating point values.
61   unsigned int mpfrPrecision;
62 
63   mpfr_t value;
64 
65 public:
66   MPFRNumber() : mpfrPrecision(256) { mpfr_init2(value, mpfrPrecision); }
67 
68   // We use explicit EnableIf specializations to disallow implicit
69   // conversions. Implicit conversions can potentially lead to loss of
70   // precision.
71   template <typename XType,
72             cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
73   explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
74     mpfr_init2(value, mpfrPrecision);
75     mpfr_set_flt(value, x, MPFR_RNDN);
76   }
77 
78   template <typename XType,
79             cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
80   explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
81     mpfr_init2(value, mpfrPrecision);
82     mpfr_set_d(value, x, MPFR_RNDN);
83   }
84 
85   template <typename XType,
86             cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
87   explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
88     mpfr_init2(value, mpfrPrecision);
89     mpfr_set_ld(value, x, MPFR_RNDN);
90   }
91 
92   template <typename XType,
93             cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
94   explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
95     mpfr_init2(value, mpfrPrecision);
96     mpfr_set_sj(value, x, MPFR_RNDN);
97   }
98 
99   MPFRNumber(const MPFRNumber &other) : mpfrPrecision(other.mpfrPrecision) {
100     mpfr_init2(value, mpfrPrecision);
101     mpfr_set(value, other.value, MPFR_RNDN);
102   }
103 
104   ~MPFRNumber() { mpfr_clear(value); }
105 
106   MPFRNumber &operator=(const MPFRNumber &rhs) {
107     mpfrPrecision = rhs.mpfrPrecision;
108     mpfr_set(value, rhs.value, MPFR_RNDN);
109     return *this;
110   }
111 
112   MPFRNumber abs() const {
113     MPFRNumber result;
114     mpfr_abs(result.value, value, MPFR_RNDN);
115     return result;
116   }
117 
118   MPFRNumber ceil() const {
119     MPFRNumber result;
120     mpfr_ceil(result.value, value);
121     return result;
122   }
123 
124   MPFRNumber cos() const {
125     MPFRNumber result;
126     mpfr_cos(result.value, value, MPFR_RNDN);
127     return result;
128   }
129 
130   MPFRNumber exp() const {
131     MPFRNumber result;
132     mpfr_exp(result.value, value, MPFR_RNDN);
133     return result;
134   }
135 
136   MPFRNumber exp2() const {
137     MPFRNumber result;
138     mpfr_exp2(result.value, value, MPFR_RNDN);
139     return result;
140   }
141 
142   MPFRNumber expm1() const {
143     MPFRNumber result;
144     mpfr_expm1(result.value, value, MPFR_RNDN);
145     return result;
146   }
147 
148   MPFRNumber floor() const {
149     MPFRNumber result;
150     mpfr_floor(result.value, value);
151     return result;
152   }
153 
154   MPFRNumber frexp(int &exp) {
155     MPFRNumber result;
156     mpfr_exp_t resultExp;
157     mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN);
158     exp = resultExp;
159     return result;
160   }
161 
162   MPFRNumber hypot(const MPFRNumber &b) {
163     MPFRNumber result;
164     mpfr_hypot(result.value, value, b.value, MPFR_RNDN);
165     return result;
166   }
167 
168   MPFRNumber remquo(const MPFRNumber &divisor, int &quotient) {
169     MPFRNumber remainder;
170     long q;
171     mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN);
172     quotient = q;
173     return remainder;
174   }
175 
176   MPFRNumber round() const {
177     MPFRNumber result;
178     mpfr_round(result.value, value);
179     return result;
180   }
181 
182   bool roundToLong(long &result) const {
183     // We first calculate the rounded value. This way, when converting
184     // to long using mpfr_get_si, the rounding direction of MPFR_RNDN
185     // (or any other rounding mode), does not have an influence.
186     MPFRNumber roundedValue = round();
187     mpfr_clear_erangeflag();
188     result = mpfr_get_si(roundedValue.value, MPFR_RNDN);
189     return mpfr_erangeflag_p();
190   }
191 
192   bool roundToLong(mpfr_rnd_t rnd, long &result) const {
193     MPFRNumber rint_result;
194     mpfr_rint(rint_result.value, value, rnd);
195     return rint_result.roundToLong(result);
196   }
197 
198   MPFRNumber rint(mpfr_rnd_t rnd) const {
199     MPFRNumber result;
200     mpfr_rint(result.value, value, rnd);
201     return result;
202   }
203 
204   MPFRNumber mod_2pi() const {
205     MPFRNumber result(0.0, 1280);
206     MPFRNumber _2pi(0.0, 1280);
207     mpfr_const_pi(_2pi.value, MPFR_RNDN);
208     mpfr_mul_si(_2pi.value, _2pi.value, 2, MPFR_RNDN);
209     mpfr_fmod(result.value, value, _2pi.value, MPFR_RNDN);
210     return result;
211   }
212 
213   MPFRNumber mod_pi_over_2() const {
214     MPFRNumber result(0.0, 1280);
215     MPFRNumber pi_over_2(0.0, 1280);
216     mpfr_const_pi(pi_over_2.value, MPFR_RNDN);
217     mpfr_mul_d(pi_over_2.value, pi_over_2.value, 0.5, MPFR_RNDN);
218     mpfr_fmod(result.value, value, pi_over_2.value, MPFR_RNDN);
219     return result;
220   }
221 
222   MPFRNumber mod_pi_over_4() const {
223     MPFRNumber result(0.0, 1280);
224     MPFRNumber pi_over_4(0.0, 1280);
225     mpfr_const_pi(pi_over_4.value, MPFR_RNDN);
226     mpfr_mul_d(pi_over_4.value, pi_over_4.value, 0.25, MPFR_RNDN);
227     mpfr_fmod(result.value, value, pi_over_4.value, MPFR_RNDN);
228     return result;
229   }
230 
231   MPFRNumber sin() const {
232     MPFRNumber result;
233     mpfr_sin(result.value, value, MPFR_RNDN);
234     return result;
235   }
236 
237   MPFRNumber sqrt() const {
238     MPFRNumber result;
239     mpfr_sqrt(result.value, value, MPFR_RNDN);
240     return result;
241   }
242 
243   MPFRNumber tan() const {
244     MPFRNumber result;
245     mpfr_tan(result.value, value, MPFR_RNDN);
246     return result;
247   }
248 
249   MPFRNumber trunc() const {
250     MPFRNumber result;
251     mpfr_trunc(result.value, value);
252     return result;
253   }
254 
255   MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) {
256     MPFRNumber result(*this);
257     mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN);
258     return result;
259   }
260 
261   std::string str() const {
262     // 200 bytes should be more than sufficient to hold a 100-digit number
263     // plus additional bytes for the decimal point, '-' sign etc.
264     constexpr size_t printBufSize = 200;
265     char buffer[printBufSize];
266     mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value);
267     cpp::StringView view(buffer);
268     view = view.trim(' ');
269     return std::string(view.data());
270   }
271 
272   // These functions are useful for debugging.
273   template <typename T> T as() const;
274 
275   template <> float as<float>() const { return mpfr_get_flt(value, MPFR_RNDN); }
276   template <> double as<double>() const { return mpfr_get_d(value, MPFR_RNDN); }
277   template <> long double as<long double>() const {
278     return mpfr_get_ld(value, MPFR_RNDN);
279   }
280 
281   void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
282 
283   // Return the ULP (units-in-the-last-place) difference between the
284   // stored MPFR and a floating point number.
285   //
286   // We define ULP difference as follows:
287   //   If exponents of this value and the |input| are same, then:
288   //     ULP(this_value, input) = abs(this_value - input) / eps(input)
289   //   else:
290   //     max = max(abs(this_value), abs(input))
291   //     min = min(abs(this_value), abs(input))
292   //     maxExponent = exponent(max)
293   //     ULP(this_value, input) = (max - 2^maxExponent) / eps(max) +
294   //                              (2^maxExponent - min) / eps(min)
295   //
296   // Remarks:
297   // 1. A ULP of 0.0 will imply that the value is correctly rounded.
298   // 2. We expect that this value and the value to be compared (the [input]
299   //    argument) are reasonable close, and we will provide an upper bound
300   //    of ULP value for testing.  Morever, most of the fractional parts of
301   //    ULP value do not matter much, so using double as the return type
302   //    should be good enough.
303   // 3. For close enough values (values which don't diff in their exponent by
304   //    not more than 1), a ULP difference of N indicates a bit distance
305   //    of N between this number and [input].
306   // 4. A values of +0.0 and -0.0 are treated as equal.
307   template <typename T>
308   cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, double> ulp(T input) {
309     T thisAsT = as<T>();
310     if (thisAsT == input)
311       return T(0.0);
312 
313     int thisExponent = fputil::FPBits<T>(thisAsT).get_exponent();
314     int inputExponent = fputil::FPBits<T>(input).get_exponent();
315     // Adjust the exponents for denormal numbers.
316     if (fputil::FPBits<T>(thisAsT).get_unbiased_exponent() == 0)
317       ++thisExponent;
318     if (fputil::FPBits<T>(input).get_unbiased_exponent() == 0)
319       ++inputExponent;
320 
321     if (thisAsT * input < 0 || thisExponent == inputExponent) {
322       MPFRNumber inputMPFR(input);
323       mpfr_sub(inputMPFR.value, value, inputMPFR.value, MPFR_RNDN);
324       mpfr_abs(inputMPFR.value, inputMPFR.value, MPFR_RNDN);
325       mpfr_mul_2si(inputMPFR.value, inputMPFR.value,
326                    -thisExponent + int(fputil::MantissaWidth<T>::VALUE),
327                    MPFR_RNDN);
328       return inputMPFR.as<double>();
329     }
330 
331     // If the control reaches here, it means that this number and input are
332     // of the same sign but different exponent. In such a case, ULP error is
333     // calculated as sum of two parts.
334     thisAsT = std::abs(thisAsT);
335     input = std::abs(input);
336     T min = thisAsT > input ? input : thisAsT;
337     T max = thisAsT > input ? thisAsT : input;
338     int minExponent = fputil::FPBits<T>(min).get_exponent();
339     int maxExponent = fputil::FPBits<T>(max).get_exponent();
340     // Adjust the exponents for denormal numbers.
341     if (fputil::FPBits<T>(min).get_unbiased_exponent() == 0)
342       ++minExponent;
343     if (fputil::FPBits<T>(max).get_unbiased_exponent() == 0)
344       ++maxExponent;
345 
346     MPFRNumber minMPFR(min);
347     MPFRNumber maxMPFR(max);
348 
349     MPFRNumber pivot(uint32_t(1));
350     mpfr_mul_2si(pivot.value, pivot.value, maxExponent, MPFR_RNDN);
351 
352     mpfr_sub(minMPFR.value, pivot.value, minMPFR.value, MPFR_RNDN);
353     mpfr_mul_2si(minMPFR.value, minMPFR.value,
354                  -minExponent + int(fputil::MantissaWidth<T>::VALUE),
355                  MPFR_RNDN);
356 
357     mpfr_sub(maxMPFR.value, maxMPFR.value, pivot.value, MPFR_RNDN);
358     mpfr_mul_2si(maxMPFR.value, maxMPFR.value,
359                  -maxExponent + int(fputil::MantissaWidth<T>::VALUE),
360                  MPFR_RNDN);
361 
362     mpfr_add(minMPFR.value, minMPFR.value, maxMPFR.value, MPFR_RNDN);
363     return minMPFR.as<double>();
364   }
365 };
366 
367 namespace internal {
368 
369 template <typename InputType>
370 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
371 unaryOperation(Operation op, InputType input) {
372   MPFRNumber mpfrInput(input);
373   switch (op) {
374   case Operation::Abs:
375     return mpfrInput.abs();
376   case Operation::Ceil:
377     return mpfrInput.ceil();
378   case Operation::Cos:
379     return mpfrInput.cos();
380   case Operation::Exp:
381     return mpfrInput.exp();
382   case Operation::Exp2:
383     return mpfrInput.exp2();
384   case Operation::Expm1:
385     return mpfrInput.expm1();
386   case Operation::Floor:
387     return mpfrInput.floor();
388   case Operation::Mod2PI:
389     return mpfrInput.mod_2pi();
390   case Operation::ModPIOver2:
391     return mpfrInput.mod_pi_over_2();
392   case Operation::ModPIOver4:
393     return mpfrInput.mod_pi_over_4();
394   case Operation::Round:
395     return mpfrInput.round();
396   case Operation::Sin:
397     return mpfrInput.sin();
398   case Operation::Sqrt:
399     return mpfrInput.sqrt();
400   case Operation::Tan:
401     return mpfrInput.tan();
402   case Operation::Trunc:
403     return mpfrInput.trunc();
404   default:
405     __builtin_unreachable();
406   }
407 }
408 
409 template <typename InputType>
410 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
411 unaryOperationTwoOutputs(Operation op, InputType input, int &output) {
412   MPFRNumber mpfrInput(input);
413   switch (op) {
414   case Operation::Frexp:
415     return mpfrInput.frexp(output);
416   default:
417     __builtin_unreachable();
418   }
419 }
420 
421 template <typename InputType>
422 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
423 binaryOperationOneOutput(Operation op, InputType x, InputType y) {
424   MPFRNumber inputX(x), inputY(y);
425   switch (op) {
426   case Operation::Hypot:
427     return inputX.hypot(inputY);
428   default:
429     __builtin_unreachable();
430   }
431 }
432 
433 template <typename InputType>
434 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
435 binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
436   MPFRNumber inputX(x), inputY(y);
437   switch (op) {
438   case Operation::RemQuo:
439     return inputX.remquo(inputY, output);
440   default:
441     __builtin_unreachable();
442   }
443 }
444 
445 template <typename InputType>
446 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
447 ternaryOperationOneOutput(Operation op, InputType x, InputType y, InputType z) {
448   // For FMA function, we just need to compare with the mpfr_fma with the same
449   // precision as InputType.  Using higher precision as the intermediate results
450   // to compare might incorrectly fail due to double-rounding errors.
451   constexpr unsigned int prec = Precision<InputType>::value;
452   MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec);
453   switch (op) {
454   case Operation::Fma:
455     return inputX.fma(inputY, inputZ);
456   default:
457     __builtin_unreachable();
458   }
459 }
460 
461 template <typename T>
462 void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
463                                             testutils::StreamWrapper &OS) {
464   MPFRNumber mpfrInput(input);
465   MPFRNumber mpfrResult = unaryOperation(op, input);
466   MPFRNumber mpfrMatchValue(matchValue);
467   FPBits<T> inputBits(input);
468   FPBits<T> matchBits(matchValue);
469   FPBits<T> mpfrResultBits(mpfrResult.as<T>());
470   OS << "Match value not within tolerance value of MPFR result:\n"
471      << "  Input decimal: " << mpfrInput.str() << '\n';
472   __llvm_libc::fputil::testing::describeValue("     Input bits: ", input, OS);
473   OS << '\n' << "  Match decimal: " << mpfrMatchValue.str() << '\n';
474   __llvm_libc::fputil::testing::describeValue("     Match bits: ", matchValue,
475                                               OS);
476   OS << '\n' << "    MPFR result: " << mpfrResult.str() << '\n';
477   __llvm_libc::fputil::testing::describeValue(
478       "   MPFR rounded: ", mpfrResult.as<T>(), OS);
479   OS << '\n';
480   OS << "      ULP error: " << std::to_string(mpfrResult.ulp(matchValue))
481      << '\n';
482 }
483 
484 template void
485 explainUnaryOperationSingleOutputError<float>(Operation op, float, float,
486                                               testutils::StreamWrapper &);
487 template void
488 explainUnaryOperationSingleOutputError<double>(Operation op, double, double,
489                                                testutils::StreamWrapper &);
490 template void explainUnaryOperationSingleOutputError<long double>(
491     Operation op, long double, long double, testutils::StreamWrapper &);
492 
493 template <typename T>
494 void explainUnaryOperationTwoOutputsError(Operation op, T input,
495                                           const BinaryOutput<T> &libcResult,
496                                           testutils::StreamWrapper &OS) {
497   MPFRNumber mpfrInput(input);
498   FPBits<T> inputBits(input);
499   int mpfrIntResult;
500   MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
501 
502   if (mpfrIntResult != libcResult.i) {
503     OS << "MPFR integral result: " << mpfrIntResult << '\n'
504        << "Libc integral result: " << libcResult.i << '\n';
505   } else {
506     OS << "Integral result from libc matches integral result from MPFR.\n";
507   }
508 
509   MPFRNumber mpfrMatchValue(libcResult.f);
510   OS << "Libc floating point result is not within tolerance value of the MPFR "
511      << "result.\n\n";
512 
513   OS << "            Input decimal: " << mpfrInput.str() << "\n\n";
514 
515   OS << "Libc floating point value: " << mpfrMatchValue.str() << '\n';
516   __llvm_libc::fputil::testing::describeValue(
517       " Libc floating point bits: ", libcResult.f, OS);
518   OS << "\n\n";
519 
520   OS << "              MPFR result: " << mpfrResult.str() << '\n';
521   __llvm_libc::fputil::testing::describeValue(
522       "             MPFR rounded: ", mpfrResult.as<T>(), OS);
523   OS << '\n'
524      << "                ULP error: "
525      << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
526 }
527 
528 template void explainUnaryOperationTwoOutputsError<float>(
529     Operation, float, const BinaryOutput<float> &, testutils::StreamWrapper &);
530 template void
531 explainUnaryOperationTwoOutputsError<double>(Operation, double,
532                                              const BinaryOutput<double> &,
533                                              testutils::StreamWrapper &);
534 template void explainUnaryOperationTwoOutputsError<long double>(
535     Operation, long double, const BinaryOutput<long double> &,
536     testutils::StreamWrapper &);
537 
538 template <typename T>
539 void explainBinaryOperationTwoOutputsError(Operation op,
540                                            const BinaryInput<T> &input,
541                                            const BinaryOutput<T> &libcResult,
542                                            testutils::StreamWrapper &OS) {
543   MPFRNumber mpfrX(input.x);
544   MPFRNumber mpfrY(input.y);
545   FPBits<T> xbits(input.x);
546   FPBits<T> ybits(input.y);
547   int mpfrIntResult;
548   MPFRNumber mpfrResult =
549       binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
550   MPFRNumber mpfrMatchValue(libcResult.f);
551 
552   OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n'
553      << "MPFR integral result: " << mpfrIntResult << '\n'
554      << "Libc integral result: " << libcResult.i << '\n'
555      << "Libc floating point result: " << mpfrMatchValue.str() << '\n'
556      << "               MPFR result: " << mpfrResult.str() << '\n';
557   __llvm_libc::fputil::testing::describeValue(
558       "Libc floating point result bits: ", libcResult.f, OS);
559   __llvm_libc::fputil::testing::describeValue(
560       "              MPFR rounded bits: ", mpfrResult.as<T>(), OS);
561   OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
562 }
563 
564 template void explainBinaryOperationTwoOutputsError<float>(
565     Operation, const BinaryInput<float> &, const BinaryOutput<float> &,
566     testutils::StreamWrapper &);
567 template void explainBinaryOperationTwoOutputsError<double>(
568     Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
569     testutils::StreamWrapper &);
570 template void explainBinaryOperationTwoOutputsError<long double>(
571     Operation, const BinaryInput<long double> &,
572     const BinaryOutput<long double> &, testutils::StreamWrapper &);
573 
574 template <typename T>
575 void explainBinaryOperationOneOutputError(Operation op,
576                                           const BinaryInput<T> &input,
577                                           T libcResult,
578                                           testutils::StreamWrapper &OS) {
579   MPFRNumber mpfrX(input.x);
580   MPFRNumber mpfrY(input.y);
581   FPBits<T> xbits(input.x);
582   FPBits<T> ybits(input.y);
583   MPFRNumber mpfrResult = binaryOperationOneOutput(op, input.x, input.y);
584   MPFRNumber mpfrMatchValue(libcResult);
585 
586   OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n';
587   __llvm_libc::fputil::testing::describeValue("First input bits: ", input.x,
588                                               OS);
589   __llvm_libc::fputil::testing::describeValue("Second input bits: ", input.y,
590                                               OS);
591 
592   OS << "Libc result: " << mpfrMatchValue.str() << '\n'
593      << "MPFR result: " << mpfrResult.str() << '\n';
594   __llvm_libc::fputil::testing::describeValue(
595       "Libc floating point result bits: ", libcResult, OS);
596   __llvm_libc::fputil::testing::describeValue(
597       "              MPFR rounded bits: ", mpfrResult.as<T>(), OS);
598   OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult)) << '\n';
599 }
600 
601 template void explainBinaryOperationOneOutputError<float>(
602     Operation, const BinaryInput<float> &, float, testutils::StreamWrapper &);
603 template void explainBinaryOperationOneOutputError<double>(
604     Operation, const BinaryInput<double> &, double, testutils::StreamWrapper &);
605 template void explainBinaryOperationOneOutputError<long double>(
606     Operation, const BinaryInput<long double> &, long double,
607     testutils::StreamWrapper &);
608 
609 template <typename T>
610 void explainTernaryOperationOneOutputError(Operation op,
611                                            const TernaryInput<T> &input,
612                                            T libcResult,
613                                            testutils::StreamWrapper &OS) {
614   MPFRNumber mpfrX(input.x, Precision<T>::value);
615   MPFRNumber mpfrY(input.y, Precision<T>::value);
616   MPFRNumber mpfrZ(input.z, Precision<T>::value);
617   FPBits<T> xbits(input.x);
618   FPBits<T> ybits(input.y);
619   FPBits<T> zbits(input.z);
620   MPFRNumber mpfrResult =
621       ternaryOperationOneOutput(op, input.x, input.y, input.z);
622   MPFRNumber mpfrMatchValue(libcResult);
623 
624   OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str()
625      << " z: " << mpfrZ.str() << '\n';
626   __llvm_libc::fputil::testing::describeValue("First input bits: ", input.x,
627                                               OS);
628   __llvm_libc::fputil::testing::describeValue("Second input bits: ", input.y,
629                                               OS);
630   __llvm_libc::fputil::testing::describeValue("Third input bits: ", input.z,
631                                               OS);
632 
633   OS << "Libc result: " << mpfrMatchValue.str() << '\n'
634      << "MPFR result: " << mpfrResult.str() << '\n';
635   __llvm_libc::fputil::testing::describeValue(
636       "Libc floating point result bits: ", libcResult, OS);
637   __llvm_libc::fputil::testing::describeValue(
638       "              MPFR rounded bits: ", mpfrResult.as<T>(), OS);
639   OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult)) << '\n';
640 }
641 
642 template void explainTernaryOperationOneOutputError<float>(
643     Operation, const TernaryInput<float> &, float, testutils::StreamWrapper &);
644 template void explainTernaryOperationOneOutputError<double>(
645     Operation, const TernaryInput<double> &, double,
646     testutils::StreamWrapper &);
647 template void explainTernaryOperationOneOutputError<long double>(
648     Operation, const TernaryInput<long double> &, long double,
649     testutils::StreamWrapper &);
650 
651 template <typename T>
652 bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
653                                        double ulpError) {
654   // If the ulp error is exactly 0.5 (i.e a tie), we would check that the result
655   // is rounded to the nearest even.
656   MPFRNumber mpfrResult = unaryOperation(op, input);
657   double ulp = mpfrResult.ulp(libcResult);
658   bool bitsAreEven = ((FPBits<T>(libcResult).uintval() & 1) == 0);
659   return (ulp < ulpError) ||
660          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
661 }
662 
663 template bool compareUnaryOperationSingleOutput<float>(Operation, float, float,
664                                                        double);
665 template bool compareUnaryOperationSingleOutput<double>(Operation, double,
666                                                         double, double);
667 template bool compareUnaryOperationSingleOutput<long double>(Operation,
668                                                              long double,
669                                                              long double,
670                                                              double);
671 
672 template <typename T>
673 bool compareUnaryOperationTwoOutputs(Operation op, T input,
674                                      const BinaryOutput<T> &libcResult,
675                                      double ulpError) {
676   int mpfrIntResult;
677   MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
678   double ulp = mpfrResult.ulp(libcResult.f);
679 
680   if (mpfrIntResult != libcResult.i)
681     return false;
682 
683   bool bitsAreEven = ((FPBits<T>(libcResult.f).uintval() & 1) == 0);
684   return (ulp < ulpError) ||
685          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
686 }
687 
688 template bool
689 compareUnaryOperationTwoOutputs<float>(Operation, float,
690                                        const BinaryOutput<float> &, double);
691 template bool
692 compareUnaryOperationTwoOutputs<double>(Operation, double,
693                                         const BinaryOutput<double> &, double);
694 template bool compareUnaryOperationTwoOutputs<long double>(
695     Operation, long double, const BinaryOutput<long double> &, double);
696 
697 template <typename T>
698 bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
699                                       const BinaryOutput<T> &libcResult,
700                                       double ulpError) {
701   int mpfrIntResult;
702   MPFRNumber mpfrResult =
703       binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
704   double ulp = mpfrResult.ulp(libcResult.f);
705 
706   if (mpfrIntResult != libcResult.i) {
707     if (op == Operation::RemQuo) {
708       if ((0x7 & mpfrIntResult) != (0x7 & libcResult.i))
709         return false;
710     } else {
711       return false;
712     }
713   }
714 
715   bool bitsAreEven = ((FPBits<T>(libcResult.f).uintval() & 1) == 0);
716   return (ulp < ulpError) ||
717          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
718 }
719 
720 template bool
721 compareBinaryOperationTwoOutputs<float>(Operation, const BinaryInput<float> &,
722                                         const BinaryOutput<float> &, double);
723 template bool
724 compareBinaryOperationTwoOutputs<double>(Operation, const BinaryInput<double> &,
725                                          const BinaryOutput<double> &, double);
726 template bool compareBinaryOperationTwoOutputs<long double>(
727     Operation, const BinaryInput<long double> &,
728     const BinaryOutput<long double> &, double);
729 
730 template <typename T>
731 bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input,
732                                      T libcResult, double ulpError) {
733   MPFRNumber mpfrResult = binaryOperationOneOutput(op, input.x, input.y);
734   double ulp = mpfrResult.ulp(libcResult);
735 
736   bool bitsAreEven = ((FPBits<T>(libcResult).uintval() & 1) == 0);
737   return (ulp < ulpError) ||
738          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
739 }
740 
741 template bool compareBinaryOperationOneOutput<float>(Operation,
742                                                      const BinaryInput<float> &,
743                                                      float, double);
744 template bool
745 compareBinaryOperationOneOutput<double>(Operation, const BinaryInput<double> &,
746                                         double, double);
747 template bool compareBinaryOperationOneOutput<long double>(
748     Operation, const BinaryInput<long double> &, long double, double);
749 
750 template <typename T>
751 bool compareTernaryOperationOneOutput(Operation op,
752                                       const TernaryInput<T> &input,
753                                       T libcResult, double ulpError) {
754   MPFRNumber mpfrResult =
755       ternaryOperationOneOutput(op, input.x, input.y, input.z);
756   double ulp = mpfrResult.ulp(libcResult);
757 
758   bool bitsAreEven = ((FPBits<T>(libcResult).uintval() & 1) == 0);
759   return (ulp < ulpError) ||
760          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
761 }
762 
763 template bool
764 compareTernaryOperationOneOutput<float>(Operation, const TernaryInput<float> &,
765                                         float, double);
766 template bool compareTernaryOperationOneOutput<double>(
767     Operation, const TernaryInput<double> &, double, double);
768 template bool compareTernaryOperationOneOutput<long double>(
769     Operation, const TernaryInput<long double> &, long double, double);
770 
771 static mpfr_rnd_t getMPFRRoundingMode(RoundingMode mode) {
772   switch (mode) {
773   case RoundingMode::Upward:
774     return MPFR_RNDU;
775     break;
776   case RoundingMode::Downward:
777     return MPFR_RNDD;
778     break;
779   case RoundingMode::TowardZero:
780     return MPFR_RNDZ;
781     break;
782   case RoundingMode::Nearest:
783     return MPFR_RNDN;
784     break;
785   }
786 }
787 
788 } // namespace internal
789 
790 template <typename T> bool RoundToLong(T x, long &result) {
791   MPFRNumber mpfr(x);
792   return mpfr.roundToLong(result);
793 }
794 
795 template bool RoundToLong<float>(float, long &);
796 template bool RoundToLong<double>(double, long &);
797 template bool RoundToLong<long double>(long double, long &);
798 
799 template <typename T> bool RoundToLong(T x, RoundingMode mode, long &result) {
800   MPFRNumber mpfr(x);
801   return mpfr.roundToLong(internal::getMPFRRoundingMode(mode), result);
802 }
803 
804 template bool RoundToLong<float>(float, RoundingMode, long &);
805 template bool RoundToLong<double>(double, RoundingMode, long &);
806 template bool RoundToLong<long double>(long double, RoundingMode, long &);
807 
808 template <typename T> T Round(T x, RoundingMode mode) {
809   MPFRNumber mpfr(x);
810   MPFRNumber result = mpfr.rint(internal::getMPFRRoundingMode(mode));
811   return result.as<T>();
812 }
813 
814 template float Round<float>(float, RoundingMode);
815 template double Round<double>(double, RoundingMode);
816 template long double Round<long double>(long double, RoundingMode);
817 
818 } // namespace mpfr
819 } // namespace testing
820 } // namespace __llvm_libc
821