1 //===-- A class to manipulate wide integers. --------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_LIBC_UTILS_CPP_UINT_H 10 #define LLVM_LIBC_UTILS_CPP_UINT_H 11 12 #include "Array.h" 13 14 #include <stddef.h> // For size_t 15 #include <stdint.h> 16 17 namespace __llvm_libc { 18 namespace cpp { 19 20 template <size_t Bits> class UInt { 21 22 static_assert(Bits > 0 && Bits % 64 == 0, 23 "Number of bits in UInt should be a multiple of 64."); 24 static constexpr size_t WordCount = Bits / 64; 25 uint64_t val[WordCount]; 26 27 static constexpr uint64_t MASK32 = 0xFFFFFFFFu; 28 29 static constexpr uint64_t low(uint64_t v) { return v & MASK32; } 30 static constexpr uint64_t high(uint64_t v) { return (v >> 32) & MASK32; } 31 32 public: 33 constexpr UInt() {} 34 35 constexpr UInt(const UInt<Bits> &other) { 36 for (size_t i = 0; i < WordCount; ++i) 37 val[i] = other.val[i]; 38 } 39 40 // Initialize the first word to |v| and the rest to 0. 41 constexpr UInt(uint64_t v) { 42 val[0] = v; 43 for (size_t i = 1; i < WordCount; ++i) { 44 val[i] = 0; 45 } 46 } 47 constexpr explicit UInt(const cpp::Array<uint64_t, WordCount> &words) { 48 for (size_t i = 0; i < WordCount; ++i) 49 val[i] = words[i]; 50 } 51 52 constexpr explicit operator uint64_t() const { return val[0]; } 53 54 constexpr explicit operator uint32_t() const { 55 return uint32_t(uint64_t(*this)); 56 } 57 58 constexpr explicit operator uint8_t() const { 59 return uint8_t(uint64_t(*this)); 60 } 61 62 UInt<Bits> &operator=(const UInt<Bits> &other) { 63 for (size_t i = 0; i < WordCount; ++i) 64 val[i] = other.val[i]; 65 return *this; 66 } 67 68 // Add x to this number and store the result in this number. 69 // Returns the carry value produced by the addition operation. 70 constexpr uint64_t add(const UInt<Bits> &x) { 71 uint64_t carry = 0; 72 for (size_t i = 0; i < WordCount; ++i) { 73 uint64_t res_lo = low(val[i]) + low(x.val[i]) + carry; 74 carry = high(res_lo); 75 res_lo = low(res_lo); 76 77 uint64_t res_hi = high(val[i]) + high(x.val[i]) + carry; 78 carry = high(res_hi); 79 res_hi = low(res_hi); 80 81 val[i] = res_lo + (res_hi << 32); 82 } 83 return carry; 84 } 85 86 constexpr UInt<Bits> operator+(const UInt<Bits> &other) const { 87 UInt<Bits> result(*this); 88 result.add(other); 89 return result; 90 } 91 92 // Multiply this number with x and store the result in this number. It is 93 // implemented using the long multiplication algorithm by splitting the 94 // 64-bit words of this number and |x| in to 32-bit halves but peforming 95 // the operations using 64-bit numbers. This ensures that we don't lose the 96 // carry bits. 97 // Returns the carry value produced by the multiplication operation. 98 constexpr uint64_t mul(uint64_t x) { 99 uint64_t x_lo = low(x); 100 uint64_t x_hi = high(x); 101 102 cpp::Array<uint64_t, WordCount + 1> row1; 103 uint64_t carry = 0; 104 for (size_t i = 0; i < WordCount; ++i) { 105 uint64_t l = low(val[i]); 106 uint64_t h = high(val[i]); 107 uint64_t p1 = x_lo * l; 108 uint64_t p2 = x_lo * h; 109 110 uint64_t res_lo = low(p1) + carry; 111 carry = high(res_lo); 112 uint64_t res_hi = high(p1) + low(p2) + carry; 113 carry = high(res_hi) + high(p2); 114 115 res_lo = low(res_lo); 116 res_hi = low(res_hi); 117 row1[i] = res_lo + (res_hi << 32); 118 } 119 row1[WordCount] = carry; 120 121 cpp::Array<uint64_t, WordCount + 1> row2; 122 row2[0] = 0; 123 carry = 0; 124 for (size_t i = 0; i < WordCount; ++i) { 125 uint64_t l = low(val[i]); 126 uint64_t h = high(val[i]); 127 uint64_t p1 = x_hi * l; 128 uint64_t p2 = x_hi * h; 129 130 uint64_t res_lo = low(p1) + carry; 131 carry = high(res_lo); 132 uint64_t res_hi = high(p1) + low(p2) + carry; 133 carry = high(res_hi) + high(p2); 134 135 res_lo = low(res_lo); 136 res_hi = low(res_hi); 137 row2[i] = res_lo + (res_hi << 32); 138 } 139 row2[WordCount] = carry; 140 141 UInt<(WordCount + 1) * 64> r1(row1), r2(row2); 142 r2.shift_left(32); 143 r1.add(r2); 144 for (size_t i = 0; i < WordCount; ++i) { 145 val[i] = r1[i]; 146 } 147 return r1[WordCount]; 148 } 149 150 constexpr UInt<Bits> operator*(const UInt<Bits> &other) const { 151 UInt<Bits> result(0); 152 for (size_t i = 0; i < WordCount; ++i) { 153 UInt<Bits> row_result(*this); 154 row_result.mul(other[i]); 155 row_result.shift_left(64 * i); 156 result = result + row_result; 157 } 158 return result; 159 } 160 161 constexpr void shift_left(size_t s) { 162 const size_t drop = s / 64; // Number of words to drop 163 const size_t shift = s % 64; // Bits to shift in the remaining words. 164 const uint64_t mask = ((uint64_t(1) << shift) - 1) << (64 - shift); 165 166 for (size_t i = WordCount; drop > 0 && i > 0; --i) { 167 if (i > drop) 168 val[i - 1] = val[i - drop - 1]; 169 else 170 val[i - 1] = 0; 171 } 172 for (size_t i = WordCount; shift > 0 && i > drop; --i) { 173 uint64_t drop_val = (val[i - 1] & mask) >> (64 - shift); 174 val[i - 1] <<= shift; 175 if (i < WordCount) 176 val[i] |= drop_val; 177 } 178 } 179 180 constexpr UInt<Bits> operator<<(size_t s) const { 181 UInt<Bits> result(*this); 182 result.shift_left(s); 183 return result; 184 } 185 186 constexpr UInt<Bits> &operator<<=(size_t s) { 187 shift_left(s); 188 return *this; 189 } 190 191 constexpr void shift_right(size_t s) { 192 const size_t drop = s / 64; // Number of words to drop 193 const size_t shift = s % 64; // Bit shift in the remaining words. 194 const uint64_t mask = (uint64_t(1) << shift) - 1; 195 196 for (size_t i = 0; drop > 0 && i < WordCount; ++i) { 197 if (i + drop < WordCount) 198 val[i] = val[i + drop]; 199 else 200 val[i] = 0; 201 } 202 for (size_t i = 0; shift > 0 && i < WordCount; ++i) { 203 uint64_t drop_val = ((val[i] & mask) << (64 - shift)); 204 val[i] >>= shift; 205 if (i > 0) 206 val[i - 1] |= drop_val; 207 } 208 } 209 210 constexpr UInt<Bits> operator>>(size_t s) const { 211 UInt<Bits> result(*this); 212 result.shift_right(s); 213 return result; 214 } 215 216 constexpr UInt<Bits> &operator>>=(size_t s) { 217 shift_right(s); 218 return *this; 219 } 220 221 constexpr UInt<Bits> operator&(const UInt<Bits> &other) const { 222 UInt<Bits> result; 223 for (size_t i = 0; i < WordCount; ++i) 224 result.val[i] = val[i] & other.val[i]; 225 return result; 226 } 227 228 constexpr UInt<Bits> operator|(const UInt<Bits> &other) const { 229 UInt<Bits> result; 230 for (size_t i = 0; i < WordCount; ++i) 231 result.val[i] = val[i] | other.val[i]; 232 return result; 233 } 234 235 constexpr UInt<Bits> operator^(const UInt<Bits> &other) const { 236 UInt<Bits> result; 237 for (size_t i = 0; i < WordCount; ++i) 238 result.val[i] = val[i] ^ other.val[i]; 239 return result; 240 } 241 242 constexpr UInt<Bits> operator~() const { 243 UInt<Bits> result; 244 for (size_t i = 0; i < WordCount; ++i) 245 result.val[i] = ~val[i]; 246 return result; 247 } 248 249 constexpr bool operator==(const UInt<Bits> &other) const { 250 for (size_t i = 0; i < WordCount; ++i) { 251 if (val[i] != other.val[i]) 252 return false; 253 } 254 return true; 255 } 256 257 constexpr bool operator!=(const UInt<Bits> &other) const { 258 for (size_t i = 0; i < WordCount; ++i) { 259 if (val[i] != other.val[i]) 260 return true; 261 } 262 return false; 263 } 264 265 constexpr bool operator>(const UInt<Bits> &other) const { 266 for (size_t i = WordCount; i > 0; --i) { 267 uint64_t word = val[i - 1]; 268 uint64_t other_word = other.val[i - 1]; 269 if (word > other_word) 270 return true; 271 else if (word < other_word) 272 return false; 273 } 274 // Equal 275 return false; 276 } 277 278 constexpr bool operator>=(const UInt<Bits> &other) const { 279 for (size_t i = WordCount; i > 0; --i) { 280 uint64_t word = val[i - 1]; 281 uint64_t other_word = other.val[i - 1]; 282 if (word > other_word) 283 return true; 284 else if (word < other_word) 285 return false; 286 } 287 // Equal 288 return true; 289 } 290 291 constexpr bool operator<(const UInt<Bits> &other) const { 292 for (size_t i = WordCount; i > 0; --i) { 293 uint64_t word = val[i - 1]; 294 uint64_t other_word = other.val[i - 1]; 295 if (word > other_word) 296 return false; 297 else if (word < other_word) 298 return true; 299 } 300 // Equal 301 return false; 302 } 303 304 constexpr bool operator<=(const UInt<Bits> &other) const { 305 for (size_t i = WordCount; i > 0; --i) { 306 uint64_t word = val[i - 1]; 307 uint64_t other_word = other.val[i - 1]; 308 if (word > other_word) 309 return false; 310 else if (word < other_word) 311 return true; 312 } 313 // Equal 314 return true; 315 } 316 317 // Return the i-th 64-bit word of the number. 318 constexpr const uint64_t &operator[](size_t i) const { return val[i]; } 319 320 // Return the i-th 64-bit word of the number. 321 constexpr uint64_t &operator[](size_t i) { return val[i]; } 322 323 uint64_t *data() { return val; } 324 325 const uint64_t *data() const { return val; } 326 }; 327 328 template <> 329 constexpr UInt<128> UInt<128>::operator*(const UInt<128> &other) const { 330 // temp low covers bits 0-63, middle covers 32-95, high covers 64-127, and 331 // high overflow covers 96-159. 332 uint64_t temp_low = low(val[0]) * low(other[0]); 333 uint64_t temp_middle_1 = low(val[0]) * high(other[0]); 334 uint64_t temp_middle_2 = high(val[0]) * low(other[0]); 335 336 // temp_middle is split out so that overflows can be handled, but since 337 // but since the result will be truncated to 128 bits any overflow from here 338 // on doesn't matter. 339 uint64_t temp_high = low(val[0]) * low(other[1]) + 340 high(val[0]) * high(other[0]) + 341 low(val[1]) * low(other[0]); 342 343 uint64_t temp_high_overflow = 344 low(val[0]) * high(other[1]) + high(val[0]) * low(other[1]) + 345 low(val[1]) * high(other[0]) + high(val[1]) * low(other[0]); 346 347 // temp_low_middle has just the high 32 bits of low, as well as any 348 // overflow. 349 uint64_t temp_low_middle = 350 high(temp_low) + low(temp_middle_1) + low(temp_middle_2); 351 352 uint64_t new_low = low(temp_low) + (low(temp_low_middle) << 32); 353 uint64_t new_high = high(temp_low_middle) + high(temp_middle_1) + 354 high(temp_middle_2) + temp_high + 355 (low(temp_high_overflow) << 32); 356 UInt<128> result(0); 357 result[0] = new_low; 358 result[1] = new_high; 359 return result; 360 } 361 362 } // namespace cpp 363 } // namespace __llvm_libc 364 365 #endif // LLVM_LIBC_UTILS_CPP_UINT_H 366