1 //===--- Float16bits.cpp - supports 2-byte floats ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements f16 and bf16 to support the compilation and execution 10 // of programs using these types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/Float16bits.h" 15 #include <cmath> 16 #include <cstring> 17 18 namespace { 19 20 // Union used to make the int/float aliasing explicit so we can access the raw 21 // bits. 22 union Float32Bits { 23 uint32_t u; 24 float f; 25 }; 26 27 const uint32_t kF32MantiBits = 23; 28 const uint32_t kF32HalfMantiBitDiff = 13; 29 const uint32_t kF32HalfBitDiff = 16; 30 const Float32Bits kF32Magic = {113 << kF32MantiBits}; 31 const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; 32 33 // Constructs the 16 bit representation for a half precision value from a float 34 // value. This implementation is adapted from Eigen. 35 uint16_t float2half(float floatValue) { 36 const Float32Bits inf = {255 << kF32MantiBits}; 37 const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; 38 const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) 39 << kF32MantiBits}; 40 uint32_t signMask = 0x80000000u; 41 uint16_t halfValue = static_cast<uint16_t>(0x0u); 42 Float32Bits f; 43 f.f = floatValue; 44 uint32_t sign = f.u & signMask; 45 f.u ^= sign; 46 47 if (f.u >= f16max.u) { 48 const uint32_t halfQnan = 0x7e00; 49 const uint32_t halfInf = 0x7c00; 50 // Inf or NaN (all exponent bits set). 51 halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf 52 } else { 53 // (De)normalized number or zero. 54 if (f.u < kF32Magic.u) { 55 // The resulting FP16 is subnormal or zero. 56 // 57 // Use a magic value to align our 10 mantissa bits at the bottom of the 58 // float. As long as FP addition is round-to-nearest-even this works. 59 f.f += denormMagic.f; 60 61 halfValue = static_cast<uint16_t>(f.u - denormMagic.u); 62 } else { 63 uint32_t mantOdd = 64 (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. 65 66 // Update exponent, rounding bias part 1. The following expressions are 67 // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + 68 // 0xfff`, but without arithmetic overflow. 69 f.u += 0xc8000fffU; 70 // Rounding bias part 2. 71 f.u += mantOdd; 72 halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff); 73 } 74 } 75 76 halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff); 77 return halfValue; 78 } 79 80 // Converts the 16 bit representation of a half precision value to a float 81 // value. This implementation is adapted from Eigen. 82 float half2float(uint16_t halfValue) { 83 const uint32_t shiftedExp = 84 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. 85 86 // Initialize the float representation with the exponent/mantissa bits. 87 Float32Bits f = { 88 static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; 89 const uint32_t exp = shiftedExp & f.u; 90 f.u += kF32HalfExpAdjust; // Adjust the exponent 91 92 // Handle exponent special cases. 93 if (exp == shiftedExp) { 94 // Inf/NaN 95 f.u += kF32HalfExpAdjust; 96 } else if (exp == 0) { 97 // Zero/Denormal? 98 f.u += 1 << kF32MantiBits; 99 f.f -= kF32Magic.f; 100 } 101 102 f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. 103 return f.f; 104 } 105 106 const uint32_t kF32BfMantiBitDiff = 16; 107 108 // Constructs the 16 bit representation for a bfloat value from a float value. 109 // This implementation is adapted from Eigen. 110 uint16_t float2bfloat(float floatValue) { 111 if (std::isnan(floatValue)) 112 return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0; 113 114 Float32Bits floatBits; 115 floatBits.f = floatValue; 116 uint16_t bfloatBits; 117 118 // Least significant bit of resulting bfloat. 119 uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; 120 uint32_t roundingBias = 0x7fff + lsb; 121 floatBits.u += roundingBias; 122 bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff); 123 return bfloatBits; 124 } 125 126 // Converts the 16 bit representation of a bfloat value to a float value. This 127 // implementation is adapted from Eigen. 128 float bfloat2float(uint16_t bfloatBits) { 129 Float32Bits floatBits; 130 floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff; 131 return floatBits.f; 132 } 133 134 } // namespace 135 136 f16::f16(float f) : bits(float2half(f)) {} 137 138 bf16::bf16(float f) : bits(float2bfloat(f)) {} 139 140 std::ostream &operator<<(std::ostream &os, const f16 &f) { 141 os << half2float(f.bits); 142 return os; 143 } 144 145 std::ostream &operator<<(std::ostream &os, const bf16 &d) { 146 os << bfloat2float(d.bits); 147 return os; 148 } 149 150 // Mark these symbols as weak so they don't conflict when compiler-rt also 151 // defines them. 152 #define ATTR_WEAK 153 #ifdef __has_attribute 154 #if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \ 155 !defined(_WIN32) 156 #undef ATTR_WEAK 157 #define ATTR_WEAK __attribute__((__weak__)) 158 #endif 159 #endif 160 161 #if defined(__x86_64__) 162 // On x86 bfloat16 is passed in SSE2 registers. Since both float and _Float16 163 // are passed in the same register we can use the wider type and careful casting 164 // to conform to x86_64 psABI. This only works with the assumption that we're 165 // dealing with little-endian values passed in wider registers. 166 using BF16ABIType = float; 167 #else 168 // Default to uint16_t if we have nothing else. 169 using BF16ABIType = uint16_t; 170 #endif 171 172 // Provide a float->bfloat conversion routine in case the runtime doesn't have 173 // one. 174 extern "C" BF16ABIType ATTR_WEAK __truncsfbf2(float f) { 175 uint16_t bf = float2bfloat(f); 176 // The output can be a float type, bitcast it from uint16_t. 177 BF16ABIType ret = 0; 178 std::memcpy(&ret, &bf, sizeof(bf)); 179 return ret; 180 } 181 182 // Provide a double->bfloat conversion routine in case the runtime doesn't have 183 // one. 184 extern "C" BF16ABIType ATTR_WEAK __truncdfbf2(double d) { 185 // This does a double rounding step, but it's precise enough for our use 186 // cases. 187 uint16_t bf = __truncsfbf2(static_cast<float>(d)); 188 // The output can be a float type, bitcast it from uint16_t. 189 BF16ABIType ret = 0; 190 std::memcpy(&ret, &bf, sizeof(bf)); 191 return ret; 192 } 193