1 //===--- Float16bits.cpp - supports 2-byte floats ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements f16 and bf16 to support the compilation and execution 10 // of programs using these types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/ExecutionEngine/Float16bits.h" 15 #include "llvm/Support/Compiler.h" 16 17 namespace { 18 19 // Union used to make the int/float aliasing explicit so we can access the raw 20 // bits. 21 union Float32Bits { 22 uint32_t u; 23 float f; 24 }; 25 26 const uint32_t kF32MantiBits = 23; 27 const uint32_t kF32HalfMantiBitDiff = 13; 28 const uint32_t kF32HalfBitDiff = 16; 29 const Float32Bits kF32Magic = {113 << kF32MantiBits}; 30 const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; 31 32 // Constructs the 16 bit representation for a half precision value from a float 33 // value. This implementation is adapted from Eigen. 34 uint16_t float2half(float floatValue) { 35 const Float32Bits inf = {255 << kF32MantiBits}; 36 const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; 37 const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) 38 << kF32MantiBits}; 39 uint32_t signMask = 0x80000000u; 40 uint16_t halfValue = static_cast<uint16_t>(0x0u); 41 Float32Bits f; 42 f.f = floatValue; 43 uint32_t sign = f.u & signMask; 44 f.u ^= sign; 45 46 if (f.u >= f16max.u) { 47 const uint32_t halfQnan = 0x7e00; 48 const uint32_t halfInf = 0x7c00; 49 // Inf or NaN (all exponent bits set). 50 halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf 51 } else { 52 // (De)normalized number or zero. 53 if (f.u < kF32Magic.u) { 54 // The resulting FP16 is subnormal or zero. 55 // 56 // Use a magic value to align our 10 mantissa bits at the bottom of the 57 // float. As long as FP addition is round-to-nearest-even this works. 58 f.f += denormMagic.f; 59 60 halfValue = static_cast<uint16_t>(f.u - denormMagic.u); 61 } else { 62 uint32_t mantOdd = 63 (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. 64 65 // Update exponent, rounding bias part 1. The following expressions are 66 // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + 67 // 0xfff`, but without arithmetic overflow. 68 f.u += 0xc8000fffU; 69 // Rounding bias part 2. 70 f.u += mantOdd; 71 halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff); 72 } 73 } 74 75 halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff); 76 return halfValue; 77 } 78 79 // Converts the 16 bit representation of a half precision value to a float 80 // value. This implementation is adapted from Eigen. 81 float half2float(uint16_t halfValue) { 82 const uint32_t shiftedExp = 83 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. 84 85 // Initialize the float representation with the exponent/mantissa bits. 86 Float32Bits f = { 87 static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; 88 const uint32_t exp = shiftedExp & f.u; 89 f.u += kF32HalfExpAdjust; // Adjust the exponent 90 91 // Handle exponent special cases. 92 if (exp == shiftedExp) { 93 // Inf/NaN 94 f.u += kF32HalfExpAdjust; 95 } else if (exp == 0) { 96 // Zero/Denormal? 97 f.u += 1 << kF32MantiBits; 98 f.f -= kF32Magic.f; 99 } 100 101 f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. 102 return f.f; 103 } 104 105 const uint32_t kF32BfMantiBitDiff = 16; 106 107 // Constructs the 16 bit representation for a bfloat value from a float value. 108 // This implementation is adapted from Eigen. 109 uint16_t float2bfloat(float floatValue) { 110 Float32Bits floatBits; 111 floatBits.f = floatValue; 112 uint16_t bfloatBits; 113 114 // Least significant bit of resulting bfloat. 115 uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; 116 uint32_t rounding_bias = 0x7fff + lsb; 117 floatBits.u += rounding_bias; 118 bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff); 119 return bfloatBits; 120 } 121 122 // Converts the 16 bit representation of a bfloat value to a float value. This 123 // implementation is adapted from Eigen. 124 float bfloat2float(uint16_t bfloatBits) { 125 Float32Bits floatBits; 126 floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff; 127 return floatBits.f; 128 } 129 130 } // namespace 131 132 f16::f16(float f) : bits(float2half(f)) {} 133 134 bf16::bf16(float f) : bits(float2bfloat(f)) {} 135 136 std::ostream &operator<<(std::ostream &os, const f16 &f) { 137 os << half2float(f.bits); 138 return os; 139 } 140 141 std::ostream &operator<<(std::ostream &os, const bf16 &d) { 142 os << bfloat2float(d.bits); 143 return os; 144 } 145 146 // Provide a float->bfloat conversion routine in case the runtime doesn't have 147 // one. 148 extern "C" uint16_t LLVM_ATTRIBUTE_WEAK __truncsfbf2(float f) { 149 return float2bfloat(f); 150 } 151 152 // Provide a double->bfloat conversion routine in case the runtime doesn't have 153 // one. 154 extern "C" uint16_t LLVM_ATTRIBUTE_WEAK __truncdfbf2(double d) { 155 // This does a double rounding step, but it's precise enough for our use 156 // cases. 157 return __truncsfbf2(static_cast<float>(d)); 158 } 159