1 //===-- A class to manipulate wide integers. --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_CPP_UINT_H
10 #define LLVM_LIBC_UTILS_CPP_UINT_H
11 
12 #include "Array.h"
13 
14 #include <stddef.h> // For size_t
15 #include <stdint.h>
16 
17 namespace __llvm_libc {
18 namespace cpp {
19 
20 template <size_t Bits> class UInt {
21 
22   static_assert(Bits > 0 && Bits % 64 == 0,
23                 "Number of bits in UInt should be a multiple of 64.");
24   static constexpr size_t WordCount = Bits / 64;
25   uint64_t val[WordCount];
26 
27   static constexpr uint64_t MASK32 = 0xFFFFFFFFu;
28 
low(uint64_t v)29   static constexpr uint64_t low(uint64_t v) { return v & MASK32; }
high(uint64_t v)30   static constexpr uint64_t high(uint64_t v) { return (v >> 32) & MASK32; }
31 
32 public:
UInt()33   constexpr UInt() {}
34 
UInt(const UInt<Bits> & other)35   constexpr UInt(const UInt<Bits> &other) {
36     for (size_t i = 0; i < WordCount; ++i)
37       val[i] = other.val[i];
38   }
39 
40   // Initialize the first word to |v| and the rest to 0.
UInt(uint64_t v)41   constexpr UInt(uint64_t v) {
42     val[0] = v;
43     for (size_t i = 1; i < WordCount; ++i) {
44       val[i] = 0;
45     }
46   }
UInt(const cpp::Array<uint64_t,WordCount> & words)47   constexpr explicit UInt(const cpp::Array<uint64_t, WordCount> &words) {
48     for (size_t i = 0; i < WordCount; ++i)
49       val[i] = words[i];
50   }
51 
uint64_t()52   constexpr explicit operator uint64_t() const { return val[0]; }
53 
uint32_t()54   constexpr explicit operator uint32_t() const {
55     return uint32_t(uint64_t(*this));
56   }
57 
uint8_t()58   constexpr explicit operator uint8_t() const {
59     return uint8_t(uint64_t(*this));
60   }
61 
62   UInt<Bits> &operator=(const UInt<Bits> &other) {
63     for (size_t i = 0; i < WordCount; ++i)
64       val[i] = other.val[i];
65     return *this;
66   }
67 
68   // Add x to this number and store the result in this number.
69   // Returns the carry value produced by the addition operation.
add(const UInt<Bits> & x)70   constexpr uint64_t add(const UInt<Bits> &x) {
71     uint64_t carry = 0;
72     for (size_t i = 0; i < WordCount; ++i) {
73       uint64_t res_lo = low(val[i]) + low(x.val[i]) + carry;
74       carry = high(res_lo);
75       res_lo = low(res_lo);
76 
77       uint64_t res_hi = high(val[i]) + high(x.val[i]) + carry;
78       carry = high(res_hi);
79       res_hi = low(res_hi);
80 
81       val[i] = res_lo + (res_hi << 32);
82     }
83     return carry;
84   }
85 
86   constexpr UInt<Bits> operator+(const UInt<Bits> &other) const {
87     UInt<Bits> result(*this);
88     result.add(other);
89     return result;
90   }
91 
92   constexpr UInt<Bits> operator+=(const UInt<Bits> &other) {
93     *this = *this + other;
94     return *this;
95   }
96 
97   // Multiply this number with x and store the result in this number. It is
98   // implemented using the long multiplication algorithm by splitting the
99   // 64-bit words of this number and |x| in to 32-bit halves but peforming
100   // the operations using 64-bit numbers. This ensures that we don't lose the
101   // carry bits.
102   // Returns the carry value produced by the multiplication operation.
mul(uint64_t x)103   constexpr uint64_t mul(uint64_t x) {
104     uint64_t x_lo = low(x);
105     uint64_t x_hi = high(x);
106 
107     cpp::Array<uint64_t, WordCount + 1> row1;
108     uint64_t carry = 0;
109     for (size_t i = 0; i < WordCount; ++i) {
110       uint64_t l = low(val[i]);
111       uint64_t h = high(val[i]);
112       uint64_t p1 = x_lo * l;
113       uint64_t p2 = x_lo * h;
114 
115       uint64_t res_lo = low(p1) + carry;
116       carry = high(res_lo);
117       uint64_t res_hi = high(p1) + low(p2) + carry;
118       carry = high(res_hi) + high(p2);
119 
120       res_lo = low(res_lo);
121       res_hi = low(res_hi);
122       row1[i] = res_lo + (res_hi << 32);
123     }
124     row1[WordCount] = carry;
125 
126     cpp::Array<uint64_t, WordCount + 1> row2;
127     row2[0] = 0;
128     carry = 0;
129     for (size_t i = 0; i < WordCount; ++i) {
130       uint64_t l = low(val[i]);
131       uint64_t h = high(val[i]);
132       uint64_t p1 = x_hi * l;
133       uint64_t p2 = x_hi * h;
134 
135       uint64_t res_lo = low(p1) + carry;
136       carry = high(res_lo);
137       uint64_t res_hi = high(p1) + low(p2) + carry;
138       carry = high(res_hi) + high(p2);
139 
140       res_lo = low(res_lo);
141       res_hi = low(res_hi);
142       row2[i] = res_lo + (res_hi << 32);
143     }
144     row2[WordCount] = carry;
145 
146     UInt<(WordCount + 1) * 64> r1(row1), r2(row2);
147     r2.shift_left(32);
148     r1.add(r2);
149     for (size_t i = 0; i < WordCount; ++i) {
150       val[i] = r1[i];
151     }
152     return r1[WordCount];
153   }
154 
155   constexpr UInt<Bits> operator*(const UInt<Bits> &other) const {
156     UInt<Bits> result(0);
157     for (size_t i = 0; i < WordCount; ++i) {
158       UInt<Bits> row_result(*this);
159       row_result.mul(other[i]);
160       row_result.shift_left(64 * i);
161       result = result + row_result;
162     }
163     return result;
164   }
165 
166   constexpr UInt<Bits> &operator*=(const UInt<Bits> &other) {
167     *this = *this * other;
168     return *this;
169   }
170 
shift_left(size_t s)171   constexpr void shift_left(size_t s) {
172     const size_t drop = s / 64;  // Number of words to drop
173     const size_t shift = s % 64; // Bits to shift in the remaining words.
174     const uint64_t mask = ((uint64_t(1) << shift) - 1) << (64 - shift);
175 
176     for (size_t i = WordCount; drop > 0 && i > 0; --i) {
177       if (i > drop)
178         val[i - 1] = val[i - drop - 1];
179       else
180         val[i - 1] = 0;
181     }
182     for (size_t i = WordCount; shift > 0 && i > drop; --i) {
183       uint64_t drop_val = (val[i - 1] & mask) >> (64 - shift);
184       val[i - 1] <<= shift;
185       if (i < WordCount)
186         val[i] |= drop_val;
187     }
188   }
189 
190   constexpr UInt<Bits> operator<<(size_t s) const {
191     UInt<Bits> result(*this);
192     result.shift_left(s);
193     return result;
194   }
195 
196   constexpr UInt<Bits> &operator<<=(size_t s) {
197     shift_left(s);
198     return *this;
199   }
200 
shift_right(size_t s)201   constexpr void shift_right(size_t s) {
202     const size_t drop = s / 64;  // Number of words to drop
203     const size_t shift = s % 64; // Bit shift in the remaining words.
204     const uint64_t mask = (uint64_t(1) << shift) - 1;
205 
206     for (size_t i = 0; drop > 0 && i < WordCount; ++i) {
207       if (i + drop < WordCount)
208         val[i] = val[i + drop];
209       else
210         val[i] = 0;
211     }
212     for (size_t i = 0; shift > 0 && i < WordCount; ++i) {
213       uint64_t drop_val = ((val[i] & mask) << (64 - shift));
214       val[i] >>= shift;
215       if (i > 0)
216         val[i - 1] |= drop_val;
217     }
218   }
219 
220   constexpr UInt<Bits> operator>>(size_t s) const {
221     UInt<Bits> result(*this);
222     result.shift_right(s);
223     return result;
224   }
225 
226   constexpr UInt<Bits> &operator>>=(size_t s) {
227     shift_right(s);
228     return *this;
229   }
230 
231   constexpr UInt<Bits> operator&(const UInt<Bits> &other) const {
232     UInt<Bits> result;
233     for (size_t i = 0; i < WordCount; ++i)
234       result.val[i] = val[i] & other.val[i];
235     return result;
236   }
237 
238   constexpr UInt<Bits> &operator&=(const UInt<Bits> &other) {
239     for (size_t i = 0; i < WordCount; ++i)
240       val[i] &= other.val[i];
241     return *this;
242   }
243 
244   constexpr UInt<Bits> operator|(const UInt<Bits> &other) const {
245     UInt<Bits> result;
246     for (size_t i = 0; i < WordCount; ++i)
247       result.val[i] = val[i] | other.val[i];
248     return result;
249   }
250 
251   constexpr UInt<Bits> &operator|=(const UInt<Bits> &other) {
252     for (size_t i = 0; i < WordCount; ++i)
253       val[i] |= other.val[i];
254     return *this;
255   }
256 
257   constexpr UInt<Bits> operator^(const UInt<Bits> &other) const {
258     UInt<Bits> result;
259     for (size_t i = 0; i < WordCount; ++i)
260       result.val[i] = val[i] ^ other.val[i];
261     return result;
262   }
263 
264   constexpr UInt<Bits> &operator^=(const UInt<Bits> &other) {
265     for (size_t i = 0; i < WordCount; ++i)
266       val[i] ^= other.val[i];
267     return *this;
268   }
269 
270   constexpr UInt<Bits> operator~() const {
271     UInt<Bits> result;
272     for (size_t i = 0; i < WordCount; ++i)
273       result.val[i] = ~val[i];
274     return result;
275   }
276 
277   constexpr bool operator==(const UInt<Bits> &other) const {
278     for (size_t i = 0; i < WordCount; ++i) {
279       if (val[i] != other.val[i])
280         return false;
281     }
282     return true;
283   }
284 
285   constexpr bool operator!=(const UInt<Bits> &other) const {
286     for (size_t i = 0; i < WordCount; ++i) {
287       if (val[i] != other.val[i])
288         return true;
289     }
290     return false;
291   }
292 
293   constexpr bool operator>(const UInt<Bits> &other) const {
294     for (size_t i = WordCount; i > 0; --i) {
295       uint64_t word = val[i - 1];
296       uint64_t other_word = other.val[i - 1];
297       if (word > other_word)
298         return true;
299       else if (word < other_word)
300         return false;
301     }
302     // Equal
303     return false;
304   }
305 
306   constexpr bool operator>=(const UInt<Bits> &other) const {
307     for (size_t i = WordCount; i > 0; --i) {
308       uint64_t word = val[i - 1];
309       uint64_t other_word = other.val[i - 1];
310       if (word > other_word)
311         return true;
312       else if (word < other_word)
313         return false;
314     }
315     // Equal
316     return true;
317   }
318 
319   constexpr bool operator<(const UInt<Bits> &other) const {
320     for (size_t i = WordCount; i > 0; --i) {
321       uint64_t word = val[i - 1];
322       uint64_t other_word = other.val[i - 1];
323       if (word > other_word)
324         return false;
325       else if (word < other_word)
326         return true;
327     }
328     // Equal
329     return false;
330   }
331 
332   constexpr bool operator<=(const UInt<Bits> &other) const {
333     for (size_t i = WordCount; i > 0; --i) {
334       uint64_t word = val[i - 1];
335       uint64_t other_word = other.val[i - 1];
336       if (word > other_word)
337         return false;
338       else if (word < other_word)
339         return true;
340     }
341     // Equal
342     return true;
343   }
344 
345   constexpr UInt<Bits> &operator++() {
346     UInt<Bits> one(1);
347     add(one);
348     return *this;
349   }
350 
351   // Return the i-th 64-bit word of the number.
352   constexpr const uint64_t &operator[](size_t i) const { return val[i]; }
353 
354   // Return the i-th 64-bit word of the number.
355   constexpr uint64_t &operator[](size_t i) { return val[i]; }
356 
data()357   uint64_t *data() { return val; }
358 
data()359   const uint64_t *data() const { return val; }
360 };
361 
362 template <>
363 constexpr UInt<128> UInt<128>::operator*(const UInt<128> &other) const {
364   // temp low covers bits 0-63, middle covers 32-95, high covers 64-127, and
365   // high overflow covers 96-159.
366   uint64_t temp_low = low(val[0]) * low(other[0]);
367   uint64_t temp_middle_1 = low(val[0]) * high(other[0]);
368   uint64_t temp_middle_2 = high(val[0]) * low(other[0]);
369 
370   // temp_middle is split out so that overflows can be handled, but since
371   // but since the result will be truncated to 128 bits any overflow from here
372   // on doesn't matter.
373   uint64_t temp_high = low(val[0]) * low(other[1]) +
374                        high(val[0]) * high(other[0]) +
375                        low(val[1]) * low(other[0]);
376 
377   uint64_t temp_high_overflow =
378       low(val[0]) * high(other[1]) + high(val[0]) * low(other[1]) +
379       low(val[1]) * high(other[0]) + high(val[1]) * low(other[0]);
380 
381   // temp_low_middle has just the high 32 bits of low, as well as any
382   // overflow.
383   uint64_t temp_low_middle =
384       high(temp_low) + low(temp_middle_1) + low(temp_middle_2);
385 
386   uint64_t new_low = low(temp_low) + (low(temp_low_middle) << 32);
387   uint64_t new_high = high(temp_low_middle) + high(temp_middle_1) +
388                       high(temp_middle_2) + temp_high +
389                       (low(temp_high_overflow) << 32);
390   UInt<128> result(0);
391   result[0] = new_low;
392   result[1] = new_high;
393   return result;
394 }
395 
396 } // namespace cpp
397 } // namespace __llvm_libc
398 
399 #endif // LLVM_LIBC_UTILS_CPP_UINT_H
400