1 //===--- Float16bits.cpp - supports 2-byte floats  ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements f16 and bf16 to support the compilation and execution
10 // of programs using these types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/Float16bits.h"
15 
16 namespace {
17 
18 // Union used to make the int/float aliasing explicit so we can access the raw
19 // bits.
20 union Float32Bits {
21   uint32_t u;
22   float f;
23 };
24 
25 const uint32_t kF32MantiBits = 23;
26 const uint32_t kF32HalfMantiBitDiff = 13;
27 const uint32_t kF32HalfBitDiff = 16;
28 const Float32Bits kF32Magic = {113 << kF32MantiBits};
29 const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits;
30 
31 // Constructs the 16 bit representation for a half precision value from a float
32 // value. This implementation is adapted from Eigen.
33 uint16_t float2half(float floatValue) {
34   const Float32Bits inf = {255 << kF32MantiBits};
35   const Float32Bits f16max = {(127 + 16) << kF32MantiBits};
36   const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1)
37                                    << kF32MantiBits};
38   uint32_t signMask = 0x80000000u;
39   uint16_t halfValue = static_cast<uint16_t>(0x0u);
40   Float32Bits f;
41   f.f = floatValue;
42   uint32_t sign = f.u & signMask;
43   f.u ^= sign;
44 
45   if (f.u >= f16max.u) {
46     const uint32_t halfQnan = 0x7e00;
47     const uint32_t halfInf = 0x7c00;
48     // Inf or NaN (all exponent bits set).
49     halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf
50   } else {
51     // (De)normalized number or zero.
52     if (f.u < kF32Magic.u) {
53       // The resulting FP16 is subnormal or zero.
54       //
55       // Use a magic value to align our 10 mantissa bits at the bottom of the
56       // float. As long as FP addition is round-to-nearest-even this works.
57       f.f += denormMagic.f;
58 
59       halfValue = static_cast<uint16_t>(f.u - denormMagic.u);
60     } else {
61       uint32_t mantOdd =
62           (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd.
63 
64       // Update exponent, rounding bias part 1. The following expressions are
65       // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) +
66       // 0xfff`, but without arithmetic overflow.
67       f.u += 0xc8000fffU;
68       // Rounding bias part 2.
69       f.u += mantOdd;
70       halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff);
71     }
72   }
73 
74   halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff);
75   return halfValue;
76 }
77 
78 // Converts the 16 bit representation of a half precision value to a float
79 // value. This implementation is adapted from Eigen.
80 float half2float(uint16_t halfValue) {
81   const uint32_t shiftedExp =
82       0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift.
83 
84   // Initialize the float representation with the exponent/mantissa bits.
85   Float32Bits f = {
86       static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)};
87   const uint32_t exp = shiftedExp & f.u;
88   f.u += kF32HalfExpAdjust; // Adjust the exponent
89 
90   // Handle exponent special cases.
91   if (exp == shiftedExp) {
92     // Inf/NaN
93     f.u += kF32HalfExpAdjust;
94   } else if (exp == 0) {
95     // Zero/Denormal?
96     f.u += 1 << kF32MantiBits;
97     f.f -= kF32Magic.f;
98   }
99 
100   f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit.
101   return f.f;
102 }
103 
104 const uint32_t kF32BfMantiBitDiff = 16;
105 
106 // Constructs the 16 bit representation for a bfloat value from a float value.
107 // This implementation is adapted from Eigen.
108 uint16_t float2bfloat(float floatValue) {
109   Float32Bits floatBits;
110   floatBits.f = floatValue;
111   uint16_t bfloatBits;
112 
113   // Least significant bit of resulting bfloat.
114   uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1;
115   uint32_t rounding_bias = 0x7fff + lsb;
116   floatBits.u += rounding_bias;
117   bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff);
118   return bfloatBits;
119 }
120 
121 // Converts the 16 bit representation of a bfloat value to a float value. This
122 // implementation is adapted from Eigen.
123 float bfloat2float(uint16_t bfloatBits) {
124   Float32Bits floatBits;
125   floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff;
126   return floatBits.f;
127 }
128 
129 } // namespace
130 
131 f16::f16(float f) : bits(float2half(f)) {}
132 
133 bf16::bf16(float f) : bits(float2bfloat(f)) {}
134 
135 std::ostream &operator<<(std::ostream &os, const f16 &f) {
136   os << half2float(f.bits);
137   return os;
138 }
139 
140 std::ostream &operator<<(std::ostream &os, const bf16 &d) {
141   os << bfloat2float(d.bits);
142   return os;
143 }
144