1 //===-- Utils which wrap MPFR ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MPFRUtils.h"
10 
11 #include "src/__support/CPP/StringView.h"
12 #include "src/__support/FPUtil/FPBits.h"
13 #include "src/__support/architectures.h"
14 #include "utils/UnitTest/FPMatcher.h"
15 
16 #include <cmath>
17 #include <memory>
18 #include <stdint.h>
19 #include <string>
20 
21 #ifdef CUSTOM_MPFR_INCLUDER
22 // Some downstream repos are monoliths carrying MPFR sources in their third
23 // party directory. In such repos, including the MPFR header as
24 // `#include <mpfr.h>` is either disallowed or not possible. If that is the
25 // case, a file named `CustomMPFRIncluder.h` should be added through which the
26 // MPFR header can be included in manner allowed in that repo.
27 #include "CustomMPFRIncluder.h"
28 #else
29 #include <mpfr.h>
30 #endif
31 
32 template <typename T> using FPBits = __llvm_libc::fputil::FPBits<T>;
33 
34 namespace __llvm_libc {
35 namespace testing {
36 namespace mpfr {
37 
38 template <typename T> struct Precision;
39 
40 template <> struct Precision<float> {
41   static constexpr unsigned int value = 24;
42 };
43 
44 template <> struct Precision<double> {
45   static constexpr unsigned int value = 53;
46 };
47 
48 #if !(defined(LLVM_LIBC_ARCH_X86))
49 template <> struct Precision<long double> {
50   static constexpr unsigned int value = 64;
51 };
52 #else
53 template <> struct Precision<long double> {
54   static constexpr unsigned int value = 113;
55 };
56 #endif
57 
58 class MPFRNumber {
59   // A precision value which allows sufficiently large additional
60   // precision even compared to quad-precision floating point values.
61   unsigned int mpfrPrecision;
62 
63   mpfr_t value;
64 
65 public:
66   MPFRNumber() : mpfrPrecision(256) { mpfr_init2(value, mpfrPrecision); }
67 
68   // We use explicit EnableIf specializations to disallow implicit
69   // conversions. Implicit conversions can potentially lead to loss of
70   // precision.
71   template <typename XType,
72             cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
73   explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
74     mpfr_init2(value, mpfrPrecision);
75     mpfr_set_flt(value, x, MPFR_RNDN);
76   }
77 
78   template <typename XType,
79             cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
80   explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
81     mpfr_init2(value, mpfrPrecision);
82     mpfr_set_d(value, x, MPFR_RNDN);
83   }
84 
85   template <typename XType,
86             cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
87   explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
88     mpfr_init2(value, mpfrPrecision);
89     mpfr_set_ld(value, x, MPFR_RNDN);
90   }
91 
92   template <typename XType,
93             cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
94   explicit MPFRNumber(XType x, int precision = 128) : mpfrPrecision(precision) {
95     mpfr_init2(value, mpfrPrecision);
96     mpfr_set_sj(value, x, MPFR_RNDN);
97   }
98 
99   MPFRNumber(const MPFRNumber &other) : mpfrPrecision(other.mpfrPrecision) {
100     mpfr_init2(value, mpfrPrecision);
101     mpfr_set(value, other.value, MPFR_RNDN);
102   }
103 
104   ~MPFRNumber() { mpfr_clear(value); }
105 
106   MPFRNumber &operator=(const MPFRNumber &rhs) {
107     mpfrPrecision = rhs.mpfrPrecision;
108     mpfr_set(value, rhs.value, MPFR_RNDN);
109     return *this;
110   }
111 
112   MPFRNumber abs() const {
113     MPFRNumber result;
114     mpfr_abs(result.value, value, MPFR_RNDN);
115     return result;
116   }
117 
118   MPFRNumber ceil() const {
119     MPFRNumber result;
120     mpfr_ceil(result.value, value);
121     return result;
122   }
123 
124   MPFRNumber cos() const {
125     MPFRNumber result;
126     mpfr_cos(result.value, value, MPFR_RNDN);
127     return result;
128   }
129 
130   MPFRNumber exp() const {
131     MPFRNumber result;
132     mpfr_exp(result.value, value, MPFR_RNDN);
133     return result;
134   }
135 
136   MPFRNumber exp2() const {
137     MPFRNumber result;
138     mpfr_exp2(result.value, value, MPFR_RNDN);
139     return result;
140   }
141 
142   MPFRNumber expm1() const {
143     MPFRNumber result;
144     mpfr_expm1(result.value, value, MPFR_RNDN);
145     return result;
146   }
147 
148   MPFRNumber floor() const {
149     MPFRNumber result;
150     mpfr_floor(result.value, value);
151     return result;
152   }
153 
154   MPFRNumber frexp(int &exp) {
155     MPFRNumber result;
156     mpfr_exp_t resultExp;
157     mpfr_frexp(&resultExp, result.value, value, MPFR_RNDN);
158     exp = resultExp;
159     return result;
160   }
161 
162   MPFRNumber hypot(const MPFRNumber &b) {
163     MPFRNumber result;
164     mpfr_hypot(result.value, value, b.value, MPFR_RNDN);
165     return result;
166   }
167 
168   MPFRNumber log() const {
169     MPFRNumber result;
170     mpfr_log(result.value, value, MPFR_RNDN);
171     return result;
172   }
173 
174   MPFRNumber remquo(const MPFRNumber &divisor, int &quotient) {
175     MPFRNumber remainder;
176     long q;
177     mpfr_remquo(remainder.value, &q, value, divisor.value, MPFR_RNDN);
178     quotient = q;
179     return remainder;
180   }
181 
182   MPFRNumber round() const {
183     MPFRNumber result;
184     mpfr_round(result.value, value);
185     return result;
186   }
187 
188   bool roundToLong(long &result) const {
189     // We first calculate the rounded value. This way, when converting
190     // to long using mpfr_get_si, the rounding direction of MPFR_RNDN
191     // (or any other rounding mode), does not have an influence.
192     MPFRNumber roundedValue = round();
193     mpfr_clear_erangeflag();
194     result = mpfr_get_si(roundedValue.value, MPFR_RNDN);
195     return mpfr_erangeflag_p();
196   }
197 
198   bool roundToLong(mpfr_rnd_t rnd, long &result) const {
199     MPFRNumber rint_result;
200     mpfr_rint(rint_result.value, value, rnd);
201     return rint_result.roundToLong(result);
202   }
203 
204   MPFRNumber rint(mpfr_rnd_t rnd) const {
205     MPFRNumber result;
206     mpfr_rint(result.value, value, rnd);
207     return result;
208   }
209 
210   MPFRNumber mod_2pi() const {
211     MPFRNumber result(0.0, 1280);
212     MPFRNumber _2pi(0.0, 1280);
213     mpfr_const_pi(_2pi.value, MPFR_RNDN);
214     mpfr_mul_si(_2pi.value, _2pi.value, 2, MPFR_RNDN);
215     mpfr_fmod(result.value, value, _2pi.value, MPFR_RNDN);
216     return result;
217   }
218 
219   MPFRNumber mod_pi_over_2() const {
220     MPFRNumber result(0.0, 1280);
221     MPFRNumber pi_over_2(0.0, 1280);
222     mpfr_const_pi(pi_over_2.value, MPFR_RNDN);
223     mpfr_mul_d(pi_over_2.value, pi_over_2.value, 0.5, MPFR_RNDN);
224     mpfr_fmod(result.value, value, pi_over_2.value, MPFR_RNDN);
225     return result;
226   }
227 
228   MPFRNumber mod_pi_over_4() const {
229     MPFRNumber result(0.0, 1280);
230     MPFRNumber pi_over_4(0.0, 1280);
231     mpfr_const_pi(pi_over_4.value, MPFR_RNDN);
232     mpfr_mul_d(pi_over_4.value, pi_over_4.value, 0.25, MPFR_RNDN);
233     mpfr_fmod(result.value, value, pi_over_4.value, MPFR_RNDN);
234     return result;
235   }
236 
237   MPFRNumber sin() const {
238     MPFRNumber result;
239     mpfr_sin(result.value, value, MPFR_RNDN);
240     return result;
241   }
242 
243   MPFRNumber sqrt() const {
244     MPFRNumber result;
245     mpfr_sqrt(result.value, value, MPFR_RNDN);
246     return result;
247   }
248 
249   MPFRNumber tan() const {
250     MPFRNumber result;
251     mpfr_tan(result.value, value, MPFR_RNDN);
252     return result;
253   }
254 
255   MPFRNumber trunc() const {
256     MPFRNumber result;
257     mpfr_trunc(result.value, value);
258     return result;
259   }
260 
261   MPFRNumber fma(const MPFRNumber &b, const MPFRNumber &c) {
262     MPFRNumber result(*this);
263     mpfr_fma(result.value, value, b.value, c.value, MPFR_RNDN);
264     return result;
265   }
266 
267   std::string str() const {
268     // 200 bytes should be more than sufficient to hold a 100-digit number
269     // plus additional bytes for the decimal point, '-' sign etc.
270     constexpr size_t printBufSize = 200;
271     char buffer[printBufSize];
272     mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value);
273     cpp::StringView view(buffer);
274     view = view.trim(' ');
275     return std::string(view.data());
276   }
277 
278   // These functions are useful for debugging.
279   template <typename T> T as() const;
280 
281   template <> float as<float>() const { return mpfr_get_flt(value, MPFR_RNDN); }
282   template <> double as<double>() const { return mpfr_get_d(value, MPFR_RNDN); }
283   template <> long double as<long double>() const {
284     return mpfr_get_ld(value, MPFR_RNDN);
285   }
286 
287   void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
288 
289   // Return the ULP (units-in-the-last-place) difference between the
290   // stored MPFR and a floating point number.
291   //
292   // We define ULP difference as follows:
293   //   If exponents of this value and the |input| are same, then:
294   //     ULP(this_value, input) = abs(this_value - input) / eps(input)
295   //   else:
296   //     max = max(abs(this_value), abs(input))
297   //     min = min(abs(this_value), abs(input))
298   //     maxExponent = exponent(max)
299   //     ULP(this_value, input) = (max - 2^maxExponent) / eps(max) +
300   //                              (2^maxExponent - min) / eps(min)
301   //
302   // Remarks:
303   // 1. A ULP of 0.0 will imply that the value is correctly rounded.
304   // 2. We expect that this value and the value to be compared (the [input]
305   //    argument) are reasonable close, and we will provide an upper bound
306   //    of ULP value for testing.  Morever, most of the fractional parts of
307   //    ULP value do not matter much, so using double as the return type
308   //    should be good enough.
309   // 3. For close enough values (values which don't diff in their exponent by
310   //    not more than 1), a ULP difference of N indicates a bit distance
311   //    of N between this number and [input].
312   // 4. A values of +0.0 and -0.0 are treated as equal.
313   template <typename T>
314   cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, double> ulp(T input) {
315     T thisAsT = as<T>();
316     if (thisAsT == input)
317       return T(0.0);
318 
319     int thisExponent = fputil::FPBits<T>(thisAsT).get_exponent();
320     int inputExponent = fputil::FPBits<T>(input).get_exponent();
321     // Adjust the exponents for denormal numbers.
322     if (fputil::FPBits<T>(thisAsT).get_unbiased_exponent() == 0)
323       ++thisExponent;
324     if (fputil::FPBits<T>(input).get_unbiased_exponent() == 0)
325       ++inputExponent;
326 
327     if (thisAsT * input < 0 || thisExponent == inputExponent) {
328       MPFRNumber inputMPFR(input);
329       mpfr_sub(inputMPFR.value, value, inputMPFR.value, MPFR_RNDN);
330       mpfr_abs(inputMPFR.value, inputMPFR.value, MPFR_RNDN);
331       mpfr_mul_2si(inputMPFR.value, inputMPFR.value,
332                    -thisExponent + int(fputil::MantissaWidth<T>::VALUE),
333                    MPFR_RNDN);
334       return inputMPFR.as<double>();
335     }
336 
337     // If the control reaches here, it means that this number and input are
338     // of the same sign but different exponent. In such a case, ULP error is
339     // calculated as sum of two parts.
340     thisAsT = std::abs(thisAsT);
341     input = std::abs(input);
342     T min = thisAsT > input ? input : thisAsT;
343     T max = thisAsT > input ? thisAsT : input;
344     int minExponent = fputil::FPBits<T>(min).get_exponent();
345     int maxExponent = fputil::FPBits<T>(max).get_exponent();
346     // Adjust the exponents for denormal numbers.
347     if (fputil::FPBits<T>(min).get_unbiased_exponent() == 0)
348       ++minExponent;
349     if (fputil::FPBits<T>(max).get_unbiased_exponent() == 0)
350       ++maxExponent;
351 
352     MPFRNumber minMPFR(min);
353     MPFRNumber maxMPFR(max);
354 
355     MPFRNumber pivot(uint32_t(1));
356     mpfr_mul_2si(pivot.value, pivot.value, maxExponent, MPFR_RNDN);
357 
358     mpfr_sub(minMPFR.value, pivot.value, minMPFR.value, MPFR_RNDN);
359     mpfr_mul_2si(minMPFR.value, minMPFR.value,
360                  -minExponent + int(fputil::MantissaWidth<T>::VALUE),
361                  MPFR_RNDN);
362 
363     mpfr_sub(maxMPFR.value, maxMPFR.value, pivot.value, MPFR_RNDN);
364     mpfr_mul_2si(maxMPFR.value, maxMPFR.value,
365                  -maxExponent + int(fputil::MantissaWidth<T>::VALUE),
366                  MPFR_RNDN);
367 
368     mpfr_add(minMPFR.value, minMPFR.value, maxMPFR.value, MPFR_RNDN);
369     return minMPFR.as<double>();
370   }
371 };
372 
373 namespace internal {
374 
375 template <typename InputType>
376 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
377 unaryOperation(Operation op, InputType input) {
378   MPFRNumber mpfrInput(input);
379   switch (op) {
380   case Operation::Abs:
381     return mpfrInput.abs();
382   case Operation::Ceil:
383     return mpfrInput.ceil();
384   case Operation::Cos:
385     return mpfrInput.cos();
386   case Operation::Exp:
387     return mpfrInput.exp();
388   case Operation::Exp2:
389     return mpfrInput.exp2();
390   case Operation::Expm1:
391     return mpfrInput.expm1();
392   case Operation::Floor:
393     return mpfrInput.floor();
394   case Operation::Log:
395     return mpfrInput.log();
396   case Operation::Mod2PI:
397     return mpfrInput.mod_2pi();
398   case Operation::ModPIOver2:
399     return mpfrInput.mod_pi_over_2();
400   case Operation::ModPIOver4:
401     return mpfrInput.mod_pi_over_4();
402   case Operation::Round:
403     return mpfrInput.round();
404   case Operation::Sin:
405     return mpfrInput.sin();
406   case Operation::Sqrt:
407     return mpfrInput.sqrt();
408   case Operation::Tan:
409     return mpfrInput.tan();
410   case Operation::Trunc:
411     return mpfrInput.trunc();
412   default:
413     __builtin_unreachable();
414   }
415 }
416 
417 template <typename InputType>
418 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
419 unaryOperationTwoOutputs(Operation op, InputType input, int &output) {
420   MPFRNumber mpfrInput(input);
421   switch (op) {
422   case Operation::Frexp:
423     return mpfrInput.frexp(output);
424   default:
425     __builtin_unreachable();
426   }
427 }
428 
429 template <typename InputType>
430 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
431 binaryOperationOneOutput(Operation op, InputType x, InputType y) {
432   MPFRNumber inputX(x), inputY(y);
433   switch (op) {
434   case Operation::Hypot:
435     return inputX.hypot(inputY);
436   default:
437     __builtin_unreachable();
438   }
439 }
440 
441 template <typename InputType>
442 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
443 binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
444   MPFRNumber inputX(x), inputY(y);
445   switch (op) {
446   case Operation::RemQuo:
447     return inputX.remquo(inputY, output);
448   default:
449     __builtin_unreachable();
450   }
451 }
452 
453 template <typename InputType>
454 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
455 ternaryOperationOneOutput(Operation op, InputType x, InputType y, InputType z) {
456   // For FMA function, we just need to compare with the mpfr_fma with the same
457   // precision as InputType.  Using higher precision as the intermediate results
458   // to compare might incorrectly fail due to double-rounding errors.
459   constexpr unsigned int prec = Precision<InputType>::value;
460   MPFRNumber inputX(x, prec), inputY(y, prec), inputZ(z, prec);
461   switch (op) {
462   case Operation::Fma:
463     return inputX.fma(inputY, inputZ);
464   default:
465     __builtin_unreachable();
466   }
467 }
468 
469 template <typename T>
470 void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
471                                             testutils::StreamWrapper &OS) {
472   MPFRNumber mpfrInput(input);
473   MPFRNumber mpfrResult = unaryOperation(op, input);
474   MPFRNumber mpfrMatchValue(matchValue);
475   FPBits<T> inputBits(input);
476   FPBits<T> matchBits(matchValue);
477   FPBits<T> mpfrResultBits(mpfrResult.as<T>());
478   OS << "Match value not within tolerance value of MPFR result:\n"
479      << "  Input decimal: " << mpfrInput.str() << '\n';
480   __llvm_libc::fputil::testing::describeValue("     Input bits: ", input, OS);
481   OS << '\n' << "  Match decimal: " << mpfrMatchValue.str() << '\n';
482   __llvm_libc::fputil::testing::describeValue("     Match bits: ", matchValue,
483                                               OS);
484   OS << '\n' << "    MPFR result: " << mpfrResult.str() << '\n';
485   __llvm_libc::fputil::testing::describeValue(
486       "   MPFR rounded: ", mpfrResult.as<T>(), OS);
487   OS << '\n';
488   OS << "      ULP error: " << std::to_string(mpfrResult.ulp(matchValue))
489      << '\n';
490 }
491 
492 template void
493 explainUnaryOperationSingleOutputError<float>(Operation op, float, float,
494                                               testutils::StreamWrapper &);
495 template void
496 explainUnaryOperationSingleOutputError<double>(Operation op, double, double,
497                                                testutils::StreamWrapper &);
498 template void explainUnaryOperationSingleOutputError<long double>(
499     Operation op, long double, long double, testutils::StreamWrapper &);
500 
501 template <typename T>
502 void explainUnaryOperationTwoOutputsError(Operation op, T input,
503                                           const BinaryOutput<T> &libcResult,
504                                           testutils::StreamWrapper &OS) {
505   MPFRNumber mpfrInput(input);
506   FPBits<T> inputBits(input);
507   int mpfrIntResult;
508   MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
509 
510   if (mpfrIntResult != libcResult.i) {
511     OS << "MPFR integral result: " << mpfrIntResult << '\n'
512        << "Libc integral result: " << libcResult.i << '\n';
513   } else {
514     OS << "Integral result from libc matches integral result from MPFR.\n";
515   }
516 
517   MPFRNumber mpfrMatchValue(libcResult.f);
518   OS << "Libc floating point result is not within tolerance value of the MPFR "
519      << "result.\n\n";
520 
521   OS << "            Input decimal: " << mpfrInput.str() << "\n\n";
522 
523   OS << "Libc floating point value: " << mpfrMatchValue.str() << '\n';
524   __llvm_libc::fputil::testing::describeValue(
525       " Libc floating point bits: ", libcResult.f, OS);
526   OS << "\n\n";
527 
528   OS << "              MPFR result: " << mpfrResult.str() << '\n';
529   __llvm_libc::fputil::testing::describeValue(
530       "             MPFR rounded: ", mpfrResult.as<T>(), OS);
531   OS << '\n'
532      << "                ULP error: "
533      << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
534 }
535 
536 template void explainUnaryOperationTwoOutputsError<float>(
537     Operation, float, const BinaryOutput<float> &, testutils::StreamWrapper &);
538 template void
539 explainUnaryOperationTwoOutputsError<double>(Operation, double,
540                                              const BinaryOutput<double> &,
541                                              testutils::StreamWrapper &);
542 template void explainUnaryOperationTwoOutputsError<long double>(
543     Operation, long double, const BinaryOutput<long double> &,
544     testutils::StreamWrapper &);
545 
546 template <typename T>
547 void explainBinaryOperationTwoOutputsError(Operation op,
548                                            const BinaryInput<T> &input,
549                                            const BinaryOutput<T> &libcResult,
550                                            testutils::StreamWrapper &OS) {
551   MPFRNumber mpfrX(input.x);
552   MPFRNumber mpfrY(input.y);
553   FPBits<T> xbits(input.x);
554   FPBits<T> ybits(input.y);
555   int mpfrIntResult;
556   MPFRNumber mpfrResult =
557       binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
558   MPFRNumber mpfrMatchValue(libcResult.f);
559 
560   OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n'
561      << "MPFR integral result: " << mpfrIntResult << '\n'
562      << "Libc integral result: " << libcResult.i << '\n'
563      << "Libc floating point result: " << mpfrMatchValue.str() << '\n'
564      << "               MPFR result: " << mpfrResult.str() << '\n';
565   __llvm_libc::fputil::testing::describeValue(
566       "Libc floating point result bits: ", libcResult.f, OS);
567   __llvm_libc::fputil::testing::describeValue(
568       "              MPFR rounded bits: ", mpfrResult.as<T>(), OS);
569   OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult.f)) << '\n';
570 }
571 
572 template void explainBinaryOperationTwoOutputsError<float>(
573     Operation, const BinaryInput<float> &, const BinaryOutput<float> &,
574     testutils::StreamWrapper &);
575 template void explainBinaryOperationTwoOutputsError<double>(
576     Operation, const BinaryInput<double> &, const BinaryOutput<double> &,
577     testutils::StreamWrapper &);
578 template void explainBinaryOperationTwoOutputsError<long double>(
579     Operation, const BinaryInput<long double> &,
580     const BinaryOutput<long double> &, testutils::StreamWrapper &);
581 
582 template <typename T>
583 void explainBinaryOperationOneOutputError(Operation op,
584                                           const BinaryInput<T> &input,
585                                           T libcResult,
586                                           testutils::StreamWrapper &OS) {
587   MPFRNumber mpfrX(input.x);
588   MPFRNumber mpfrY(input.y);
589   FPBits<T> xbits(input.x);
590   FPBits<T> ybits(input.y);
591   MPFRNumber mpfrResult = binaryOperationOneOutput(op, input.x, input.y);
592   MPFRNumber mpfrMatchValue(libcResult);
593 
594   OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n';
595   __llvm_libc::fputil::testing::describeValue("First input bits: ", input.x,
596                                               OS);
597   __llvm_libc::fputil::testing::describeValue("Second input bits: ", input.y,
598                                               OS);
599 
600   OS << "Libc result: " << mpfrMatchValue.str() << '\n'
601      << "MPFR result: " << mpfrResult.str() << '\n';
602   __llvm_libc::fputil::testing::describeValue(
603       "Libc floating point result bits: ", libcResult, OS);
604   __llvm_libc::fputil::testing::describeValue(
605       "              MPFR rounded bits: ", mpfrResult.as<T>(), OS);
606   OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult)) << '\n';
607 }
608 
609 template void explainBinaryOperationOneOutputError<float>(
610     Operation, const BinaryInput<float> &, float, testutils::StreamWrapper &);
611 template void explainBinaryOperationOneOutputError<double>(
612     Operation, const BinaryInput<double> &, double, testutils::StreamWrapper &);
613 template void explainBinaryOperationOneOutputError<long double>(
614     Operation, const BinaryInput<long double> &, long double,
615     testutils::StreamWrapper &);
616 
617 template <typename T>
618 void explainTernaryOperationOneOutputError(Operation op,
619                                            const TernaryInput<T> &input,
620                                            T libcResult,
621                                            testutils::StreamWrapper &OS) {
622   MPFRNumber mpfrX(input.x, Precision<T>::value);
623   MPFRNumber mpfrY(input.y, Precision<T>::value);
624   MPFRNumber mpfrZ(input.z, Precision<T>::value);
625   FPBits<T> xbits(input.x);
626   FPBits<T> ybits(input.y);
627   FPBits<T> zbits(input.z);
628   MPFRNumber mpfrResult =
629       ternaryOperationOneOutput(op, input.x, input.y, input.z);
630   MPFRNumber mpfrMatchValue(libcResult);
631 
632   OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str()
633      << " z: " << mpfrZ.str() << '\n';
634   __llvm_libc::fputil::testing::describeValue("First input bits: ", input.x,
635                                               OS);
636   __llvm_libc::fputil::testing::describeValue("Second input bits: ", input.y,
637                                               OS);
638   __llvm_libc::fputil::testing::describeValue("Third input bits: ", input.z,
639                                               OS);
640 
641   OS << "Libc result: " << mpfrMatchValue.str() << '\n'
642      << "MPFR result: " << mpfrResult.str() << '\n';
643   __llvm_libc::fputil::testing::describeValue(
644       "Libc floating point result bits: ", libcResult, OS);
645   __llvm_libc::fputil::testing::describeValue(
646       "              MPFR rounded bits: ", mpfrResult.as<T>(), OS);
647   OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult)) << '\n';
648 }
649 
650 template void explainTernaryOperationOneOutputError<float>(
651     Operation, const TernaryInput<float> &, float, testutils::StreamWrapper &);
652 template void explainTernaryOperationOneOutputError<double>(
653     Operation, const TernaryInput<double> &, double,
654     testutils::StreamWrapper &);
655 template void explainTernaryOperationOneOutputError<long double>(
656     Operation, const TernaryInput<long double> &, long double,
657     testutils::StreamWrapper &);
658 
659 template <typename T>
660 bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
661                                        double ulpError) {
662   // If the ulp error is exactly 0.5 (i.e a tie), we would check that the result
663   // is rounded to the nearest even.
664   MPFRNumber mpfrResult = unaryOperation(op, input);
665   double ulp = mpfrResult.ulp(libcResult);
666   bool bitsAreEven = ((FPBits<T>(libcResult).uintval() & 1) == 0);
667   return (ulp < ulpError) ||
668          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
669 }
670 
671 template bool compareUnaryOperationSingleOutput<float>(Operation, float, float,
672                                                        double);
673 template bool compareUnaryOperationSingleOutput<double>(Operation, double,
674                                                         double, double);
675 template bool compareUnaryOperationSingleOutput<long double>(Operation,
676                                                              long double,
677                                                              long double,
678                                                              double);
679 
680 template <typename T>
681 bool compareUnaryOperationTwoOutputs(Operation op, T input,
682                                      const BinaryOutput<T> &libcResult,
683                                      double ulpError) {
684   int mpfrIntResult;
685   MPFRNumber mpfrResult = unaryOperationTwoOutputs(op, input, mpfrIntResult);
686   double ulp = mpfrResult.ulp(libcResult.f);
687 
688   if (mpfrIntResult != libcResult.i)
689     return false;
690 
691   bool bitsAreEven = ((FPBits<T>(libcResult.f).uintval() & 1) == 0);
692   return (ulp < ulpError) ||
693          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
694 }
695 
696 template bool
697 compareUnaryOperationTwoOutputs<float>(Operation, float,
698                                        const BinaryOutput<float> &, double);
699 template bool
700 compareUnaryOperationTwoOutputs<double>(Operation, double,
701                                         const BinaryOutput<double> &, double);
702 template bool compareUnaryOperationTwoOutputs<long double>(
703     Operation, long double, const BinaryOutput<long double> &, double);
704 
705 template <typename T>
706 bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input,
707                                       const BinaryOutput<T> &libcResult,
708                                       double ulpError) {
709   int mpfrIntResult;
710   MPFRNumber mpfrResult =
711       binaryOperationTwoOutputs(op, input.x, input.y, mpfrIntResult);
712   double ulp = mpfrResult.ulp(libcResult.f);
713 
714   if (mpfrIntResult != libcResult.i) {
715     if (op == Operation::RemQuo) {
716       if ((0x7 & mpfrIntResult) != (0x7 & libcResult.i))
717         return false;
718     } else {
719       return false;
720     }
721   }
722 
723   bool bitsAreEven = ((FPBits<T>(libcResult.f).uintval() & 1) == 0);
724   return (ulp < ulpError) ||
725          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
726 }
727 
728 template bool
729 compareBinaryOperationTwoOutputs<float>(Operation, const BinaryInput<float> &,
730                                         const BinaryOutput<float> &, double);
731 template bool
732 compareBinaryOperationTwoOutputs<double>(Operation, const BinaryInput<double> &,
733                                          const BinaryOutput<double> &, double);
734 template bool compareBinaryOperationTwoOutputs<long double>(
735     Operation, const BinaryInput<long double> &,
736     const BinaryOutput<long double> &, double);
737 
738 template <typename T>
739 bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input,
740                                      T libcResult, double ulpError) {
741   MPFRNumber mpfrResult = binaryOperationOneOutput(op, input.x, input.y);
742   double ulp = mpfrResult.ulp(libcResult);
743 
744   bool bitsAreEven = ((FPBits<T>(libcResult).uintval() & 1) == 0);
745   return (ulp < ulpError) ||
746          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
747 }
748 
749 template bool compareBinaryOperationOneOutput<float>(Operation,
750                                                      const BinaryInput<float> &,
751                                                      float, double);
752 template bool
753 compareBinaryOperationOneOutput<double>(Operation, const BinaryInput<double> &,
754                                         double, double);
755 template bool compareBinaryOperationOneOutput<long double>(
756     Operation, const BinaryInput<long double> &, long double, double);
757 
758 template <typename T>
759 bool compareTernaryOperationOneOutput(Operation op,
760                                       const TernaryInput<T> &input,
761                                       T libcResult, double ulpError) {
762   MPFRNumber mpfrResult =
763       ternaryOperationOneOutput(op, input.x, input.y, input.z);
764   double ulp = mpfrResult.ulp(libcResult);
765 
766   bool bitsAreEven = ((FPBits<T>(libcResult).uintval() & 1) == 0);
767   return (ulp < ulpError) ||
768          ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
769 }
770 
771 template bool
772 compareTernaryOperationOneOutput<float>(Operation, const TernaryInput<float> &,
773                                         float, double);
774 template bool compareTernaryOperationOneOutput<double>(
775     Operation, const TernaryInput<double> &, double, double);
776 template bool compareTernaryOperationOneOutput<long double>(
777     Operation, const TernaryInput<long double> &, long double, double);
778 
779 static mpfr_rnd_t getMPFRRoundingMode(RoundingMode mode) {
780   switch (mode) {
781   case RoundingMode::Upward:
782     return MPFR_RNDU;
783     break;
784   case RoundingMode::Downward:
785     return MPFR_RNDD;
786     break;
787   case RoundingMode::TowardZero:
788     return MPFR_RNDZ;
789     break;
790   case RoundingMode::Nearest:
791     return MPFR_RNDN;
792     break;
793   }
794 }
795 
796 } // namespace internal
797 
798 template <typename T> bool RoundToLong(T x, long &result) {
799   MPFRNumber mpfr(x);
800   return mpfr.roundToLong(result);
801 }
802 
803 template bool RoundToLong<float>(float, long &);
804 template bool RoundToLong<double>(double, long &);
805 template bool RoundToLong<long double>(long double, long &);
806 
807 template <typename T> bool RoundToLong(T x, RoundingMode mode, long &result) {
808   MPFRNumber mpfr(x);
809   return mpfr.roundToLong(internal::getMPFRRoundingMode(mode), result);
810 }
811 
812 template bool RoundToLong<float>(float, RoundingMode, long &);
813 template bool RoundToLong<double>(double, RoundingMode, long &);
814 template bool RoundToLong<long double>(long double, RoundingMode, long &);
815 
816 template <typename T> T Round(T x, RoundingMode mode) {
817   MPFRNumber mpfr(x);
818   MPFRNumber result = mpfr.rint(internal::getMPFRRoundingMode(mode));
819   return result.as<T>();
820 }
821 
822 template float Round<float>(float, RoundingMode);
823 template double Round<double>(double, RoundingMode);
824 template long double Round<long double>(long double, RoundingMode);
825 
826 } // namespace mpfr
827 } // namespace testing
828 } // namespace __llvm_libc
829