1 //===--- Float16bits.cpp - supports 2-byte floats ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements f16 and bf16 to support the compilation and execution 10 // of programs using these types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/Float16bits.h" 15 16 namespace { 17 18 // Union used to make the int/float aliasing explicit so we can access the raw 19 // bits. 20 union Float32Bits { 21 uint32_t u; 22 float f; 23 }; 24 25 const uint32_t kF32MantiBits = 23; 26 const uint32_t kF32HalfMantiBitDiff = 13; 27 const uint32_t kF32HalfBitDiff = 16; 28 const Float32Bits kF32Magic = {113 << kF32MantiBits}; 29 const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; 30 31 // Constructs the 16 bit representation for a half precision value from a float 32 // value. This implementation is adapted from Eigen. 33 uint16_t float2half(float floatValue) { 34 const Float32Bits inf = {255 << kF32MantiBits}; 35 const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; 36 const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) 37 << kF32MantiBits}; 38 uint32_t signMask = 0x80000000u; 39 uint16_t halfValue = static_cast<uint16_t>(0x0u); 40 Float32Bits f; 41 f.f = floatValue; 42 uint32_t sign = f.u & signMask; 43 f.u ^= sign; 44 45 if (f.u >= f16max.u) { 46 const uint32_t halfQnan = 0x7e00; 47 const uint32_t halfInf = 0x7c00; 48 // Inf or NaN (all exponent bits set). 49 halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf 50 } else { 51 // (De)normalized number or zero. 52 if (f.u < kF32Magic.u) { 53 // The resulting FP16 is subnormal or zero. 54 // 55 // Use a magic value to align our 10 mantissa bits at the bottom of the 56 // float. As long as FP addition is round-to-nearest-even this works. 57 f.f += denormMagic.f; 58 59 halfValue = static_cast<uint16_t>(f.u - denormMagic.u); 60 } else { 61 uint32_t mantOdd = 62 (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. 63 64 // Update exponent, rounding bias part 1. The following expressions are 65 // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + 66 // 0xfff`, but without arithmetic overflow. 67 f.u += 0xc8000fffU; 68 // Rounding bias part 2. 69 f.u += mantOdd; 70 halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff); 71 } 72 } 73 74 halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff); 75 return halfValue; 76 } 77 78 // Converts the 16 bit representation of a half precision value to a float 79 // value. This implementation is adapted from Eigen. 80 float half2float(uint16_t halfValue) { 81 const uint32_t shiftedExp = 82 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. 83 84 // Initialize the float representation with the exponent/mantissa bits. 85 Float32Bits f = { 86 static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; 87 const uint32_t exp = shiftedExp & f.u; 88 f.u += kF32HalfExpAdjust; // Adjust the exponent 89 90 // Handle exponent special cases. 91 if (exp == shiftedExp) { 92 // Inf/NaN 93 f.u += kF32HalfExpAdjust; 94 } else if (exp == 0) { 95 // Zero/Denormal? 96 f.u += 1 << kF32MantiBits; 97 f.f -= kF32Magic.f; 98 } 99 100 f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. 101 return f.f; 102 } 103 104 const uint32_t kF32BfMantiBitDiff = 16; 105 106 // Constructs the 16 bit representation for a bfloat value from a float value. 107 // This implementation is adapted from Eigen. 108 uint16_t float2bfloat(float floatValue) { 109 Float32Bits floatBits; 110 floatBits.f = floatValue; 111 uint16_t bfloatBits; 112 113 // Least significant bit of resulting bfloat. 114 uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; 115 uint32_t rounding_bias = 0x7fff + lsb; 116 floatBits.u += rounding_bias; 117 bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff); 118 return bfloatBits; 119 } 120 121 // Converts the 16 bit representation of a bfloat value to a float value. This 122 // implementation is adapted from Eigen. 123 float bfloat2float(uint16_t bfloatBits) { 124 Float32Bits floatBits; 125 floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff; 126 return floatBits.f; 127 } 128 129 } // namespace 130 131 f16::f16(float f) : bits(float2half(f)) {} 132 133 bf16::bf16(float f) : bits(float2bfloat(f)) {} 134 135 std::ostream &operator<<(std::ostream &os, const f16 &f) { 136 os << half2float(f.bits); 137 return os; 138 } 139 140 std::ostream &operator<<(std::ostream &os, const bf16 &d) { 141 os << bfloat2float(d.bits); 142 return os; 143 } 144 145 // Provide a float->bfloat conversion routine in case the runtime doesn't have 146 // one. 147 extern "C" uint16_t 148 #ifdef __has_attribute 149 #if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \ 150 !defined(_WIN32) 151 __attribute__((__weak__)) 152 #endif 153 #endif 154 __truncsfbf2(float f) { 155 return float2bfloat(f); 156 } 157 158 // Provide a double->bfloat conversion routine in case the runtime doesn't have 159 // one. 160 extern "C" uint16_t 161 #ifdef __has_attribute 162 #if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \ 163 !defined(_WIN32) 164 __attribute__((__weak__)) 165 #endif 166 #endif 167 __truncdfbf2(double d) { 168 // This does a double rounding step, but it's precise enough for our use 169 // cases. 170 return __truncsfbf2(static_cast<float>(d)); 171 } 172