1 //! Helpers for undoing partial side effects when their larger operation fails. 2 3 use core::{fmt, mem, ops}; 4 5 /// An RAII guard to rollback and undo something on (early) drop. 6 /// 7 /// Dereferences to its inner `T` and its undo function is given the `T` on 8 /// drop. 9 /// 10 /// When all of the changes that need to happen together have happened, you can 11 /// call `Undo::commit` to disable the guard and commit the associated side 12 /// effects. 13 /// 14 /// # Example 15 /// 16 /// ``` 17 /// use std::cell::Cell; 18 /// use wasmtime_internal_core::{error::Result, undo::Undo}; 19 /// 20 /// /// Some big ball of state that must always be coherent. 21 /// pub struct Context { 22 /// // ... 23 /// } 24 /// 25 /// impl Context { 26 /// /// Perform some incremental mutation to `self`, which might not leave 27 /// /// it in a valid state unless its whole batch of work is completed. 28 /// fn do_thing(&mut self, arg: u32) -> Result<()> { 29 /// # let _ = arg; 30 /// # todo!() 31 /// // ... 32 /// } 33 /// 34 /// /// Undo the side effects of `self.do_thing(arg)` for when we need to 35 /// /// roll back mutations. 36 /// fn undo_thing(&mut self, arg: u32) { 37 /// # let _ = arg; 38 /// // ... 39 /// } 40 /// 41 /// /// Call `self.do_thing(arg)` for each `arg` in `args`. 42 /// /// 43 /// /// However, if any `self.do_thing(arg)` call fails, make sure that 44 /// /// we roll back to the original state by calling `self.undo_thing(arg)` 45 /// /// for all the `self.do_thing(arg)` calls that already succeeded. This 46 /// /// way we never leave `self` in a state where things got half-done. 47 /// pub fn do_all_or_nothing(&mut self, args: &[u32]) -> Result<()> { 48 /// // Counter for our progress, so that we know how much to work undo upon 49 /// // failure. 50 /// let num_things_done = Cell::new(0); 51 /// 52 /// // Wrap the `Context` in an `Undo` that rolls back our side effects if 53 /// // we early-exit this function via `?`-propagation or panic unwinding. 54 /// let mut ctx = Undo::new(self, |ctx| { 55 /// for arg in args.iter().take(num_things_done.get()) { 56 /// ctx.undo_thing(*arg); 57 /// } 58 /// }); 59 /// 60 /// // Do each piece of work! 61 /// for arg in args { 62 /// // Note: if this call returns an error that is `?`-propagated or 63 /// // triggers unwinding by panicking, then the work performed thus 64 /// // far will be rolled back when `ctx` is dropped. 65 /// ctx.do_thing(*arg)?; 66 /// 67 /// // Update how much work has been completed. 68 /// num_things_done.set(num_things_done.get() + 1); 69 /// } 70 /// 71 /// // We completed all of the work, so commit the `Undo` guard and 72 /// // disable its cleanup function. 73 /// Undo::commit(ctx); 74 /// 75 /// Ok(()) 76 /// } 77 /// } 78 /// ``` 79 #[must_use = "`Undo` implicitly runs its undo function on drop; use `Undo::commit(...)` \ 80 to disable"] 81 pub struct Undo<T, F> 82 where 83 F: FnOnce(T), 84 { 85 inner: mem::ManuallyDrop<T>, 86 undo: mem::ManuallyDrop<F>, 87 } 88 89 impl<T, F> Drop for Undo<T, F> 90 where 91 F: FnOnce(T), 92 { drop(&mut self)93 fn drop(&mut self) { 94 // Safety: These `ManuallyDrop` fields will not be used again. 95 let inner = unsafe { mem::ManuallyDrop::take(&mut self.inner) }; 96 let undo = unsafe { mem::ManuallyDrop::take(&mut self.undo) }; 97 undo(inner); 98 } 99 } 100 101 impl<T, F> fmt::Debug for Undo<T, F> 102 where 103 F: FnOnce(T), 104 T: fmt::Debug, 105 { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 107 f.debug_struct("Undo") 108 .field("inner", &self.inner) 109 .field("undo", &"..") 110 .finish() 111 } 112 } 113 114 impl<T, F> ops::Deref for Undo<T, F> 115 where 116 F: FnOnce(T), 117 { 118 type Target = T; 119 deref(&self) -> &Self::Target120 fn deref(&self) -> &Self::Target { 121 &self.inner 122 } 123 } 124 125 impl<T, F> ops::DerefMut for Undo<T, F> 126 where 127 F: FnOnce(T), 128 { deref_mut(&mut self) -> &mut Self::Target129 fn deref_mut(&mut self) -> &mut Self::Target { 130 &mut self.inner 131 } 132 } 133 134 impl<T, F> Undo<T, F> 135 where 136 F: FnOnce(T), 137 { 138 /// Create a new `Undo` guard. 139 /// 140 /// This guard will wrap the given `inner` object and call `undo(inner)` 141 /// when dropped, unless the guard is disabled via `Undo::commit`. new(inner: T, undo: F) -> Self142 pub fn new(inner: T, undo: F) -> Self { 143 Self { 144 inner: mem::ManuallyDrop::new(inner), 145 undo: mem::ManuallyDrop::new(undo), 146 } 147 } 148 149 /// Disable this `Undo` and return its inner value. 150 /// 151 /// This `Undo`'s cleanup function will never be called. commit(guard: Self) -> T152 pub fn commit(guard: Self) -> T { 153 let mut guard = mem::ManuallyDrop::new(guard); 154 155 // Safety: These `ManuallyDrop` fields will not be used again. 156 unsafe { 157 // Make sure to drop `undo`, even though we aren't calling it, to 158 // avoid leaking closed-over `Arc`s, for example. 159 mem::ManuallyDrop::drop(&mut guard.undo); 160 161 mem::ManuallyDrop::take(&mut guard.inner) 162 } 163 } 164 } 165 166 #[cfg(all(test, feature = "std"))] 167 mod tests { 168 use super::*; 169 use crate::error::{Result, ensure}; 170 use core::{cell::Cell, cmp}; 171 use std::{panic, string::ToString}; 172 173 #[derive(Default)] 174 struct Counter { 175 value: u32, 176 max_value_seen: u32, 177 } 178 179 impl Counter { inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()>180 fn inc(&mut self, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> { 181 f(self)?; 182 self.value += 1; 183 self.max_value_seen = cmp::max(self.max_value_seen, self.value); 184 Ok(()) 185 } 186 dec(&mut self)187 fn dec(&mut self) { 188 self.value -= 1; 189 } 190 inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()>191 fn inc_n(&mut self, n: u32, mut f: impl FnMut(&Self) -> Result<()>) -> Result<()> { 192 let i = Cell::new(0); 193 194 let mut counter = Undo::new(self, |counter| { 195 for _ in 0..i.get() { 196 counter.dec(); 197 } 198 }); 199 200 for _ in 0..n { 201 counter.inc(&mut f)?; 202 i.set(i.get() + 1); 203 } 204 205 Undo::commit(counter); 206 Ok(()) 207 } 208 } 209 210 #[test] error_propagation()211 fn error_propagation() { 212 let mut counter = Counter::default(); 213 let result = counter.inc_n(10, |c| { 214 ensure!(c.value < 5, "uh oh"); 215 Ok(()) 216 }); 217 assert_eq!(result.unwrap_err().to_string(), "uh oh"); 218 assert_eq!(counter.value, 0); 219 assert_eq!(counter.max_value_seen, 5); 220 } 221 222 #[test] panic_unwind()223 fn panic_unwind() { 224 let mut counter = Counter::default(); 225 let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { 226 counter.inc_n(10, |c| { 227 assert!(c.value < 5); 228 Ok(()) 229 }) 230 })); 231 assert!(result.is_err()); 232 assert_eq!(counter.value, 0); 233 assert_eq!(counter.max_value_seen, 5); 234 } 235 236 #[test] commit()237 fn commit() { 238 let mut counter = Counter::default(); 239 let result = counter.inc_n(10, |_| Ok(())); 240 assert!(result.is_ok()); 241 assert_eq!(counter.value, 10); 242 assert_eq!(counter.max_value_seen, 10); 243 } 244 } 245