xref: /llvm-project-15.0.7/libc/src/math/fmaf.cpp (revision dd5165a9)
1 //===-- Implementation of fmaf function -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "src/__support/common.h"
10 
11 #include "utils/FPUtil/FEnv.h"
12 #include "utils/FPUtil/FPBits.h"
13 
14 namespace __llvm_libc {
15 
16 float LLVM_LIBC_ENTRYPOINT(fmaf)(float x, float y, float z) {
17   // Product is exact.
18   double prod = static_cast<double>(x) * static_cast<double>(y);
19   double z_d = static_cast<double>(z);
20   double sum = prod + z_d;
21   fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
22 
23   if (!(bit_sum.isInfOrNaN() || bit_sum.isZero())) {
24     // Since the sum is computed in double precision, rounding might happen
25     // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
26     // bit_prod.exponent > bitz.exponent + 40).  In that case, when we round
27     // the sum back to float, double rounding error might occur.
28     // A concrete example of this phenomenon is as follows:
29     //   x = y = 1 + 2^(-12), z = 2^(-53)
30     // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
31     // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
32     // On the other hand, with the default rounding mode,
33     //   double(x*y + z) = 1 + 2^(-11) + 2^(-24)
34     // and casting again to float gives us:
35     //   float(double(x*y + z)) = 1 + 2^(-11).
36     //
37     // In order to correct this possible double rounding error, first we use
38     // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
39     // assuming the (default) rounding mode is round-to-the-nearest,
40     // tie-to-even.  Moreover, t satisfies the condition that t < eps(sum),
41     // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
42     // occurs when computing the sum, we just need to use t to adjust (any) last
43     // bit of sum, so that the sticky bits used when rounding sum to float are
44     // correct (when it matters).
45     fputil::FPBits<double> t(
46         (bit_prod.exponent >= bitz.exponent)
47             ? ((static_cast<double>(bit_sum) - bit_prod) - bitz)
48             : ((static_cast<double>(bit_sum) - bitz) - bit_prod));
49 
50     // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
51     // zero.
52     if (!t.isZero() && ((bit_sum.mantissa & 0xfff'ffffULL) == 0)) {
53       if (bit_sum.sign != t.sign) {
54         ++bit_sum.mantissa;
55       } else if (bit_sum.mantissa) {
56         --bit_sum.mantissa;
57       }
58     }
59   }
60 
61   return static_cast<float>(static_cast<double>(bit_sum));
62 }
63 
64 } // namespace __llvm_libc
65