1 //! Simple string interning.
2 
3 use crate::{error::OutOfMemory, prelude::*};
4 use core::{fmt, mem, num::NonZeroU32};
5 
6 /// An interned string associated with a particular string in a `StringPool`.
7 ///
8 /// Allows for $O(1)$ equality tests, $O(1)$ hashing, and $O(1)$
9 /// arbitrary-but-stable ordering.
10 #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
11 pub struct Atom {
12     index: NonZeroU32,
13 }
14 
15 /// A pool of interned strings.
16 ///
17 /// Insert new strings with [`StringPool::insert`] to get an `Atom` that is
18 /// unique per string within the context of the associated pool.
19 ///
20 /// Once you have interned a string into the pool and have its `Atom`, you can
21 /// get the interned string slice via `&pool[atom]` or `pool.get(atom)`.
22 ///
23 /// In general, there are no correctness protections against indexing into a
24 /// different `StringPool` from the one that the `Atom` was not allocated
25 /// inside. Doing so is memory safe but may panic or otherwise return incorrect
26 /// results.
27 #[derive(Default)]
28 pub struct StringPool {
29     /// A map from each string in this pool (as an unsafe borrow from
30     /// `self.strings`) to its `Atom`.
31     map: mem::ManuallyDrop<TryHashMap<&'static str, Atom>>,
32 
33     /// Strings in this pool. These must never be mutated or reallocated once
34     /// inserted.
35     strings: mem::ManuallyDrop<TryVec<Box<str>>>,
36 }
37 
38 impl Drop for StringPool {
drop(&mut self)39     fn drop(&mut self) {
40         // Ensure that `self.map` is dropped before `self.strings`, since
41         // `self.map` borrows from `self.strings`.
42         //
43         // Safety: Neither field will be used again.
44         unsafe {
45             mem::ManuallyDrop::drop(&mut self.map);
46             mem::ManuallyDrop::drop(&mut self.strings);
47         }
48     }
49 }
50 
51 impl fmt::Debug for StringPool {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result52     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53         struct Strings<'a>(&'a StringPool);
54         impl fmt::Debug for Strings<'_> {
55             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56                 f.debug_map()
57                     .entries(
58                         self.0
59                             .strings
60                             .iter()
61                             .enumerate()
62                             .map(|(i, s)| (Atom::new(i), s)),
63                     )
64                     .finish()
65             }
66         }
67 
68         f.debug_struct("StringPool")
69             .field("strings", &Strings(self))
70             .finish()
71     }
72 }
73 
74 impl TryClone for StringPool {
try_clone(&self) -> Result<Self, OutOfMemory>75     fn try_clone(&self) -> Result<Self, OutOfMemory> {
76         let mut new_pool = StringPool::new();
77         // Re-intern strings in index order so that each Atom value is
78         // identical in the clone — callers that hold Atoms from the original
79         // can use them interchangeably with the clone.
80         //
81         // Directly cloning `self.map` would copy &'static str keys that point
82         // into the *original* pool's `strings` allocation. Those pointers
83         // become dangling once the original is dropped, leading to UB on any
84         // subsequent lookup. Re-interning ensures the cloned map's keys point
85         // into the clone's own `strings`.
86         for s in self.strings.iter() {
87             new_pool.insert(s)?;
88         }
89         Ok(new_pool)
90     }
91 }
92 
93 impl TryClone for Atom {
try_clone(&self) -> Result<Self, OutOfMemory>94     fn try_clone(&self) -> Result<Self, OutOfMemory> {
95         Ok(*self)
96     }
97 }
98 
99 impl core::ops::Index<Atom> for StringPool {
100     type Output = str;
101 
102     #[inline]
103     #[track_caller]
index(&self, atom: Atom) -> &Self::Output104     fn index(&self, atom: Atom) -> &Self::Output {
105         self.get(atom).unwrap()
106     }
107 }
108 
109 // For convenience, to avoid `*atom` noise at call sites.
110 impl core::ops::Index<&'_ Atom> for StringPool {
111     type Output = str;
112 
113     #[inline]
114     #[track_caller]
index(&self, atom: &Atom) -> &Self::Output115     fn index(&self, atom: &Atom) -> &Self::Output {
116         self.get(*atom).unwrap()
117     }
118 }
119 
120 impl serde::ser::Serialize for StringPool {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer,121     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
122     where
123         S: serde::Serializer,
124     {
125         serde::ser::Serialize::serialize(&*self.strings, serializer)
126     }
127 }
128 
129 impl<'de> serde::de::Deserialize<'de> for StringPool {
deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: serde::Deserializer<'de>,130     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
131     where
132         D: serde::Deserializer<'de>,
133     {
134         struct Visitor;
135         impl<'de> serde::de::Visitor<'de> for Visitor {
136             type Value = StringPool;
137 
138             fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
139                 f.write_str("a `StringPool` sequence of strings")
140             }
141 
142             fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
143             where
144                 A: serde::de::SeqAccess<'de>,
145             {
146                 use serde::de::Error as _;
147 
148                 let mut pool = StringPool::new();
149 
150                 if let Some(len) = seq.size_hint() {
151                     pool.map.reserve(len).map_err(|oom| A::Error::custom(oom))?;
152                     pool.strings
153                         .reserve(len)
154                         .map_err(|oom| A::Error::custom(oom))?;
155                 }
156 
157                 while let Some(s) = seq.next_element::<TryString>()? {
158                     debug_assert_eq!(s.len(), s.capacity());
159                     let s = s.into_boxed_str().map_err(|oom| A::Error::custom(oom))?;
160                     if !pool.map.contains_key(&*s) {
161                         pool.insert_new_boxed_str(s)
162                             .map_err(|oom| A::Error::custom(oom))?;
163                     }
164                 }
165 
166                 Ok(pool)
167             }
168         }
169         deserializer.deserialize_seq(Visitor)
170     }
171 }
172 
173 impl StringPool {
174     /// Create a new, empty pool.
new() -> Self175     pub fn new() -> Self {
176         Self::default()
177     }
178 
179     /// Insert a new string into this pool.
insert(&mut self, s: &str) -> Result<Atom, OutOfMemory>180     pub fn insert(&mut self, s: &str) -> Result<Atom, OutOfMemory> {
181         if let Some(atom) = self.map.get(s) {
182             return Ok(*atom);
183         }
184 
185         self.map.reserve(1)?;
186         self.strings.reserve(1)?;
187 
188         let mut owned = TryString::new();
189         owned.reserve_exact(s.len())?;
190         owned.push_str(s).expect("reserved capacity");
191         let owned = owned
192             .into_boxed_str()
193             .expect("reserved exact capacity, so shouldn't need to realloc");
194 
195         self.insert_new_boxed_str(owned)
196     }
197 
insert_new_boxed_str(&mut self, owned: Box<str>) -> Result<Atom, OutOfMemory>198     fn insert_new_boxed_str(&mut self, owned: Box<str>) -> Result<Atom, OutOfMemory> {
199         debug_assert!(!self.map.contains_key(&*owned));
200 
201         let index = self.strings.len();
202         let atom = Atom::new(index);
203         self.strings.push(owned)?;
204 
205         // SAFETY: We never expose this borrow and never mutate or reallocate
206         // strings once inserted into the pool.
207         let s = unsafe { mem::transmute::<&str, &'static str>(&self.strings[index]) };
208 
209         let old = self.map.insert(s, atom)?;
210         debug_assert!(old.is_none());
211 
212         Ok(atom)
213     }
214 
215     /// Get the `Atom` for the given string, if it has already been inserted
216     /// into this pool.
get_atom(&self, s: &str) -> Option<Atom>217     pub fn get_atom(&self, s: &str) -> Option<Atom> {
218         self.map.get(s).copied()
219     }
220 
221     /// Does this pool contain the given `atom`?
222     #[inline]
contains(&self, atom: Atom) -> bool223     pub fn contains(&self, atom: Atom) -> bool {
224         atom.index() < self.strings.len()
225     }
226 
227     /// Get the string associated with the given `atom`, if the pool contains
228     /// the atom.
229     #[inline]
get(&self, atom: Atom) -> Option<&str>230     pub fn get(&self, atom: Atom) -> Option<&str> {
231         if self.contains(atom) {
232             Some(&self.strings[atom.index()])
233         } else {
234             None
235         }
236     }
237 
238     /// Get the number of strings in this pool.
len(&self) -> usize239     pub fn len(&self) -> usize {
240         self.strings.len()
241     }
242 }
243 
244 impl Default for Atom {
245     #[inline]
default() -> Self246     fn default() -> Self {
247         Self {
248             index: NonZeroU32::MAX,
249         }
250     }
251 }
252 
253 impl fmt::Debug for Atom {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result254     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255         f.debug_struct("Atom")
256             .field("index", &self.index())
257             .finish()
258     }
259 }
260 
261 // Allow using `Atom` in `SecondaryMap`s.
262 impl crate::EntityRef for Atom {
new(index: usize) -> Self263     fn new(index: usize) -> Self {
264         Atom::new(index)
265     }
266 
index(self) -> usize267     fn index(self) -> usize {
268         Atom::index(&self)
269     }
270 }
271 
272 impl serde::ser::Serialize for Atom {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer,273     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
274     where
275         S: serde::Serializer,
276     {
277         serde::ser::Serialize::serialize(&self.index, serializer)
278     }
279 }
280 
281 impl<'de> serde::de::Deserialize<'de> for Atom {
deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error> where D: serde::Deserializer<'de>,282     fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
283     where
284         D: serde::Deserializer<'de>,
285     {
286         let index = serde::de::Deserialize::deserialize(deserializer)?;
287         Ok(Self { index })
288     }
289 }
290 
291 impl Atom {
new(index: usize) -> Self292     fn new(index: usize) -> Self {
293         assert!(index < usize::try_from(u32::MAX).unwrap());
294         let index = u32::try_from(index).unwrap();
295         let index = NonZeroU32::new(index + 1).unwrap();
296         Self { index }
297     }
298 
299     /// Get this atom's index in its pool.
index(&self) -> usize300     pub fn index(&self) -> usize {
301         let index = self.index.get() - 1;
302         usize::try_from(index).unwrap()
303     }
304 }
305 
306 #[cfg(test)]
307 mod tests {
308     use super::*;
309 
310     #[test]
basic() -> Result<()>311     fn basic() -> Result<()> {
312         let mut pool = StringPool::new();
313 
314         let a = pool.insert("a")?;
315         assert_eq!(&pool[a], "a");
316         assert_eq!(pool.get_atom("a"), Some(a));
317 
318         let a2 = pool.insert("a")?;
319         assert_eq!(a, a2);
320         assert_eq!(&pool[a2], "a");
321 
322         let b = pool.insert("b")?;
323         assert_eq!(&pool[b], "b");
324         assert_ne!(a, b);
325         assert_eq!(pool.get_atom("b"), Some(b));
326 
327         assert!(pool.get_atom("zzz").is_none());
328 
329         let mut pool2 = StringPool::new();
330         let c = pool2.insert("c")?;
331         assert_eq!(&pool2[c], "c");
332         assert_eq!(a, c);
333         assert_eq!(&pool2[a], "c");
334         assert!(!pool2.contains(b));
335         assert!(pool2.get(b).is_none());
336 
337         Ok(())
338     }
339 
340     #[test]
stress() -> Result<()>341     fn stress() -> Result<()> {
342         let mut pool = StringPool::new();
343 
344         let n = if cfg!(miri) { 100 } else { 10_000 };
345 
346         for _ in 0..2 {
347             let atoms: TryVec<_> = (0..n).map(|i| pool.insert(&i.to_string())).try_collect()?;
348 
349             for atom in atoms {
350                 assert!(pool.contains(atom));
351                 assert_eq!(&pool[atom], atom.index().to_string());
352             }
353         }
354 
355         Ok(())
356     }
357 
358     #[test]
roundtrip_serialize_deserialize() -> Result<()>359     fn roundtrip_serialize_deserialize() -> Result<()> {
360         let mut pool = StringPool::new();
361         let a = pool.insert("a")?;
362         let b = pool.insert("b")?;
363         let c = pool.insert("c")?;
364 
365         let bytes = postcard::to_allocvec(&(pool, a, b, c))?;
366         let (pool, a2, b2, c2) = postcard::from_bytes::<(StringPool, Atom, Atom, Atom)>(&bytes)?;
367 
368         assert_eq!(&pool[a], "a");
369         assert_eq!(&pool[b], "b");
370         assert_eq!(&pool[c], "c");
371 
372         assert_eq!(&pool[a2], "a");
373         assert_eq!(&pool[b2], "b");
374         assert_eq!(&pool[c2], "c");
375 
376         Ok(())
377     }
378 }
379