1 use core::fmt; 2 3 use super::host_page_size; 4 5 /// A number of bytes that's guaranteed to be aligned to the host page size. 6 /// 7 /// This is used to manage page-aligned memory allocations. 8 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] 9 pub struct HostAlignedByteCount( 10 // Invariant: this is always a multiple of the host page size. 11 usize, 12 ); 13 14 impl HostAlignedByteCount { 15 /// A zero byte count. 16 pub const ZERO: Self = Self(0); 17 18 /// Creates a new `HostAlignedByteCount` from an aligned byte count. 19 /// 20 /// Returns an error if `bytes` is not page-aligned. new(bytes: usize) -> Result<Self, ByteCountNotAligned>21 pub fn new(bytes: usize) -> Result<Self, ByteCountNotAligned> { 22 let host_page_size = host_page_size(); 23 if bytes % host_page_size == 0 { 24 Ok(Self(bytes)) 25 } else { 26 Err(ByteCountNotAligned(bytes)) 27 } 28 } 29 30 /// Creates a new `HostAlignedByteCount` from an aligned byte count without 31 /// checking validity. 32 /// 33 /// ## Safety 34 /// 35 /// The caller must ensure that `bytes` is page-aligned. new_unchecked(bytes: usize) -> Self36 pub unsafe fn new_unchecked(bytes: usize) -> Self { 37 debug_assert!( 38 bytes % host_page_size() == 0, 39 "byte count {bytes} is not page-aligned (page size = {})", 40 host_page_size(), 41 ); 42 Self(bytes) 43 } 44 45 /// Creates a new `HostAlignedByteCount`, rounding up to the nearest page. 46 /// 47 /// Returns an error if `bytes + page_size - 1` overflows. new_rounded_up(bytes: usize) -> Result<Self, ByteCountOutOfBounds>48 pub fn new_rounded_up(bytes: usize) -> Result<Self, ByteCountOutOfBounds> { 49 let page_size = host_page_size(); 50 debug_assert!(page_size.is_power_of_two()); 51 match bytes.checked_add(page_size - 1) { 52 Some(v) => Ok(Self(v & !(page_size - 1))), 53 None => Err(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::RoundUp)), 54 } 55 } 56 57 /// Creates a new `HostAlignedByteCount` from a `u64`, rounding up to the nearest page. 58 /// 59 /// Returns an error if the `u64` overflows `usize`, or if `bytes + 60 /// page_size - 1` overflows. new_rounded_up_u64(bytes: u64) -> Result<Self, ByteCountOutOfBounds>61 pub fn new_rounded_up_u64(bytes: u64) -> Result<Self, ByteCountOutOfBounds> { 62 let bytes = bytes 63 .try_into() 64 .map_err(|_| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::ConvertU64))?; 65 Self::new_rounded_up(bytes) 66 } 67 68 /// Returns the host page size. host_page_size() -> HostAlignedByteCount69 pub fn host_page_size() -> HostAlignedByteCount { 70 // The host page size is always a multiple of itself. 71 HostAlignedByteCount(host_page_size()) 72 } 73 74 /// Returns true if the page count is zero. 75 #[inline] is_zero(self) -> bool76 pub fn is_zero(self) -> bool { 77 self == Self::ZERO 78 } 79 80 /// Returns the number of bytes as a `usize`. 81 #[inline] byte_count(self) -> usize82 pub fn byte_count(self) -> usize { 83 self.0 84 } 85 86 /// Add two aligned byte counts together. 87 /// 88 /// Returns an error if the result overflows. checked_add(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds>89 pub fn checked_add(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> { 90 // aligned + aligned = aligned 91 self.0 92 .checked_add(bytes.0) 93 .map(Self) 94 .ok_or(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Add)) 95 } 96 97 // Note: saturating_add should not be naively added! usize::MAX is not a 98 // power of 2 so is not aligned. 99 100 /// Compute `self - bytes`. 101 /// 102 /// Returns an error if the result underflows. checked_sub(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds>103 pub fn checked_sub(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> { 104 // aligned - aligned = aligned 105 self.0 106 .checked_sub(bytes.0) 107 .map(Self) 108 .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Sub)) 109 } 110 111 /// Compute `self - bytes`, returning zero if the result underflows. 112 #[inline] saturating_sub(self, bytes: HostAlignedByteCount) -> Self113 pub fn saturating_sub(self, bytes: HostAlignedByteCount) -> Self { 114 // aligned - aligned = aligned, and 0 is always aligned. 115 Self(self.0.saturating_sub(bytes.0)) 116 } 117 118 /// Multiply an aligned byte count by a scalar value. 119 /// 120 /// Returns an error if the result overflows. checked_mul(self, scalar: usize) -> Result<Self, ByteCountOutOfBounds>121 pub fn checked_mul(self, scalar: usize) -> Result<Self, ByteCountOutOfBounds> { 122 // aligned * scalar = aligned 123 self.0 124 .checked_mul(scalar) 125 .map(Self) 126 .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Mul)) 127 } 128 129 /// Divide an aligned byte count by another aligned byte count, producing a 130 /// scalar value. 131 /// 132 /// Returns an error in case the divisor is zero. checked_div(self, divisor: HostAlignedByteCount) -> Result<usize, ByteCountOutOfBounds>133 pub fn checked_div(self, divisor: HostAlignedByteCount) -> Result<usize, ByteCountOutOfBounds> { 134 self.0 135 .checked_div(divisor.0) 136 .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Div)) 137 } 138 139 /// Compute the remainder of an aligned byte count divided by another 140 /// aligned byte count. 141 /// 142 /// The remainder is always an aligned byte count itself. 143 /// 144 /// Returns an error in case the divisor is zero. checked_rem(self, divisor: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds>145 pub fn checked_rem(self, divisor: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> { 146 // Why is the remainder an aligned byte count? For example, if the page 147 // size is 4KiB, then the remainder of dividing (say) 40KiB by 16KiB is 148 // 8KiB, which is a multiple of 4KiB. 149 // 150 // More generally, for integers n >= 0, m > 0, k > 0: 151 // 152 // (n * k) % (m * k) = (n % m) * k 153 // 154 // which is a multiple of k. Here, k is the host page size, so the 155 // remainder is a multiple of the host page size. 156 self.0 157 .checked_rem(divisor.0) 158 .map(Self) 159 .ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Rem)) 160 } 161 162 /// Unchecked multiplication by a scalar value. 163 /// 164 /// ## Safety 165 /// 166 /// The result must not overflow. 167 #[inline] unchecked_mul(self, n: usize) -> Self168 pub unsafe fn unchecked_mul(self, n: usize) -> Self { 169 Self(self.0 * n) 170 } 171 } 172 173 impl PartialEq<usize> for HostAlignedByteCount { 174 #[inline] eq(&self, other: &usize) -> bool175 fn eq(&self, other: &usize) -> bool { 176 self.0 == *other 177 } 178 } 179 180 impl PartialEq<HostAlignedByteCount> for usize { 181 #[inline] eq(&self, other: &HostAlignedByteCount) -> bool182 fn eq(&self, other: &HostAlignedByteCount) -> bool { 183 *self == other.0 184 } 185 } 186 187 struct LowerHexDisplay<T>(T); 188 189 impl<T: fmt::LowerHex> fmt::Display for LowerHexDisplay<T> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 191 // Use the LowerHex impl as the Display impl, ensuring that there's 192 // always a 0x in the beginning (i.e. that the alternate formatter is 193 // used.) 194 if f.alternate() { 195 fmt::LowerHex::fmt(&self.0, f) 196 } else { 197 // Unfortunately, fill and alignment aren't respected this way, but 198 // it's quite hard to construct a new formatter with mostly the same 199 // options but the alternate flag set. 200 // https://github.com/rust-lang/rust/pull/118159 would make this 201 // easier. 202 write!(f, "{:#x}", self.0) 203 } 204 } 205 } 206 207 impl fmt::Display for HostAlignedByteCount { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 209 // Use the LowerHex impl as the Display impl, ensuring that there's 210 // always a 0x in the beginning (i.e. that the alternate formatter is 211 // used.) 212 fmt::Display::fmt(&LowerHexDisplay(self.0), f) 213 } 214 } 215 216 impl fmt::LowerHex for HostAlignedByteCount { 217 #[inline] fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 219 fmt::LowerHex::fmt(&self.0, f) 220 } 221 } 222 223 #[derive(Clone, Copy, Debug, PartialEq, Eq)] 224 pub struct ByteCountNotAligned(usize); 225 226 impl fmt::Display for ByteCountNotAligned { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result227 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 228 write!( 229 f, 230 "byte count not page-aligned: {}", 231 LowerHexDisplay(self.0) 232 ) 233 } 234 } 235 236 impl core::error::Error for ByteCountNotAligned {} 237 238 #[derive(Clone, Copy, Debug, PartialEq, Eq)] 239 pub struct ByteCountOutOfBounds(ByteCountOutOfBoundsKind); 240 241 impl fmt::Display for ByteCountOutOfBounds { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 243 write!(f, "{}", self.0) 244 } 245 } 246 247 impl core::error::Error for ByteCountOutOfBounds {} 248 249 #[derive(Clone, Copy, Debug, PartialEq, Eq)] 250 enum ByteCountOutOfBoundsKind { 251 // We don't carry the arguments that errored out to avoid the error type 252 // becoming too big. 253 RoundUp, 254 ConvertU64, 255 Add, 256 Sub, 257 Mul, 258 Div, 259 Rem, 260 } 261 262 impl fmt::Display for ByteCountOutOfBoundsKind { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result263 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 264 match self { 265 ByteCountOutOfBoundsKind::RoundUp => f.write_str("byte count overflow rounding up"), 266 ByteCountOutOfBoundsKind::ConvertU64 => { 267 f.write_str("byte count overflow converting u64") 268 } 269 ByteCountOutOfBoundsKind::Add => f.write_str("byte count overflow during addition"), 270 ByteCountOutOfBoundsKind::Sub => f.write_str("byte count underflow during subtraction"), 271 ByteCountOutOfBoundsKind::Mul => { 272 f.write_str("byte count overflow during multiplication") 273 } 274 ByteCountOutOfBoundsKind::Div => f.write_str("division by zero"), 275 ByteCountOutOfBoundsKind::Rem => f.write_str("remainder by zero"), 276 } 277 } 278 } 279 280 #[cfg(test)] 281 mod proptest_impls { 282 use super::*; 283 284 use proptest::prelude::*; 285 286 impl Arbitrary for HostAlignedByteCount { 287 type Strategy = BoxedStrategy<Self>; 288 type Parameters = (); 289 arbitrary_with(_: ()) -> Self::Strategy290 fn arbitrary_with(_: ()) -> Self::Strategy { 291 // Compute the number of pages that fit in a usize, rounded down. 292 // For example, if: 293 // 294 // * usize::MAX is 2**64 - 1 295 // * host_page_size is 2**12 (4KiB) 296 // 297 // Then page_count = floor(usize::MAX / host_page_size) = 2**52 - 1. 298 // The range 0..=page_count, when multiplied by the page size, will 299 // produce values in the range 0..=(2**64 - 2**12), in steps of 300 // 2**12, uniformly at random. This is the desired uniform 301 // distribution of byte counts. 302 let page_count = usize::MAX / host_page_size(); 303 (0..=page_count) 304 .prop_map(|n| HostAlignedByteCount::new(n * host_page_size()).unwrap()) 305 .boxed() 306 } 307 } 308 } 309 310 #[cfg(test)] 311 mod tests { 312 use super::*; 313 314 #[test] byte_count_display()315 fn byte_count_display() { 316 // Pages should hopefully be 64k or smaller. 317 let byte_count = HostAlignedByteCount::new(65536).unwrap(); 318 319 assert_eq!(format!("{byte_count}"), "0x10000"); 320 assert_eq!(format!("{byte_count:x}"), "10000"); 321 assert_eq!(format!("{byte_count:#x}"), "0x10000"); 322 } 323 324 #[test] byte_count_ops()325 fn byte_count_ops() { 326 let host_page_size = host_page_size(); 327 HostAlignedByteCount::new(0).expect("0 is aligned"); 328 HostAlignedByteCount::new(host_page_size).expect("host_page_size is aligned"); 329 HostAlignedByteCount::new(host_page_size * 2).expect("host_page_size * 2 is aligned"); 330 HostAlignedByteCount::new(host_page_size + 1) 331 .expect_err("host_page_size + 1 is not aligned"); 332 HostAlignedByteCount::new(host_page_size / 2) 333 .expect_err("host_page_size / 2 is not aligned"); 334 335 // Rounding up. 336 HostAlignedByteCount::new_rounded_up(usize::MAX).expect_err("usize::MAX overflows"); 337 assert_eq!( 338 HostAlignedByteCount::new_rounded_up(usize::MAX - host_page_size) 339 .expect("(usize::MAX - 1 page) is in bounds"), 340 HostAlignedByteCount::new((usize::MAX - host_page_size) + 1) 341 .expect("usize::MAX is 2**N - 1"), 342 ); 343 344 // Addition. 345 let half_max = HostAlignedByteCount::new((usize::MAX >> 1) + 1) 346 .expect("(usize::MAX >> 1) + 1 is aligned"); 347 half_max 348 .checked_add(HostAlignedByteCount::host_page_size()) 349 .expect("half max + page size is in bounds"); 350 half_max 351 .checked_add(half_max) 352 .expect_err("half max + half max is out of bounds"); 353 354 // Subtraction. 355 let half_max_minus_one = half_max 356 .checked_sub(HostAlignedByteCount::host_page_size()) 357 .expect("(half_max - 1 page) is in bounds"); 358 assert_eq!( 359 half_max.checked_sub(half_max), 360 Ok(HostAlignedByteCount::ZERO) 361 ); 362 assert_eq!( 363 half_max.checked_sub(half_max_minus_one), 364 Ok(HostAlignedByteCount::host_page_size()) 365 ); 366 half_max_minus_one 367 .checked_sub(half_max) 368 .expect_err("(half_max - 1 page) - half_max is out of bounds"); 369 370 // Multiplication. 371 half_max 372 .checked_mul(2) 373 .expect_err("half max * 2 is out of bounds"); 374 half_max_minus_one 375 .checked_mul(2) 376 .expect("(half max - 1 page) * 2 is in bounds"); 377 } 378 } 379