1ea8ed5cbSbixia1 //===--- Float16bits.cpp - supports 2-byte floats ------------------------===//
2ea8ed5cbSbixia1 //
3ea8ed5cbSbixia1 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ea8ed5cbSbixia1 // See https://llvm.org/LICENSE.txt for license information.
5ea8ed5cbSbixia1 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ea8ed5cbSbixia1 //
7ea8ed5cbSbixia1 //===----------------------------------------------------------------------===//
8ea8ed5cbSbixia1 //
9ea8ed5cbSbixia1 // This file implements f16 and bf16 to support the compilation and execution
10ea8ed5cbSbixia1 // of programs using these types.
11ea8ed5cbSbixia1 //
12ea8ed5cbSbixia1 //===----------------------------------------------------------------------===//
13ea8ed5cbSbixia1
14ea8ed5cbSbixia1 #include "mlir/ExecutionEngine/Float16bits.h"
15b3127769SBenjamin Kramer #include <cmath>
1623637ca0SBenjamin Kramer #include <cstring>
17ea8ed5cbSbixia1
18ea8ed5cbSbixia1 namespace {
19ea8ed5cbSbixia1
20ea8ed5cbSbixia1 // Union used to make the int/float aliasing explicit so we can access the raw
21ea8ed5cbSbixia1 // bits.
22ea8ed5cbSbixia1 union Float32Bits {
23ea8ed5cbSbixia1 uint32_t u;
24ea8ed5cbSbixia1 float f;
25ea8ed5cbSbixia1 };
26ea8ed5cbSbixia1
27ea8ed5cbSbixia1 const uint32_t kF32MantiBits = 23;
28ea8ed5cbSbixia1 const uint32_t kF32HalfMantiBitDiff = 13;
29ea8ed5cbSbixia1 const uint32_t kF32HalfBitDiff = 16;
30ea8ed5cbSbixia1 const Float32Bits kF32Magic = {113 << kF32MantiBits};
31ea8ed5cbSbixia1 const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits;
32ea8ed5cbSbixia1
33ea8ed5cbSbixia1 // Constructs the 16 bit representation for a half precision value from a float
34ea8ed5cbSbixia1 // value. This implementation is adapted from Eigen.
float2half(float floatValue)35ea8ed5cbSbixia1 uint16_t float2half(float floatValue) {
36ea8ed5cbSbixia1 const Float32Bits inf = {255 << kF32MantiBits};
37ea8ed5cbSbixia1 const Float32Bits f16max = {(127 + 16) << kF32MantiBits};
38ea8ed5cbSbixia1 const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1)
39ea8ed5cbSbixia1 << kF32MantiBits};
40ea8ed5cbSbixia1 uint32_t signMask = 0x80000000u;
41ea8ed5cbSbixia1 uint16_t halfValue = static_cast<uint16_t>(0x0u);
42ea8ed5cbSbixia1 Float32Bits f;
43ea8ed5cbSbixia1 f.f = floatValue;
44ea8ed5cbSbixia1 uint32_t sign = f.u & signMask;
45ea8ed5cbSbixia1 f.u ^= sign;
46ea8ed5cbSbixia1
47ea8ed5cbSbixia1 if (f.u >= f16max.u) {
48ea8ed5cbSbixia1 const uint32_t halfQnan = 0x7e00;
49ea8ed5cbSbixia1 const uint32_t halfInf = 0x7c00;
50ea8ed5cbSbixia1 // Inf or NaN (all exponent bits set).
51ea8ed5cbSbixia1 halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf
52ea8ed5cbSbixia1 } else {
53ea8ed5cbSbixia1 // (De)normalized number or zero.
54ea8ed5cbSbixia1 if (f.u < kF32Magic.u) {
55ea8ed5cbSbixia1 // The resulting FP16 is subnormal or zero.
56ea8ed5cbSbixia1 //
57ea8ed5cbSbixia1 // Use a magic value to align our 10 mantissa bits at the bottom of the
58ea8ed5cbSbixia1 // float. As long as FP addition is round-to-nearest-even this works.
59ea8ed5cbSbixia1 f.f += denormMagic.f;
60ea8ed5cbSbixia1
61ea8ed5cbSbixia1 halfValue = static_cast<uint16_t>(f.u - denormMagic.u);
62ea8ed5cbSbixia1 } else {
63ea8ed5cbSbixia1 uint32_t mantOdd =
64ea8ed5cbSbixia1 (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd.
65ea8ed5cbSbixia1
66ea8ed5cbSbixia1 // Update exponent, rounding bias part 1. The following expressions are
67ea8ed5cbSbixia1 // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) +
68ea8ed5cbSbixia1 // 0xfff`, but without arithmetic overflow.
69ea8ed5cbSbixia1 f.u += 0xc8000fffU;
70ea8ed5cbSbixia1 // Rounding bias part 2.
71ea8ed5cbSbixia1 f.u += mantOdd;
72ea8ed5cbSbixia1 halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff);
73ea8ed5cbSbixia1 }
74ea8ed5cbSbixia1 }
75ea8ed5cbSbixia1
76ea8ed5cbSbixia1 halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff);
77ea8ed5cbSbixia1 return halfValue;
78ea8ed5cbSbixia1 }
79ea8ed5cbSbixia1
80ea8ed5cbSbixia1 // Converts the 16 bit representation of a half precision value to a float
81ea8ed5cbSbixia1 // value. This implementation is adapted from Eigen.
half2float(uint16_t halfValue)82ea8ed5cbSbixia1 float half2float(uint16_t halfValue) {
83ea8ed5cbSbixia1 const uint32_t shiftedExp =
84ea8ed5cbSbixia1 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift.
85ea8ed5cbSbixia1
86ea8ed5cbSbixia1 // Initialize the float representation with the exponent/mantissa bits.
87ea8ed5cbSbixia1 Float32Bits f = {
88ea8ed5cbSbixia1 static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)};
89ea8ed5cbSbixia1 const uint32_t exp = shiftedExp & f.u;
90ea8ed5cbSbixia1 f.u += kF32HalfExpAdjust; // Adjust the exponent
91ea8ed5cbSbixia1
92ea8ed5cbSbixia1 // Handle exponent special cases.
93ea8ed5cbSbixia1 if (exp == shiftedExp) {
94ea8ed5cbSbixia1 // Inf/NaN
95ea8ed5cbSbixia1 f.u += kF32HalfExpAdjust;
96ea8ed5cbSbixia1 } else if (exp == 0) {
97ea8ed5cbSbixia1 // Zero/Denormal?
98ea8ed5cbSbixia1 f.u += 1 << kF32MantiBits;
99ea8ed5cbSbixia1 f.f -= kF32Magic.f;
100ea8ed5cbSbixia1 }
101ea8ed5cbSbixia1
102ea8ed5cbSbixia1 f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit.
103ea8ed5cbSbixia1 return f.f;
104ea8ed5cbSbixia1 }
105ea8ed5cbSbixia1
106ea8ed5cbSbixia1 const uint32_t kF32BfMantiBitDiff = 16;
107ea8ed5cbSbixia1
108ea8ed5cbSbixia1 // Constructs the 16 bit representation for a bfloat value from a float value.
109ea8ed5cbSbixia1 // This implementation is adapted from Eigen.
float2bfloat(float floatValue)110ea8ed5cbSbixia1 uint16_t float2bfloat(float floatValue) {
111b3127769SBenjamin Kramer if (std::isnan(floatValue))
112b3127769SBenjamin Kramer return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0;
113b3127769SBenjamin Kramer
114ea8ed5cbSbixia1 Float32Bits floatBits;
115ea8ed5cbSbixia1 floatBits.f = floatValue;
116ea8ed5cbSbixia1 uint16_t bfloatBits;
117ea8ed5cbSbixia1
118ea8ed5cbSbixia1 // Least significant bit of resulting bfloat.
119ea8ed5cbSbixia1 uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1;
120be799722SMehdi Amini uint32_t roundingBias = 0x7fff + lsb;
121be799722SMehdi Amini floatBits.u += roundingBias;
122ea8ed5cbSbixia1 bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff);
123ea8ed5cbSbixia1 return bfloatBits;
124ea8ed5cbSbixia1 }
125ea8ed5cbSbixia1
126ea8ed5cbSbixia1 // Converts the 16 bit representation of a bfloat value to a float value. This
127ea8ed5cbSbixia1 // implementation is adapted from Eigen.
bfloat2float(uint16_t bfloatBits)128ea8ed5cbSbixia1 float bfloat2float(uint16_t bfloatBits) {
129ea8ed5cbSbixia1 Float32Bits floatBits;
130ea8ed5cbSbixia1 floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff;
131ea8ed5cbSbixia1 return floatBits.f;
132ea8ed5cbSbixia1 }
133ea8ed5cbSbixia1
134ea8ed5cbSbixia1 } // namespace
135ea8ed5cbSbixia1
f16(float f)136ea8ed5cbSbixia1 f16::f16(float f) : bits(float2half(f)) {}
137ea8ed5cbSbixia1
bf16(float f)138ea8ed5cbSbixia1 bf16::bf16(float f) : bits(float2bfloat(f)) {}
139ea8ed5cbSbixia1
operator <<(std::ostream & os,const f16 & f)140ea8ed5cbSbixia1 std::ostream &operator<<(std::ostream &os, const f16 &f) {
141ea8ed5cbSbixia1 os << half2float(f.bits);
142ea8ed5cbSbixia1 return os;
143ea8ed5cbSbixia1 }
144ea8ed5cbSbixia1
operator <<(std::ostream & os,const bf16 & d)145ea8ed5cbSbixia1 std::ostream &operator<<(std::ostream &os, const bf16 &d) {
146ea8ed5cbSbixia1 os << bfloat2float(d.bits);
147ea8ed5cbSbixia1 return os;
148ea8ed5cbSbixia1 }
1493420cd7cSBenjamin Kramer
15023637ca0SBenjamin Kramer // Mark these symbols as weak so they don't conflict when compiler-rt also
15123637ca0SBenjamin Kramer // defines them.
15223637ca0SBenjamin Kramer #define ATTR_WEAK
153745a4caaSBenjamin Kramer #ifdef __has_attribute
154745a4caaSBenjamin Kramer #if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \
155745a4caaSBenjamin Kramer !defined(_WIN32)
15623637ca0SBenjamin Kramer #undef ATTR_WEAK
15723637ca0SBenjamin Kramer #define ATTR_WEAK __attribute__((__weak__))
158d5c29b23SBenjamin Kramer #endif
159745a4caaSBenjamin Kramer #endif
16023637ca0SBenjamin Kramer
16123637ca0SBenjamin Kramer #if defined(__x86_64__)
162*fbd2950dSBenjamin Kramer // On x86 bfloat16 is passed in SSE registers. Since both float and __bf16
16323637ca0SBenjamin Kramer // are passed in the same register we can use the wider type and careful casting
16423637ca0SBenjamin Kramer // to conform to x86_64 psABI. This only works with the assumption that we're
16523637ca0SBenjamin Kramer // dealing with little-endian values passed in wider registers.
166*fbd2950dSBenjamin Kramer // Ideally this would directly use __bf16, but that type isn't supported by all
167*fbd2950dSBenjamin Kramer // compilers.
16823637ca0SBenjamin Kramer using BF16ABIType = float;
16923637ca0SBenjamin Kramer #else
17023637ca0SBenjamin Kramer // Default to uint16_t if we have nothing else.
17123637ca0SBenjamin Kramer using BF16ABIType = uint16_t;
17223637ca0SBenjamin Kramer #endif
17323637ca0SBenjamin Kramer
17423637ca0SBenjamin Kramer // Provide a float->bfloat conversion routine in case the runtime doesn't have
17523637ca0SBenjamin Kramer // one.
__truncsfbf2(float f)17623637ca0SBenjamin Kramer extern "C" BF16ABIType ATTR_WEAK __truncsfbf2(float f) {
17723637ca0SBenjamin Kramer uint16_t bf = float2bfloat(f);
17823637ca0SBenjamin Kramer // The output can be a float type, bitcast it from uint16_t.
17923637ca0SBenjamin Kramer BF16ABIType ret = 0;
18023637ca0SBenjamin Kramer std::memcpy(&ret, &bf, sizeof(bf));
18123637ca0SBenjamin Kramer return ret;
1823420cd7cSBenjamin Kramer }
1833420cd7cSBenjamin Kramer
1843420cd7cSBenjamin Kramer // Provide a double->bfloat conversion routine in case the runtime doesn't have
1853420cd7cSBenjamin Kramer // one.
__truncdfbf2(double d)18623637ca0SBenjamin Kramer extern "C" BF16ABIType ATTR_WEAK __truncdfbf2(double d) {
1873420cd7cSBenjamin Kramer // This does a double rounding step, but it's precise enough for our use
1883420cd7cSBenjamin Kramer // cases.
189*fbd2950dSBenjamin Kramer return __truncsfbf2(static_cast<float>(d));
1903420cd7cSBenjamin Kramer }
191