1 //===-- Implementation of fmaf function -----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "src/__support/common.h" 10 11 #include "utils/FPUtil/FEnv.h" 12 #include "utils/FPUtil/FPBits.h" 13 14 namespace __llvm_libc { 15 16 float LLVM_LIBC_ENTRYPOINT(fmaf)(float x, float y, float z) { 17 // Product is exact. 18 double prod = static_cast<double>(x) * static_cast<double>(y); 19 double z_d = static_cast<double>(z); 20 double sum = prod + z_d; 21 fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum); 22 23 if (!(bit_sum.isInfOrNaN() || bit_sum.isZero())) { 24 // Since the sum is computed in double precision, rounding might happen 25 // (for instance, when bitz.exponent > bit_prod.exponent + 5, or 26 // bit_prod.exponent > bitz.exponent + 40). In that case, when we round 27 // the sum back to float, double rounding error might occur. 28 // A concrete example of this phenomenon is as follows: 29 // x = y = 1 + 2^(-12), z = 2^(-53) 30 // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53) 31 // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23) 32 // On the other hand, with the default rounding mode, 33 // double(x*y + z) = 1 + 2^(-11) + 2^(-24) 34 // and casting again to float gives us: 35 // float(double(x*y + z)) = 1 + 2^(-11). 36 // 37 // In order to correct this possible double rounding error, first we use 38 // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly, 39 // assuming the (default) rounding mode is round-to-the-nearest, 40 // tie-to-even. Moreover, t satisfies the condition that t < eps(sum), 41 // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding 42 // occurs when computing the sum, we just need to use t to adjust (any) last 43 // bit of sum, so that the sticky bits used when rounding sum to float are 44 // correct (when it matters). 45 fputil::FPBits<double> t( 46 (bit_prod.exponent >= bitz.exponent) 47 ? ((static_cast<double>(bit_sum) - bit_prod) - bitz) 48 : ((static_cast<double>(bit_sum) - bitz) - bit_prod)); 49 50 // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are 51 // zero. 52 if (!t.isZero() && ((bit_sum.mantissa & 0xfff'ffffULL) == 0)) { 53 if (bit_sum.sign != t.sign) { 54 ++bit_sum.mantissa; 55 } else if (bit_sum.mantissa) { 56 --bit_sum.mantissa; 57 } 58 } 59 } 60 61 return static_cast<float>(static_cast<double>(bit_sum)); 62 } 63 64 } // namespace __llvm_libc 65