1 use core::fmt;
2 
3 use super::host_page_size;
4 
5 /// A number of bytes that's guaranteed to be aligned to the host page size.
6 ///
7 /// This is used to manage page-aligned memory allocations.
8 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
9 pub struct HostAlignedByteCount(
10     // Invariant: this is always a multiple of the host page size.
11     usize,
12 );
13 
14 impl HostAlignedByteCount {
15     /// A zero byte count.
16     pub const ZERO: Self = Self(0);
17 
18     /// Creates a new `HostAlignedByteCount` from an aligned byte count.
19     ///
20     /// Returns an error if `bytes` is not page-aligned.
new(bytes: usize) -> Result<Self, ByteCountNotAligned>21     pub fn new(bytes: usize) -> Result<Self, ByteCountNotAligned> {
22         let host_page_size = host_page_size();
23         if bytes % host_page_size == 0 {
24             Ok(Self(bytes))
25         } else {
26             Err(ByteCountNotAligned(bytes))
27         }
28     }
29 
30     /// Creates a new `HostAlignedByteCount` from an aligned byte count without
31     /// checking validity.
32     ///
33     /// ## Safety
34     ///
35     /// The caller must ensure that `bytes` is page-aligned.
new_unchecked(bytes: usize) -> Self36     pub unsafe fn new_unchecked(bytes: usize) -> Self {
37         debug_assert!(
38             bytes % host_page_size() == 0,
39             "byte count {bytes} is not page-aligned (page size = {})",
40             host_page_size(),
41         );
42         Self(bytes)
43     }
44 
45     /// Creates a new `HostAlignedByteCount`, rounding up to the nearest page.
46     ///
47     /// Returns an error if `bytes + page_size - 1` overflows.
new_rounded_up(bytes: usize) -> Result<Self, ByteCountOutOfBounds>48     pub fn new_rounded_up(bytes: usize) -> Result<Self, ByteCountOutOfBounds> {
49         let page_size = host_page_size();
50         debug_assert!(page_size.is_power_of_two());
51         match bytes.checked_add(page_size - 1) {
52             Some(v) => Ok(Self(v & !(page_size - 1))),
53             None => Err(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::RoundUp)),
54         }
55     }
56 
57     /// Creates a new `HostAlignedByteCount` from a `u64`, rounding up to the nearest page.
58     ///
59     /// Returns an error if the `u64` overflows `usize`, or if `bytes +
60     /// page_size - 1` overflows.
new_rounded_up_u64(bytes: u64) -> Result<Self, ByteCountOutOfBounds>61     pub fn new_rounded_up_u64(bytes: u64) -> Result<Self, ByteCountOutOfBounds> {
62         let bytes = bytes
63             .try_into()
64             .map_err(|_| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::ConvertU64))?;
65         Self::new_rounded_up(bytes)
66     }
67 
68     /// Returns the host page size.
host_page_size() -> HostAlignedByteCount69     pub fn host_page_size() -> HostAlignedByteCount {
70         // The host page size is always a multiple of itself.
71         HostAlignedByteCount(host_page_size())
72     }
73 
74     /// Returns true if the page count is zero.
75     #[inline]
is_zero(self) -> bool76     pub fn is_zero(self) -> bool {
77         self == Self::ZERO
78     }
79 
80     /// Returns the number of bytes as a `usize`.
81     #[inline]
byte_count(self) -> usize82     pub fn byte_count(self) -> usize {
83         self.0
84     }
85 
86     /// Add two aligned byte counts together.
87     ///
88     /// Returns an error if the result overflows.
checked_add(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds>89     pub fn checked_add(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> {
90         // aligned + aligned = aligned
91         self.0
92             .checked_add(bytes.0)
93             .map(Self)
94             .ok_or(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Add))
95     }
96 
97     // Note: saturating_add should not be naively added! usize::MAX is not a
98     // power of 2 so is not aligned.
99 
100     /// Compute `self - bytes`.
101     ///
102     /// Returns an error if the result underflows.
checked_sub(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds>103     pub fn checked_sub(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> {
104         // aligned - aligned = aligned
105         self.0
106             .checked_sub(bytes.0)
107             .map(Self)
108             .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Sub))
109     }
110 
111     /// Compute `self - bytes`, returning zero if the result underflows.
112     #[inline]
saturating_sub(self, bytes: HostAlignedByteCount) -> Self113     pub fn saturating_sub(self, bytes: HostAlignedByteCount) -> Self {
114         // aligned - aligned = aligned, and 0 is always aligned.
115         Self(self.0.saturating_sub(bytes.0))
116     }
117 
118     /// Multiply an aligned byte count by a scalar value.
119     ///
120     /// Returns an error if the result overflows.
checked_mul(self, scalar: usize) -> Result<Self, ByteCountOutOfBounds>121     pub fn checked_mul(self, scalar: usize) -> Result<Self, ByteCountOutOfBounds> {
122         // aligned * scalar = aligned
123         self.0
124             .checked_mul(scalar)
125             .map(Self)
126             .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Mul))
127     }
128 
129     /// Divide an aligned byte count by another aligned byte count, producing a
130     /// scalar value.
131     ///
132     /// Returns an error in case the divisor is zero.
checked_div(self, divisor: HostAlignedByteCount) -> Result<usize, ByteCountOutOfBounds>133     pub fn checked_div(self, divisor: HostAlignedByteCount) -> Result<usize, ByteCountOutOfBounds> {
134         self.0
135             .checked_div(divisor.0)
136             .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Div))
137     }
138 
139     /// Compute the remainder of an aligned byte count divided by another
140     /// aligned byte count.
141     ///
142     /// The remainder is always an aligned byte count itself.
143     ///
144     /// Returns an error in case the divisor is zero.
checked_rem(self, divisor: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds>145     pub fn checked_rem(self, divisor: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> {
146         // Why is the remainder an aligned byte count? For example, if the page
147         // size is 4KiB, then the remainder of dividing (say) 40KiB by 16KiB is
148         // 8KiB, which is a multiple of 4KiB.
149         //
150         // More generally, for integers n >= 0, m > 0, k > 0:
151         //
152         //     (n * k) % (m * k) = (n % m) * k
153         //
154         // which is a multiple of k. Here, k is the host page size, so the
155         // remainder is a multiple of the host page size.
156         self.0
157             .checked_rem(divisor.0)
158             .map(Self)
159             .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Rem))
160     }
161 
162     /// Unchecked multiplication by a scalar value.
163     ///
164     /// ## Safety
165     ///
166     /// The result must not overflow.
167     #[inline]
unchecked_mul(self, n: usize) -> Self168     pub unsafe fn unchecked_mul(self, n: usize) -> Self {
169         Self(self.0 * n)
170     }
171 }
172 
173 impl PartialEq<usize> for HostAlignedByteCount {
174     #[inline]
eq(&self, other: &usize) -> bool175     fn eq(&self, other: &usize) -> bool {
176         self.0 == *other
177     }
178 }
179 
180 impl PartialEq<HostAlignedByteCount> for usize {
181     #[inline]
eq(&self, other: &HostAlignedByteCount) -> bool182     fn eq(&self, other: &HostAlignedByteCount) -> bool {
183         *self == other.0
184     }
185 }
186 
187 struct LowerHexDisplay<T>(T);
188 
189 impl<T: fmt::LowerHex> fmt::Display for LowerHexDisplay<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result190     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191         // Use the LowerHex impl as the Display impl, ensuring that there's
192         // always a 0x in the beginning (i.e. that the alternate formatter is
193         // used.)
194         if f.alternate() {
195             fmt::LowerHex::fmt(&self.0, f)
196         } else {
197             // Unfortunately, fill and alignment aren't respected this way, but
198             // it's quite hard to construct a new formatter with mostly the same
199             // options but the alternate flag set.
200             // https://github.com/rust-lang/rust/pull/118159 would make this
201             // easier.
202             write!(f, "{:#x}", self.0)
203         }
204     }
205 }
206 
207 impl fmt::Display for HostAlignedByteCount {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result208     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209         // Use the LowerHex impl as the Display impl, ensuring that there's
210         // always a 0x in the beginning (i.e. that the alternate formatter is
211         // used.)
212         fmt::Display::fmt(&LowerHexDisplay(self.0), f)
213     }
214 }
215 
216 impl fmt::LowerHex for HostAlignedByteCount {
217     #[inline]
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result218     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219         fmt::LowerHex::fmt(&self.0, f)
220     }
221 }
222 
223 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
224 pub struct ByteCountNotAligned(usize);
225 
226 impl fmt::Display for ByteCountNotAligned {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result227     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228         write!(
229             f,
230             "byte count not page-aligned: {}",
231             LowerHexDisplay(self.0)
232         )
233     }
234 }
235 
236 impl core::error::Error for ByteCountNotAligned {}
237 
238 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
239 pub struct ByteCountOutOfBounds(ByteCountOutOfBoundsKind);
240 
241 impl fmt::Display for ByteCountOutOfBounds {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result242     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243         write!(f, "{}", self.0)
244     }
245 }
246 
247 impl core::error::Error for ByteCountOutOfBounds {}
248 
249 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
250 enum ByteCountOutOfBoundsKind {
251     // We don't carry the arguments that errored out to avoid the error type
252     // becoming too big.
253     RoundUp,
254     ConvertU64,
255     Add,
256     Sub,
257     Mul,
258     Div,
259     Rem,
260 }
261 
262 impl fmt::Display for ByteCountOutOfBoundsKind {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result263     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264         match self {
265             ByteCountOutOfBoundsKind::RoundUp => f.write_str("byte count overflow rounding up"),
266             ByteCountOutOfBoundsKind::ConvertU64 => {
267                 f.write_str("byte count overflow converting u64")
268             }
269             ByteCountOutOfBoundsKind::Add => f.write_str("byte count overflow during addition"),
270             ByteCountOutOfBoundsKind::Sub => f.write_str("byte count underflow during subtraction"),
271             ByteCountOutOfBoundsKind::Mul => {
272                 f.write_str("byte count overflow during multiplication")
273             }
274             ByteCountOutOfBoundsKind::Div => f.write_str("division by zero"),
275             ByteCountOutOfBoundsKind::Rem => f.write_str("remainder by zero"),
276         }
277     }
278 }
279 
280 #[cfg(test)]
281 mod proptest_impls {
282     use super::*;
283 
284     use proptest::prelude::*;
285 
286     impl Arbitrary for HostAlignedByteCount {
287         type Strategy = BoxedStrategy<Self>;
288         type Parameters = ();
289 
arbitrary_with(_: ()) -> Self::Strategy290         fn arbitrary_with(_: ()) -> Self::Strategy {
291             // Compute the number of pages that fit in a usize, rounded down.
292             // For example, if:
293             //
294             // * usize::MAX is 2**64 - 1
295             // * host_page_size is 2**12 (4KiB)
296             //
297             // Then page_count = floor(usize::MAX / host_page_size) = 2**52 - 1.
298             // The range 0..=page_count, when multiplied by the page size, will
299             // produce values in the range 0..=(2**64 - 2**12), in steps of
300             // 2**12, uniformly at random. This is the desired uniform
301             // distribution of byte counts.
302             let page_count = usize::MAX / host_page_size();
303             (0..=page_count)
304                 .prop_map(|n| HostAlignedByteCount::new(n * host_page_size()).unwrap())
305                 .boxed()
306         }
307     }
308 }
309 
310 #[cfg(test)]
311 mod tests {
312     use super::*;
313 
314     #[test]
byte_count_display()315     fn byte_count_display() {
316         // Pages should hopefully be 64k or smaller.
317         let byte_count = HostAlignedByteCount::new(65536).unwrap();
318 
319         assert_eq!(format!("{byte_count}"), "0x10000");
320         assert_eq!(format!("{byte_count:x}"), "10000");
321         assert_eq!(format!("{byte_count:#x}"), "0x10000");
322     }
323 
324     #[test]
byte_count_ops()325     fn byte_count_ops() {
326         let host_page_size = host_page_size();
327         HostAlignedByteCount::new(0).expect("0 is aligned");
328         HostAlignedByteCount::new(host_page_size).expect("host_page_size is aligned");
329         HostAlignedByteCount::new(host_page_size * 2).expect("host_page_size * 2 is aligned");
330         HostAlignedByteCount::new(host_page_size + 1)
331             .expect_err("host_page_size + 1 is not aligned");
332         HostAlignedByteCount::new(host_page_size / 2)
333             .expect_err("host_page_size / 2 is not aligned");
334 
335         // Rounding up.
336         HostAlignedByteCount::new_rounded_up(usize::MAX).expect_err("usize::MAX overflows");
337         assert_eq!(
338             HostAlignedByteCount::new_rounded_up(usize::MAX - host_page_size)
339                 .expect("(usize::MAX - 1 page) is in bounds"),
340             HostAlignedByteCount::new((usize::MAX - host_page_size) + 1)
341                 .expect("usize::MAX is 2**N - 1"),
342         );
343 
344         // Addition.
345         let half_max = HostAlignedByteCount::new((usize::MAX >> 1) + 1)
346             .expect("(usize::MAX >> 1) + 1 is aligned");
347         half_max
348             .checked_add(HostAlignedByteCount::host_page_size())
349             .expect("half max + page size is in bounds");
350         half_max
351             .checked_add(half_max)
352             .expect_err("half max + half max is out of bounds");
353 
354         // Subtraction.
355         let half_max_minus_one = half_max
356             .checked_sub(HostAlignedByteCount::host_page_size())
357             .expect("(half_max - 1 page) is in bounds");
358         assert_eq!(
359             half_max.checked_sub(half_max),
360             Ok(HostAlignedByteCount::ZERO)
361         );
362         assert_eq!(
363             half_max.checked_sub(half_max_minus_one),
364             Ok(HostAlignedByteCount::host_page_size())
365         );
366         half_max_minus_one
367             .checked_sub(half_max)
368             .expect_err("(half_max - 1 page) - half_max is out of bounds");
369 
370         // Multiplication.
371         half_max
372             .checked_mul(2)
373             .expect_err("half max * 2 is out of bounds");
374         half_max_minus_one
375             .checked_mul(2)
376             .expect("(half max - 1 page) * 2 is in bounds");
377     }
378 }
379