1 //===-- A class to manipulate wide integers. --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_CPP_UINT_H
10 #define LLVM_LIBC_UTILS_CPP_UINT_H
11 
12 #include "Array.h"
13 
14 #include <stddef.h> // For size_t
15 #include <stdint.h>
16 
17 namespace __llvm_libc {
18 namespace cpp {
19 
20 template <size_t Bits> class UInt {
21 
22   static_assert(Bits > 0 && Bits % 64 == 0,
23                 "Number of bits in UInt should be a multiple of 64.");
24   static constexpr size_t WordCount = Bits / 64;
25   uint64_t val[WordCount];
26 
27   static constexpr uint64_t MASK32 = 0xFFFFFFFFu;
28 
29   static constexpr uint64_t low(uint64_t v) { return v & MASK32; }
30   static constexpr uint64_t high(uint64_t v) { return (v >> 32) & MASK32; }
31 
32 public:
33   constexpr UInt() {}
34 
35   constexpr UInt(const UInt<Bits> &other) {
36     for (size_t i = 0; i < WordCount; ++i)
37       val[i] = other.val[i];
38   }
39 
40   // Initialize the first word to |v| and the rest to 0.
41   constexpr UInt(uint64_t v) {
42     val[0] = v;
43     for (size_t i = 1; i < WordCount; ++i) {
44       val[i] = 0;
45     }
46   }
47   constexpr explicit UInt(const cpp::Array<uint64_t, WordCount> &words) {
48     for (size_t i = 0; i < WordCount; ++i)
49       val[i] = words[i];
50   }
51 
52   constexpr explicit operator uint64_t() const { return val[0]; }
53 
54   constexpr explicit operator uint32_t() const {
55     return uint32_t(uint64_t(*this));
56   }
57 
58   constexpr explicit operator uint8_t() const {
59     return uint8_t(uint64_t(*this));
60   }
61 
62   UInt<Bits> &operator=(const UInt<Bits> &other) {
63     for (size_t i = 0; i < WordCount; ++i)
64       val[i] = other.val[i];
65     return *this;
66   }
67 
68   // Add x to this number and store the result in this number.
69   // Returns the carry value produced by the addition operation.
70   constexpr uint64_t add(const UInt<Bits> &x) {
71     uint64_t carry = 0;
72     for (size_t i = 0; i < WordCount; ++i) {
73       uint64_t res_lo = low(val[i]) + low(x.val[i]) + carry;
74       carry = high(res_lo);
75       res_lo = low(res_lo);
76 
77       uint64_t res_hi = high(val[i]) + high(x.val[i]) + carry;
78       carry = high(res_hi);
79       res_hi = low(res_hi);
80 
81       val[i] = res_lo + (res_hi << 32);
82     }
83     return carry;
84   }
85 
86   constexpr UInt<Bits> operator+(const UInt<Bits> &other) const {
87     UInt<Bits> result(*this);
88     result.add(other);
89     return result;
90   }
91 
92   // Multiply this number with x and store the result in this number. It is
93   // implemented using the long multiplication algorithm by splitting the
94   // 64-bit words of this number and |x| in to 32-bit halves but peforming
95   // the operations using 64-bit numbers. This ensures that we don't lose the
96   // carry bits.
97   // Returns the carry value produced by the multiplication operation.
98   constexpr uint64_t mul(uint64_t x) {
99     uint64_t x_lo = low(x);
100     uint64_t x_hi = high(x);
101 
102     cpp::Array<uint64_t, WordCount + 1> row1;
103     uint64_t carry = 0;
104     for (size_t i = 0; i < WordCount; ++i) {
105       uint64_t l = low(val[i]);
106       uint64_t h = high(val[i]);
107       uint64_t p1 = x_lo * l;
108       uint64_t p2 = x_lo * h;
109 
110       uint64_t res_lo = low(p1) + carry;
111       carry = high(res_lo);
112       uint64_t res_hi = high(p1) + low(p2) + carry;
113       carry = high(res_hi) + high(p2);
114 
115       res_lo = low(res_lo);
116       res_hi = low(res_hi);
117       row1[i] = res_lo + (res_hi << 32);
118     }
119     row1[WordCount] = carry;
120 
121     cpp::Array<uint64_t, WordCount + 1> row2;
122     row2[0] = 0;
123     carry = 0;
124     for (size_t i = 0; i < WordCount; ++i) {
125       uint64_t l = low(val[i]);
126       uint64_t h = high(val[i]);
127       uint64_t p1 = x_hi * l;
128       uint64_t p2 = x_hi * h;
129 
130       uint64_t res_lo = low(p1) + carry;
131       carry = high(res_lo);
132       uint64_t res_hi = high(p1) + low(p2) + carry;
133       carry = high(res_hi) + high(p2);
134 
135       res_lo = low(res_lo);
136       res_hi = low(res_hi);
137       row2[i] = res_lo + (res_hi << 32);
138     }
139     row2[WordCount] = carry;
140 
141     UInt<(WordCount + 1) * 64> r1(row1), r2(row2);
142     r2.shift_left(32);
143     r1.add(r2);
144     for (size_t i = 0; i < WordCount; ++i) {
145       val[i] = r1[i];
146     }
147     return r1[WordCount];
148   }
149 
150   constexpr UInt<Bits> operator*(const UInt<Bits> &other) const {
151     UInt<Bits> result(0);
152     for (size_t i = 0; i < WordCount; ++i) {
153       UInt<Bits> row_result(*this);
154       row_result.mul(other[i]);
155       row_result.shift_left(64 * i);
156       result = result + row_result;
157     }
158     return result;
159   }
160 
161   constexpr void shift_left(size_t s) {
162     const size_t drop = s / 64;  // Number of words to drop
163     const size_t shift = s % 64; // Bits to shift in the remaining words.
164     const uint64_t mask = ((uint64_t(1) << shift) - 1) << (64 - shift);
165 
166     for (size_t i = WordCount; drop > 0 && i > 0; --i) {
167       if (i > drop)
168         val[i - 1] = val[i - drop - 1];
169       else
170         val[i - 1] = 0;
171     }
172     for (size_t i = WordCount; shift > 0 && i > drop; --i) {
173       uint64_t drop_val = (val[i - 1] & mask) >> (64 - shift);
174       val[i - 1] <<= shift;
175       if (i < WordCount)
176         val[i] |= drop_val;
177     }
178   }
179 
180   constexpr UInt<Bits> operator<<(size_t s) const {
181     UInt<Bits> result(*this);
182     result.shift_left(s);
183     return result;
184   }
185 
186   constexpr UInt<Bits> &operator<<=(size_t s) {
187     shift_left(s);
188     return *this;
189   }
190 
191   constexpr void shift_right(size_t s) {
192     const size_t drop = s / 64;  // Number of words to drop
193     const size_t shift = s % 64; // Bit shift in the remaining words.
194     const uint64_t mask = (uint64_t(1) << shift) - 1;
195 
196     for (size_t i = 0; drop > 0 && i < WordCount; ++i) {
197       if (i + drop < WordCount)
198         val[i] = val[i + drop];
199       else
200         val[i] = 0;
201     }
202     for (size_t i = 0; shift > 0 && i < WordCount; ++i) {
203       uint64_t drop_val = ((val[i] & mask) << (64 - shift));
204       val[i] >>= shift;
205       if (i > 0)
206         val[i - 1] |= drop_val;
207     }
208   }
209 
210   constexpr UInt<Bits> operator>>(size_t s) const {
211     UInt<Bits> result(*this);
212     result.shift_right(s);
213     return result;
214   }
215 
216   constexpr UInt<Bits> &operator>>=(size_t s) {
217     shift_right(s);
218     return *this;
219   }
220 
221   constexpr UInt<Bits> operator&(const UInt<Bits> &other) const {
222     UInt<Bits> result;
223     for (size_t i = 0; i < WordCount; ++i)
224       result.val[i] = val[i] & other.val[i];
225     return result;
226   }
227 
228   constexpr UInt<Bits> operator|(const UInt<Bits> &other) const {
229     UInt<Bits> result;
230     for (size_t i = 0; i < WordCount; ++i)
231       result.val[i] = val[i] | other.val[i];
232     return result;
233   }
234 
235   constexpr UInt<Bits> operator^(const UInt<Bits> &other) const {
236     UInt<Bits> result;
237     for (size_t i = 0; i < WordCount; ++i)
238       result.val[i] = val[i] ^ other.val[i];
239     return result;
240   }
241 
242   constexpr UInt<Bits> operator~() const {
243     UInt<Bits> result;
244     for (size_t i = 0; i < WordCount; ++i)
245       result.val[i] = ~val[i];
246     return result;
247   }
248 
249   constexpr bool operator==(const UInt<Bits> &other) const {
250     for (size_t i = 0; i < WordCount; ++i) {
251       if (val[i] != other.val[i])
252         return false;
253     }
254     return true;
255   }
256 
257   constexpr bool operator!=(const UInt<Bits> &other) const {
258     for (size_t i = 0; i < WordCount; ++i) {
259       if (val[i] != other.val[i])
260         return true;
261     }
262     return false;
263   }
264 
265   constexpr bool operator>(const UInt<Bits> &other) const {
266     for (size_t i = WordCount; i > 0; --i) {
267       uint64_t word = val[i - 1];
268       uint64_t other_word = other.val[i - 1];
269       if (word > other_word)
270         return true;
271       else if (word < other_word)
272         return false;
273     }
274     // Equal
275     return false;
276   }
277 
278   constexpr bool operator>=(const UInt<Bits> &other) const {
279     for (size_t i = WordCount; i > 0; --i) {
280       uint64_t word = val[i - 1];
281       uint64_t other_word = other.val[i - 1];
282       if (word > other_word)
283         return true;
284       else if (word < other_word)
285         return false;
286     }
287     // Equal
288     return true;
289   }
290 
291   constexpr bool operator<(const UInt<Bits> &other) const {
292     for (size_t i = WordCount; i > 0; --i) {
293       uint64_t word = val[i - 1];
294       uint64_t other_word = other.val[i - 1];
295       if (word > other_word)
296         return false;
297       else if (word < other_word)
298         return true;
299     }
300     // Equal
301     return false;
302   }
303 
304   constexpr bool operator<=(const UInt<Bits> &other) const {
305     for (size_t i = WordCount; i > 0; --i) {
306       uint64_t word = val[i - 1];
307       uint64_t other_word = other.val[i - 1];
308       if (word > other_word)
309         return false;
310       else if (word < other_word)
311         return true;
312     }
313     // Equal
314     return true;
315   }
316 
317   // Return the i-th 64-bit word of the number.
318   constexpr const uint64_t &operator[](size_t i) const { return val[i]; }
319 
320   // Return the i-th 64-bit word of the number.
321   constexpr uint64_t &operator[](size_t i) { return val[i]; }
322 
323   uint64_t *data() { return val; }
324 
325   const uint64_t *data() const { return val; }
326 };
327 
328 template <>
329 constexpr UInt<128> UInt<128>::operator*(const UInt<128> &other) const {
330   // temp low covers bits 0-63, middle covers 32-95, high covers 64-127, and
331   // high overflow covers 96-159.
332   uint64_t temp_low = low(val[0]) * low(other[0]);
333   uint64_t temp_middle_1 = low(val[0]) * high(other[0]);
334   uint64_t temp_middle_2 = high(val[0]) * low(other[0]);
335 
336   // temp_middle is split out so that overflows can be handled, but since
337   // but since the result will be truncated to 128 bits any overflow from here
338   // on doesn't matter.
339   uint64_t temp_high = low(val[0]) * low(other[1]) +
340                        high(val[0]) * high(other[0]) +
341                        low(val[1]) * low(other[0]);
342 
343   uint64_t temp_high_overflow =
344       low(val[0]) * high(other[1]) + high(val[0]) * low(other[1]) +
345       low(val[1]) * high(other[0]) + high(val[1]) * low(other[0]);
346 
347   // temp_low_middle has just the high 32 bits of low, as well as any
348   // overflow.
349   uint64_t temp_low_middle =
350       high(temp_low) + low(temp_middle_1) + low(temp_middle_2);
351 
352   uint64_t new_low = low(temp_low) + (low(temp_low_middle) << 32);
353   uint64_t new_high = high(temp_low_middle) + high(temp_middle_1) +
354                       high(temp_middle_2) + temp_high +
355                       (low(temp_high_overflow) << 32);
356   UInt<128> result(0);
357   result[0] = new_low;
358   result[1] = new_high;
359   return result;
360 }
361 
362 } // namespace cpp
363 } // namespace __llvm_libc
364 
365 #endif // LLVM_LIBC_UTILS_CPP_UINT_H
366