xref: /wasmtime-44.0.1/crates/core/src/undo.rs (revision fdd8ed9e)
1 //! Helpers for undoing partial side effects when their larger operation fails.
2 
3 use core::{fmt, mem, ops};
4 
5 /// An RAII guard to rollback and undo something on (early) drop.
6 ///
7 /// Dereferences to its inner `T` and its undo function is given the `T` on
8 /// drop.
9 ///
10 /// When all of the changes that need to happen together have happened, you can
11 /// call `Undo::commit` to disable the guard and commit the associated side
12 /// effects.
13 ///
14 /// # Example
15 ///
16 /// ```
17 /// use std::cell::Cell;
18 /// use wasmtime_internal_core::{error::Result, undo::Undo};
19 ///
20 /// /// Some big ball of state that must always be coherent.
21 /// pub struct Context {
22 ///     // ...
23 /// }
24 ///
25 /// impl Context {
26 ///     /// Perform some incremental mutation to `self`, which might not leave
27 ///     /// it in a valid state unless its whole batch of work is completed.
28 ///     fn do_thing(&mut self, arg: u32) -> Result<()> {
29 /// #       let _ = arg;
30 /// #       todo!()
31 ///         // ...
32 ///     }
33 ///
34 ///     /// Undo the side effects of `self.do_thing(arg)` for when we need to
35 ///     /// roll back mutations.
36 ///     fn undo_thing(&mut self, arg: u32) {
37 /// #       let _ = arg;
38 ///         // ...
39 ///     }
40 ///
41 ///     /// Call `self.do_thing(arg)` for each `arg` in `args`.
42 ///     ///
43 ///     /// However, if any `self.do_thing(arg)` call fails, make sure that
44 ///     /// we roll back to the original state by calling `self.undo_thing(arg)`
45 ///     /// for all the `self.do_thing(arg)` calls that already succeeded. This
46 ///     /// way we never leave `self` in a state where things got half-done.
47 ///     pub fn do_all_or_nothing(&mut self, args: &[u32]) -> Result<()> {
48 ///         // Counter for our progress, so that we know how much to work undo upon
49 ///         // failure.
50 ///         let num_things_done = Cell::new(0);
51 ///
52 ///         // Wrap the `Context` in an `Undo` that rolls back our side effects if
53 ///         // we early-exit this function via `?`-propagation or panic unwinding.
54 ///         let mut ctx = Undo::new(self, |ctx| {
55 ///             for arg in args.iter().take(num_things_done.get()) {
56 ///                 ctx.undo_thing(*arg);
57 ///             }
58 ///         });
59 ///
60 ///         // Do each piece of work!
61 ///         for arg in args {
62 ///             // Note: if this call returns an error that is `?`-propagated or
63 ///             // triggers unwinding by panicking, then the work performed thus
64 ///             // far will be rolled back when `ctx` is dropped.
65 ///             ctx.do_thing(*arg)?;
66 ///
67 ///             // Update how much work has been completed.
68 ///             num_things_done.set(num_things_done.get() + 1);
69 ///         }
70 ///
71 ///         // We completed all of the work, so commit the `Undo` guard and
72 ///         // disable its cleanup function.
73 ///         Undo::commit(ctx);
74 ///
75 ///         Ok(())
76 ///     }
77 /// }
78 /// ```
79 #[must_use = "`Undo` implicitly runs its undo function on drop; use `Undo::commit(...)` \
80               to disable"]
81 pub struct Undo<T, F>
82 where
83     F: FnOnce(T),
84 {
85     inner: mem::ManuallyDrop<T>,
86     undo: mem::ManuallyDrop<F>,
87 }
88 
89 impl<T, F> Drop for Undo<T, F>
90 where
91     F: FnOnce(T),
92 {
drop(&mut self)93     fn drop(&mut self) {
94         // Safety: These `ManuallyDrop` fields will not be used again.
95         let inner = unsafe { mem::ManuallyDrop::take(&mut self.inner) };
96         let undo = unsafe { mem::ManuallyDrop::take(&mut self.undo) };
97         undo(inner);
98     }
99 }
100 
101 impl<T, F> fmt::Debug for Undo<T, F>
102 where
103     F: FnOnce(T),
104     T: fmt::Debug,
105 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result106     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107         f.debug_struct("Undo")
108             .field("inner", &self.inner)
109             .field("undo", &"..")
110             .finish()
111     }
112 }
113 
114 impl<T, F> ops::Deref for Undo<T, F>
115 where
116     F: FnOnce(T),
117 {
118     type Target = T;
119 
deref(&self) -> &Self::Target120     fn deref(&self) -> &Self::Target {
121         &self.inner
122     }
123 }
124 
125 impl<T, F> ops::DerefMut for Undo<T, F>
126 where
127     F: FnOnce(T),
128 {
deref_mut(&mut self) -> &mut Self::Target129     fn deref_mut(&mut self) -> &mut Self::Target {
130         &mut self.inner
131     }
132 }
133 
134 impl<T, F> Undo<T, F>
135 where
136     F: FnOnce(T),
137 {
138     /// Create a new `Undo` guard.
139     ///
140     /// This guard will wrap the given `inner` object and call `undo(inner)`
141     /// when dropped, unless the guard is disabled via `Undo::commit`.
new(inner: T, undo: F) -> Self142     pub fn new(inner: T, undo: F) -> Self {
143         Self {
144             inner: mem::ManuallyDrop::new(inner),
145             undo: mem::ManuallyDrop::new(undo),
146         }
147     }
148 
149     /// Disable this `Undo` and return its inner value.
150     ///
151     /// This `Undo`'s cleanup function will never be called.
commit(guard: Self) -> T152     pub fn commit(guard: Self) -> T {
153         let mut guard = mem::ManuallyDrop::new(guard);
154 
155         // Safety: These `ManuallyDrop` fields will not be used again.
156         unsafe {
157             // Make sure to drop `undo`, even though we aren't calling it, to
158             // avoid leaking closed-over `Arc`s, for example.
159             mem::ManuallyDrop::drop(&mut guard.undo);
160 
161             mem::ManuallyDrop::take(&mut guard.inner)
162         }
163     }
164 }
165 
166 #[cfg(all(test, feature = "std"))]
167 mod tests {
168     use super::*;
169     use crate::error::{Result, ensure};
170     use core::{cell::Cell, cmp};
171     use std::{panic, string::ToString};
172 
173     #[derive(Default)]
174     struct Counter {
175         value: u32,
176         max_value_seen: u32,
177     }
178 
179     impl Counter {
inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()>180         fn inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
181             f(self)?;
182             self.value += 1;
183             self.max_value_seen = cmp::max(self.max_value_seen, self.value);
184             Ok(())
185         }
186 
dec(&mut self)187         fn dec(&mut self) {
188             self.value -= 1;
189         }
190 
inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()>191         fn inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> {
192             let i = Cell::new(0);
193 
194             let mut counter = Undo::new(self, |counter| {
195                 for _ in 0..i.get() {
196                     counter.dec();
197                 }
198             });
199 
200             for _ in 0..n {
201                 counter.inc(&mut f)?;
202                 i.set(i.get() + 1);
203             }
204 
205             Undo::commit(counter);
206             Ok(())
207         }
208     }
209 
210     #[test]
error_propagation()211     fn error_propagation() {
212         let mut counter = Counter::default();
213         let result = counter.inc_n(10, |c| {
214             ensure!(c.value < 5, "uh oh");
215             Ok(())
216         });
217         assert_eq!(result.unwrap_err().to_string(), "uh oh");
218         assert_eq!(counter.value, 0);
219         assert_eq!(counter.max_value_seen, 5);
220     }
221 
222     #[test]
panic_unwind()223     fn panic_unwind() {
224         let mut counter = Counter::default();
225         let result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
226             counter.inc_n(10, |c| {
227                 assert!(c.value < 5);
228                 Ok(())
229             })
230         }));
231         assert!(result.is_err());
232         assert_eq!(counter.value, 0);
233         assert_eq!(counter.max_value_seen, 5);
234     }
235 
236     #[test]
commit()237     fn commit() {
238         let mut counter = Counter::default();
239         let result = counter.inc_n(10, |_| Ok(()));
240         assert!(result.is_ok());
241         assert_eq!(counter.value, 10);
242         assert_eq!(counter.max_value_seen, 10);
243     }
244 }
245