1 use std::mem::{self, ManuallyDrop}; 2 use std::pin::Pin; 3 use std::sync::{Arc, Mutex}; 4 use std::task::{Context, Poll, Waker}; 5 6 /// Handle to a task which may be used to join on the result of executing it. 7 /// 8 /// This represents a handle to a running task which can be cancelled with 9 /// [`JoinHandle::abort`]. The final result and drop of the task can be 10 /// determined by `await`-ing this handle. 11 /// 12 /// Note that dropping this handle does not affect the running task it's 13 /// connected to. A manual invocation of [`JoinHandle::abort`] is required to 14 /// affect the task. 15 pub struct JoinHandle { 16 state: Arc<Mutex<JoinState>>, 17 } 18 19 enum JoinState { 20 /// The task this is connected to is still running and has not completed or 21 /// been dropped. 22 Running { 23 /// The waker that the running task has registered which is signaled 24 /// upon abort. 25 waiting_for_abort_signal: Option<Waker>, 26 27 /// The waker that the `JoinHandle` has registered to await 28 /// destruction of the running task itself. 29 waiting_for_abort_to_complete: Option<Waker>, 30 }, 31 32 /// An abort as been requested through an `JoinHandle`. The task specified 33 /// here is used for `Future for JoinHandle`. 34 AbortRequested { 35 waiting_for_abort_to_complete: Option<Waker>, 36 }, 37 38 /// The running task has completed, so no need to abort it and nothing else 39 /// needs to wait. 40 Complete, 41 } 42 43 impl JoinHandle { 44 /// Abort the task. 45 /// 46 /// This flags the connected task should abort in the near future, but note 47 /// that if this is called while the future is being polled then that call 48 /// will still complete. 49 /// 50 /// Note that this `JoinHandle` is itself a `Future` and can be used to 51 /// await the result and destruction of the task that this is associated 52 /// with. abort(&self)53 pub fn abort(&self) { 54 let mut state = self.state.lock().unwrap(); 55 56 match &mut *state { 57 // If this task is still running, then fall through to below to 58 // transition it into the `AbortRequested` state. If present the 59 // waker for the running task is notified to indicate that an abort 60 // signal has been received. 61 JoinState::Running { 62 waiting_for_abort_signal, 63 waiting_for_abort_to_complete, 64 } => { 65 if let Some(task) = waiting_for_abort_signal.take() { 66 task.wake(); 67 } 68 69 *state = JoinState::AbortRequested { 70 waiting_for_abort_to_complete: waiting_for_abort_to_complete.take(), 71 }; 72 } 73 74 // If this task has already been aborted or has completed, nothing 75 // is left to do. 76 JoinState::AbortRequested { .. } | JoinState::Complete => {} 77 } 78 } 79 80 /// Wraps the `future` provided in a new future which is "abortable" where 81 /// if the returned `JoinHandle` is flagged then the future will resolve 82 /// ASAP with `None` and drop the provided `future`. run<F>(future: F) -> (JoinHandle, impl Future<Output = Option<F::Output>>) where F: Future,83 pub(crate) fn run<F>(future: F) -> (JoinHandle, impl Future<Output = Option<F::Output>>) 84 where 85 F: Future, 86 { 87 let handle = JoinHandle { 88 state: Arc::new(Mutex::new(JoinState::Running { 89 waiting_for_abort_signal: None, 90 waiting_for_abort_to_complete: None, 91 })), 92 }; 93 let future = JoinHandleFuture { 94 future: ManuallyDrop::new(future), 95 state: handle.state.clone(), 96 }; 97 (handle, future) 98 } 99 } 100 101 impl Future for JoinHandle { 102 type Output = (); 103 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>104 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 105 let mut state = self.state.lock().unwrap(); 106 match &mut *state { 107 // If this task is running or still only has requested an abort, 108 // wait further for the task to get dropped. 109 JoinState::Running { 110 waiting_for_abort_to_complete, 111 .. 112 } 113 | JoinState::AbortRequested { 114 waiting_for_abort_to_complete, 115 } => { 116 *waiting_for_abort_to_complete = Some(cx.waker().clone()); 117 Poll::Pending 118 } 119 120 // The task is dropped, done! 121 JoinState::Complete => Poll::Ready(()), 122 } 123 } 124 } 125 126 struct JoinHandleFuture<F> { 127 future: ManuallyDrop<F>, 128 state: Arc<Mutex<JoinState>>, 129 } 130 131 impl<F> Future for JoinHandleFuture<F> 132 where 133 F: Future, 134 { 135 type Output = Option<F::Output>; 136 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>137 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 138 // SAFETY: this is a pin-projection from `Self` to the state and `Pin` 139 // of the internal future. This is the exclusive access of these fields 140 // apart from the destructor and should be safe. 141 let (state, future) = unsafe { 142 let me = self.get_unchecked_mut(); 143 (&me.state, Pin::new_unchecked(&mut *me.future)) 144 }; 145 146 // First, before polling the future, check to see if we've been 147 // aborted. If not register our task as awaiting such an abort. 148 { 149 let mut state = state.lock().unwrap(); 150 match &mut *state { 151 JoinState::Running { 152 waiting_for_abort_signal, 153 .. 154 } => { 155 *waiting_for_abort_signal = Some(cx.waker().clone()); 156 } 157 JoinState::AbortRequested { .. } | JoinState::Complete => { 158 return Poll::Ready(None); 159 } 160 } 161 } 162 163 future.poll(cx).map(Some) 164 } 165 } 166 167 impl<F> Drop for JoinHandleFuture<F> { drop(&mut self)168 fn drop(&mut self) { 169 // SAFETY: this is the exclusive owner of this future and it's safe to 170 // drop here during the owning destructor. 171 // 172 // Note that this explicitly happens before notifying the abort handle 173 // that the task completed so that when the notification goes through 174 // it's guaranteed that the future has been destroyed. 175 unsafe { 176 ManuallyDrop::drop(&mut self.future); 177 } 178 179 // After the future dropped see if there was a task awaiting its 180 // destruction. Simultaneously flag this state as complete. 181 let prev = mem::replace(&mut *self.state.lock().unwrap(), JoinState::Complete); 182 let task = match prev { 183 JoinState::Running { 184 waiting_for_abort_to_complete, 185 .. 186 } 187 | JoinState::AbortRequested { 188 waiting_for_abort_to_complete, 189 } => waiting_for_abort_to_complete, 190 JoinState::Complete => None, 191 }; 192 if let Some(task) = task { 193 task.wake(); 194 } 195 } 196 } 197 198 #[cfg(test)] 199 mod tests { 200 use super::JoinHandle; 201 use std::pin::{Pin, pin}; 202 use std::task::{Context, Poll, Waker}; 203 use tokio::sync::oneshot; 204 is_ready<F>(future: Pin<&mut F>) -> bool where F: Future,205 fn is_ready<F>(future: Pin<&mut F>) -> bool 206 where 207 F: Future, 208 { 209 match future.poll(&mut Context::from_waker(Waker::noop())) { 210 Poll::Ready(_) => true, 211 Poll::Pending => false, 212 } 213 } 214 215 #[tokio::test] abort_in_progress()216 async fn abort_in_progress() { 217 let (tx, rx) = oneshot::channel::<()>(); 218 let (mut handle, future) = JoinHandle::run(rx); 219 let mut handle = Pin::new(&mut handle); 220 { 221 let mut future = pin!(future); 222 assert!(!is_ready(future.as_mut())); 223 assert!(!is_ready(handle.as_mut())); 224 handle.abort(); 225 assert!(is_ready(future.as_mut())); 226 assert!(!is_ready(handle.as_mut())); 227 assert!(!tx.is_closed()); 228 } 229 assert!(is_ready(handle.as_mut())); 230 assert!(tx.is_closed()); 231 } 232 233 #[tokio::test] abort_complete()234 async fn abort_complete() { 235 let (tx, rx) = oneshot::channel::<()>(); 236 let (mut handle, future) = JoinHandle::run(rx); 237 let mut handle = Pin::new(&mut handle); 238 tx.send(()).unwrap(); 239 assert!(!is_ready(handle.as_mut())); 240 { 241 let mut future = pin!(future); 242 assert!(is_ready(future.as_mut())); 243 assert!(!is_ready(handle.as_mut())); 244 } 245 assert!(is_ready(handle.as_mut())); 246 handle.abort(); 247 assert!(is_ready(handle.as_mut())); 248 } 249 250 #[tokio::test] abort_dropped()251 async fn abort_dropped() { 252 let (tx, rx) = oneshot::channel::<()>(); 253 let (mut handle, future) = JoinHandle::run(rx); 254 let mut handle = Pin::new(&mut handle); 255 drop(future); 256 assert!(is_ready(handle.as_mut())); 257 handle.abort(); 258 assert!(is_ready(handle.as_mut())); 259 assert!(tx.is_closed()); 260 } 261 262 #[tokio::test] await_completion()263 async fn await_completion() { 264 let (tx, rx) = oneshot::channel::<()>(); 265 tx.send(()).unwrap(); 266 let (handle, future) = JoinHandle::run(rx); 267 let task = tokio::task::spawn(future); 268 handle.await; 269 task.await.unwrap(); 270 } 271 272 #[tokio::test] await_abort()273 async fn await_abort() { 274 let (tx, rx) = oneshot::channel::<()>(); 275 tx.send(()).unwrap(); 276 let (handle, future) = JoinHandle::run(rx); 277 handle.abort(); 278 let task = tokio::task::spawn(future); 279 handle.await; 280 task.await.unwrap(); 281 } 282 } 283