1 //! Simple string interning.
2 
3 use crate::{error::OutOfMemory, prelude::*};
4 use core::{fmt, mem, num::NonZeroU32};
5 
6 /// An interned string associated with a particular string in a `StringPool`.
7 ///
8 /// Allows for $O(1)$ equality tests, $O(1)$ hashing, and $O(1)$
9 /// arbitrary-but-stable ordering.
10 #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
11 pub struct Atom {
12     index: NonZeroU32,
13 }
14 
15 /// A pool of interned strings.
16 ///
17 /// Insert new strings with [`StringPool::insert`] to get an `Atom` that is
18 /// unique per string within the context of the associated pool.
19 ///
20 /// Once you have interned a string into the pool and have its `Atom`, you can
21 /// get the interned string slice via `&pool[atom]` or `pool.get(atom)`.
22 ///
23 /// In general, there are no correctness protections against indexing into a
24 /// different `StringPool` from the one that the `Atom` was not allocated
25 /// inside. Doing so is memory safe but may panic or otherwise return incorrect
26 /// results.
27 #[derive(Default)]
28 pub struct StringPool {
29     /// A map from each string in this pool (as an unsafe borrow from
30     /// `self.strings`) to its `Atom`.
31     map: mem::ManuallyDrop<TryHashMap<&'static str, Atom>>,
32 
33     /// Strings in this pool. These must never be mutated or reallocated once
34     /// inserted.
35     strings: mem::ManuallyDrop<TryVec<Box<str>>>,
36 }
37 
38 impl Drop for StringPool {
39     fn drop(&mut self) {
40         // Ensure that `self.map` is dropped before `self.strings`, since
41         // `self.map` borrows from `self.strings`.
42         //
43         // Safety: Neither field will be used again.
44         unsafe {
45             mem::ManuallyDrop::drop(&mut self.map);
46             mem::ManuallyDrop::drop(&mut self.strings);
47         }
48     }
49 }
50 
51 impl fmt::Debug for StringPool {
52     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53         struct Strings<'a>(&'a StringPool);
54         impl fmt::Debug for Strings<'_> {
55             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56                 f.debug_map()
57                     .entries(
58                         self.0
59                             .strings
60                             .iter()
61                             .enumerate()
62                             .map(|(i, s)| (Atom::new(i), s)),
63                     )
64                     .finish()
65             }
66         }
67 
68         f.debug_struct("StringPool")
69             .field("strings", &Strings(self))
70             .finish()
71     }
72 }
73 
74 impl TryClone for StringPool {
75     fn try_clone(&self) -> Result<Self, OutOfMemory> {
76         Ok(StringPool {
77             map: self.map.try_clone()?,
78             strings: self.strings.try_clone()?,
79         })
80     }
81 }
82 
83 impl TryClone for Atom {
84     fn try_clone(&self) -> Result<Self, OutOfMemory> {
85         Ok(*self)
86     }
87 }
88 
89 impl core::ops::Index<Atom> for StringPool {
90     type Output = str;
91 
92     #[inline]
93     #[track_caller]
94     fn index(&self, atom: Atom) -> &Self::Output {
95         self.get(atom).unwrap()
96     }
97 }
98 
99 // For convenience, to avoid `*atom` noise at call sites.
100 impl core::ops::Index<&'_ Atom> for StringPool {
101     type Output = str;
102 
103     #[inline]
104     #[track_caller]
105     fn index(&self, atom: &Atom) -> &Self::Output {
106         self.get(*atom).unwrap()
107     }
108 }
109 
110 impl serde::ser::Serialize for StringPool {
111     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
112     where
113         S: serde::Serializer,
114     {
115         serde::ser::Serialize::serialize(&*self.strings, serializer)
116     }
117 }
118 
119 impl<'de> serde::de::Deserialize<'de> for StringPool {
120     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
121     where
122         D: serde::Deserializer<'de>,
123     {
124         struct Visitor;
125         impl<'de> serde::de::Visitor<'de> for Visitor {
126             type Value = StringPool;
127 
128             fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
129                 f.write_str("a `StringPool` sequence of strings")
130             }
131 
132             fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
133             where
134                 A: serde::de::SeqAccess<'de>,
135             {
136                 use serde::de::Error as _;
137 
138                 let mut pool = StringPool::new();
139 
140                 if let Some(len) = seq.size_hint() {
141                     pool.map.reserve(len).map_err(|oom| A::Error::custom(oom))?;
142                     pool.strings
143                         .reserve(len)
144                         .map_err(|oom| A::Error::custom(oom))?;
145                 }
146 
147                 while let Some(s) = seq.next_element::<TryString>()? {
148                     debug_assert_eq!(s.len(), s.capacity());
149                     let s = s.into_boxed_str().map_err(|oom| A::Error::custom(oom))?;
150                     if !pool.map.contains_key(&*s) {
151                         pool.insert_new_boxed_str(s)
152                             .map_err(|oom| A::Error::custom(oom))?;
153                     }
154                 }
155 
156                 Ok(pool)
157             }
158         }
159         deserializer.deserialize_seq(Visitor)
160     }
161 }
162 
163 impl StringPool {
164     /// Create a new, empty pool.
165     pub fn new() -> Self {
166         Self::default()
167     }
168 
169     /// Insert a new string into this pool.
170     pub fn insert(&mut self, s: &str) -> Result<Atom, OutOfMemory> {
171         if let Some(atom) = self.map.get(s) {
172             return Ok(*atom);
173         }
174 
175         self.map.reserve(1)?;
176         self.strings.reserve(1)?;
177 
178         let mut owned = TryString::new();
179         owned.reserve_exact(s.len())?;
180         owned.push_str(s).expect("reserved capacity");
181         let owned = owned
182             .into_boxed_str()
183             .expect("reserved exact capacity, so shouldn't need to realloc");
184 
185         self.insert_new_boxed_str(owned)
186     }
187 
188     fn insert_new_boxed_str(&mut self, owned: Box<str>) -> Result<Atom, OutOfMemory> {
189         debug_assert!(!self.map.contains_key(&*owned));
190 
191         let index = self.strings.len();
192         let atom = Atom::new(index);
193         self.strings.push(owned)?;
194 
195         // SAFETY: We never expose this borrow and never mutate or reallocate
196         // strings once inserted into the pool.
197         let s = unsafe { mem::transmute::<&str, &'static str>(&self.strings[index]) };
198 
199         let old = self.map.insert(s, atom)?;
200         debug_assert!(old.is_none());
201 
202         Ok(atom)
203     }
204 
205     /// Get the `Atom` for the given string, if it has already been inserted
206     /// into this pool.
207     pub fn get_atom(&self, s: &str) -> Option<Atom> {
208         self.map.get(s).copied()
209     }
210 
211     /// Does this pool contain the given `atom`?
212     #[inline]
213     pub fn contains(&self, atom: Atom) -> bool {
214         atom.index() < self.strings.len()
215     }
216 
217     /// Get the string associated with the given `atom`, if the pool contains
218     /// the atom.
219     #[inline]
220     pub fn get(&self, atom: Atom) -> Option<&str> {
221         if self.contains(atom) {
222             Some(&self.strings[atom.index()])
223         } else {
224             None
225         }
226     }
227 
228     /// Get the number of strings in this pool.
229     pub fn len(&self) -> usize {
230         self.strings.len()
231     }
232 }
233 
234 impl Default for Atom {
235     #[inline]
236     fn default() -> Self {
237         Self {
238             index: NonZeroU32::MAX,
239         }
240     }
241 }
242 
243 impl fmt::Debug for Atom {
244     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245         f.debug_struct("Atom")
246             .field("index", &self.index())
247             .finish()
248     }
249 }
250 
251 // Allow using `Atom` in `SecondaryMap`s.
252 impl crate::EntityRef for Atom {
253     fn new(index: usize) -> Self {
254         Atom::new(index)
255     }
256 
257     fn index(self) -> usize {
258         Atom::index(&self)
259     }
260 }
261 
262 impl serde::ser::Serialize for Atom {
263     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
264     where
265         S: serde::Serializer,
266     {
267         serde::ser::Serialize::serialize(&self.index, serializer)
268     }
269 }
270 
271 impl<'de> serde::de::Deserialize<'de> for Atom {
272     fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
273     where
274         D: serde::Deserializer<'de>,
275     {
276         let index = serde::de::Deserialize::deserialize(deserializer)?;
277         Ok(Self { index })
278     }
279 }
280 
281 impl Atom {
282     fn new(index: usize) -> Self {
283         assert!(index < usize::try_from(u32::MAX).unwrap());
284         let index = u32::try_from(index).unwrap();
285         let index = NonZeroU32::new(index + 1).unwrap();
286         Self { index }
287     }
288 
289     /// Get this atom's index in its pool.
290     pub fn index(&self) -> usize {
291         let index = self.index.get() - 1;
292         usize::try_from(index).unwrap()
293     }
294 }
295 
296 #[cfg(test)]
297 mod tests {
298     use super::*;
299 
300     #[test]
301     fn basic() -> Result<()> {
302         let mut pool = StringPool::new();
303 
304         let a = pool.insert("a")?;
305         assert_eq!(&pool[a], "a");
306         assert_eq!(pool.get_atom("a"), Some(a));
307 
308         let a2 = pool.insert("a")?;
309         assert_eq!(a, a2);
310         assert_eq!(&pool[a2], "a");
311 
312         let b = pool.insert("b")?;
313         assert_eq!(&pool[b], "b");
314         assert_ne!(a, b);
315         assert_eq!(pool.get_atom("b"), Some(b));
316 
317         assert!(pool.get_atom("zzz").is_none());
318 
319         let mut pool2 = StringPool::new();
320         let c = pool2.insert("c")?;
321         assert_eq!(&pool2[c], "c");
322         assert_eq!(a, c);
323         assert_eq!(&pool2[a], "c");
324         assert!(!pool2.contains(b));
325         assert!(pool2.get(b).is_none());
326 
327         Ok(())
328     }
329 
330     #[test]
331     fn stress() -> Result<()> {
332         let mut pool = StringPool::new();
333 
334         let n = if cfg!(miri) { 100 } else { 10_000 };
335 
336         for _ in 0..2 {
337             let atoms: TryVec<_> = (0..n).map(|i| pool.insert(&i.to_string())).try_collect()?;
338 
339             for atom in atoms {
340                 assert!(pool.contains(atom));
341                 assert_eq!(&pool[atom], atom.index().to_string());
342             }
343         }
344 
345         Ok(())
346     }
347 
348     #[test]
349     fn roundtrip_serialize_deserialize() -> Result<()> {
350         let mut pool = StringPool::new();
351         let a = pool.insert("a")?;
352         let b = pool.insert("b")?;
353         let c = pool.insert("c")?;
354 
355         let bytes = postcard::to_allocvec(&(pool, a, b, c))?;
356         let (pool, a2, b2, c2) = postcard::from_bytes::<(StringPool, Atom, Atom, Atom)>(&bytes)?;
357 
358         assert_eq!(&pool[a], "a");
359         assert_eq!(&pool[b], "b");
360         assert_eq!(&pool[c], "c");
361 
362         assert_eq!(&pool[a2], "a");
363         assert_eq!(&pool[b2], "b");
364         assert_eq!(&pool[c2], "c");
365 
366         Ok(())
367     }
368 }
369