1 //===--- Float16bits.cpp - supports 2-byte floats ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements f16 and bf16 to support the compilation and execution 10 // of programs using these types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/Float16bits.h" 15 #include <cmath> 16 17 namespace { 18 19 // Union used to make the int/float aliasing explicit so we can access the raw 20 // bits. 21 union Float32Bits { 22 uint32_t u; 23 float f; 24 }; 25 26 const uint32_t kF32MantiBits = 23; 27 const uint32_t kF32HalfMantiBitDiff = 13; 28 const uint32_t kF32HalfBitDiff = 16; 29 const Float32Bits kF32Magic = {113 << kF32MantiBits}; 30 const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; 31 32 // Constructs the 16 bit representation for a half precision value from a float 33 // value. This implementation is adapted from Eigen. 34 uint16_t float2half(float floatValue) { 35 const Float32Bits inf = {255 << kF32MantiBits}; 36 const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; 37 const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) 38 << kF32MantiBits}; 39 uint32_t signMask = 0x80000000u; 40 uint16_t halfValue = static_cast<uint16_t>(0x0u); 41 Float32Bits f; 42 f.f = floatValue; 43 uint32_t sign = f.u & signMask; 44 f.u ^= sign; 45 46 if (f.u >= f16max.u) { 47 const uint32_t halfQnan = 0x7e00; 48 const uint32_t halfInf = 0x7c00; 49 // Inf or NaN (all exponent bits set). 50 halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf 51 } else { 52 // (De)normalized number or zero. 53 if (f.u < kF32Magic.u) { 54 // The resulting FP16 is subnormal or zero. 55 // 56 // Use a magic value to align our 10 mantissa bits at the bottom of the 57 // float. As long as FP addition is round-to-nearest-even this works. 58 f.f += denormMagic.f; 59 60 halfValue = static_cast<uint16_t>(f.u - denormMagic.u); 61 } else { 62 uint32_t mantOdd = 63 (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. 64 65 // Update exponent, rounding bias part 1. The following expressions are 66 // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + 67 // 0xfff`, but without arithmetic overflow. 68 f.u += 0xc8000fffU; 69 // Rounding bias part 2. 70 f.u += mantOdd; 71 halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff); 72 } 73 } 74 75 halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff); 76 return halfValue; 77 } 78 79 // Converts the 16 bit representation of a half precision value to a float 80 // value. This implementation is adapted from Eigen. 81 float half2float(uint16_t halfValue) { 82 const uint32_t shiftedExp = 83 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. 84 85 // Initialize the float representation with the exponent/mantissa bits. 86 Float32Bits f = { 87 static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; 88 const uint32_t exp = shiftedExp & f.u; 89 f.u += kF32HalfExpAdjust; // Adjust the exponent 90 91 // Handle exponent special cases. 92 if (exp == shiftedExp) { 93 // Inf/NaN 94 f.u += kF32HalfExpAdjust; 95 } else if (exp == 0) { 96 // Zero/Denormal? 97 f.u += 1 << kF32MantiBits; 98 f.f -= kF32Magic.f; 99 } 100 101 f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. 102 return f.f; 103 } 104 105 const uint32_t kF32BfMantiBitDiff = 16; 106 107 // Constructs the 16 bit representation for a bfloat value from a float value. 108 // This implementation is adapted from Eigen. 109 uint16_t float2bfloat(float floatValue) { 110 if (std::isnan(floatValue)) 111 return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0; 112 113 Float32Bits floatBits; 114 floatBits.f = floatValue; 115 uint16_t bfloatBits; 116 117 // Least significant bit of resulting bfloat. 118 uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; 119 uint32_t roundingBias = 0x7fff + lsb; 120 floatBits.u += roundingBias; 121 bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff); 122 return bfloatBits; 123 } 124 125 // Converts the 16 bit representation of a bfloat value to a float value. This 126 // implementation is adapted from Eigen. 127 float bfloat2float(uint16_t bfloatBits) { 128 Float32Bits floatBits; 129 floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff; 130 return floatBits.f; 131 } 132 133 } // namespace 134 135 f16::f16(float f) : bits(float2half(f)) {} 136 137 bf16::bf16(float f) : bits(float2bfloat(f)) {} 138 139 std::ostream &operator<<(std::ostream &os, const f16 &f) { 140 os << half2float(f.bits); 141 return os; 142 } 143 144 std::ostream &operator<<(std::ostream &os, const bf16 &d) { 145 os << bfloat2float(d.bits); 146 return os; 147 } 148 149 // Provide a float->bfloat conversion routine in case the runtime doesn't have 150 // one. 151 extern "C" uint16_t 152 #ifdef __has_attribute 153 #if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \ 154 !defined(_WIN32) 155 __attribute__((__weak__)) 156 #endif 157 #endif 158 __truncsfbf2(float f) { 159 return float2bfloat(f); 160 } 161 162 // Provide a double->bfloat conversion routine in case the runtime doesn't have 163 // one. 164 extern "C" uint16_t 165 #ifdef __has_attribute 166 #if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \ 167 !defined(_WIN32) 168 __attribute__((__weak__)) 169 #endif 170 #endif 171 __truncdfbf2(double d) { 172 // This does a double rounding step, but it's precise enough for our use 173 // cases. 174 return __truncsfbf2(static_cast<float>(d)); 175 } 176