1 use std::mem::{self, ManuallyDrop};
2 use std::pin::Pin;
3 use std::sync::{Arc, Mutex};
4 use std::task::{Context, Poll, Waker};
5 
6 /// Handle to a task which may be used to join on the result of executing it.
7 ///
8 /// This represents a handle to a running task which can be cancelled with
9 /// [`JoinHandle::abort`]. The final result and drop of the task can be
10 /// determined by `await`-ing this handle.
11 ///
12 /// Note that dropping this handle does not affect the running task it's
13 /// connected to. A manual invocation of [`JoinHandle::abort`] is required to
14 /// affect the task.
15 pub struct JoinHandle {
16     state: Arc<Mutex<JoinState>>,
17 }
18 
19 enum JoinState {
20     /// The task this is connected to is still running and has not completed or
21     /// been dropped.
22     Running {
23         /// The waker that the running task has registered which is signaled
24         /// upon abort.
25         waiting_for_abort_signal: Option<Waker>,
26 
27         /// The waker that the `JoinHandle` has registered to await
28         /// destruction of the running task itself.
29         waiting_for_abort_to_complete: Option<Waker>,
30     },
31 
32     /// An abort as been requested through an `JoinHandle`. The task specified
33     /// here is used for `Future for JoinHandle`.
34     AbortRequested {
35         waiting_for_abort_to_complete: Option<Waker>,
36     },
37 
38     /// The running task has completed, so no need to abort it and nothing else
39     /// needs to wait.
40     Complete,
41 }
42 
43 impl JoinHandle {
44     /// Abort the task.
45     ///
46     /// This flags the connected task should abort in the near future, but note
47     /// that if this is called while the future is being polled then that call
48     /// will still complete.
49     ///
50     /// Note that this `JoinHandle` is itself a `Future` and can be used to
51     /// await the result and destruction of the task that this is associated
52     /// with.
abort(&self)53     pub fn abort(&self) {
54         let mut state = self.state.lock().unwrap();
55 
56         match &mut *state {
57             // If this task is still running, then fall through to below to
58             // transition it into the `AbortRequested` state. If present the
59             // waker for the running task is notified to indicate that an abort
60             // signal has been received.
61             JoinState::Running {
62                 waiting_for_abort_signal,
63                 waiting_for_abort_to_complete,
64             } => {
65                 if let Some(task) = waiting_for_abort_signal.take() {
66                     task.wake();
67                 }
68 
69                 *state = JoinState::AbortRequested {
70                     waiting_for_abort_to_complete: waiting_for_abort_to_complete.take(),
71                 };
72             }
73 
74             // If this task has already been aborted or has completed, nothing
75             // is left to do.
76             JoinState::AbortRequested { .. } | JoinState::Complete => {}
77         }
78     }
79 
80     /// Wraps the `future` provided in a new future which is "abortable" where
81     /// if the returned `JoinHandle` is flagged then the future will resolve
82     /// ASAP with `None` and drop the provided `future`.
run<F>(future: F) -> (JoinHandle, impl Future<Output = Option<F::Output>>) where F: Future,83     pub(crate) fn run<F>(future: F) -> (JoinHandle, impl Future<Output = Option<F::Output>>)
84     where
85         F: Future,
86     {
87         let handle = JoinHandle {
88             state: Arc::new(Mutex::new(JoinState::Running {
89                 waiting_for_abort_signal: None,
90                 waiting_for_abort_to_complete: None,
91             })),
92         };
93         let future = JoinHandleFuture {
94             future: ManuallyDrop::new(future),
95             state: handle.state.clone(),
96         };
97         (handle, future)
98     }
99 }
100 
101 impl Future for JoinHandle {
102     type Output = ();
103 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>104     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
105         let mut state = self.state.lock().unwrap();
106         match &mut *state {
107             // If this task is running or still only has requested an abort,
108             // wait further for the task to get dropped.
109             JoinState::Running {
110                 waiting_for_abort_to_complete,
111                 ..
112             }
113             | JoinState::AbortRequested {
114                 waiting_for_abort_to_complete,
115             } => {
116                 *waiting_for_abort_to_complete = Some(cx.waker().clone());
117                 Poll::Pending
118             }
119 
120             // The task is dropped, done!
121             JoinState::Complete => Poll::Ready(()),
122         }
123     }
124 }
125 
126 struct JoinHandleFuture<F> {
127     future: ManuallyDrop<F>,
128     state: Arc<Mutex<JoinState>>,
129 }
130 
131 impl<F> Future for JoinHandleFuture<F>
132 where
133     F: Future,
134 {
135     type Output = Option<F::Output>;
136 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>137     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138         // SAFETY: this is a pin-projection from `Self` to the state and `Pin`
139         // of the internal future. This is the exclusive access of these fields
140         // apart from the destructor and should be safe.
141         let (state, future) = unsafe {
142             let me = self.get_unchecked_mut();
143             (&me.state, Pin::new_unchecked(&mut *me.future))
144         };
145 
146         // First, before polling the future, check to see if we've been
147         // aborted. If not register our task as awaiting such an abort.
148         {
149             let mut state = state.lock().unwrap();
150             match &mut *state {
151                 JoinState::Running {
152                     waiting_for_abort_signal,
153                     ..
154                 } => {
155                     *waiting_for_abort_signal = Some(cx.waker().clone());
156                 }
157                 JoinState::AbortRequested { .. } | JoinState::Complete => {
158                     return Poll::Ready(None);
159                 }
160             }
161         }
162 
163         future.poll(cx).map(Some)
164     }
165 }
166 
167 impl<F> Drop for JoinHandleFuture<F> {
drop(&mut self)168     fn drop(&mut self) {
169         // SAFETY: this is the exclusive owner of this future and it's safe to
170         // drop here during the owning destructor.
171         //
172         // Note that this explicitly happens before notifying the abort handle
173         // that the task completed so that when the notification goes through
174         // it's guaranteed that the future has been destroyed.
175         unsafe {
176             ManuallyDrop::drop(&mut self.future);
177         }
178 
179         // After the future dropped see if there was a task awaiting its
180         // destruction. Simultaneously flag this state as complete.
181         let prev = mem::replace(&mut *self.state.lock().unwrap(), JoinState::Complete);
182         let task = match prev {
183             JoinState::Running {
184                 waiting_for_abort_to_complete,
185                 ..
186             }
187             | JoinState::AbortRequested {
188                 waiting_for_abort_to_complete,
189             } => waiting_for_abort_to_complete,
190             JoinState::Complete => None,
191         };
192         if let Some(task) = task {
193             task.wake();
194         }
195     }
196 }
197 
198 #[cfg(test)]
199 mod tests {
200     use super::JoinHandle;
201     use std::pin::{Pin, pin};
202     use std::task::{Context, Poll, Waker};
203     use tokio::sync::oneshot;
204 
is_ready<F>(future: Pin<&mut F>) -> bool where F: Future,205     fn is_ready<F>(future: Pin<&mut F>) -> bool
206     where
207         F: Future,
208     {
209         match future.poll(&mut Context::from_waker(Waker::noop())) {
210             Poll::Ready(_) => true,
211             Poll::Pending => false,
212         }
213     }
214 
215     #[tokio::test]
abort_in_progress()216     async fn abort_in_progress() {
217         let (tx, rx) = oneshot::channel::<()>();
218         let (mut handle, future) = JoinHandle::run(rx);
219         let mut handle = Pin::new(&mut handle);
220         {
221             let mut future = pin!(future);
222             assert!(!is_ready(future.as_mut()));
223             assert!(!is_ready(handle.as_mut()));
224             handle.abort();
225             assert!(is_ready(future.as_mut()));
226             assert!(!is_ready(handle.as_mut()));
227             assert!(!tx.is_closed());
228         }
229         assert!(is_ready(handle.as_mut()));
230         assert!(tx.is_closed());
231     }
232 
233     #[tokio::test]
abort_complete()234     async fn abort_complete() {
235         let (tx, rx) = oneshot::channel::<()>();
236         let (mut handle, future) = JoinHandle::run(rx);
237         let mut handle = Pin::new(&mut handle);
238         tx.send(()).unwrap();
239         assert!(!is_ready(handle.as_mut()));
240         {
241             let mut future = pin!(future);
242             assert!(is_ready(future.as_mut()));
243             assert!(!is_ready(handle.as_mut()));
244         }
245         assert!(is_ready(handle.as_mut()));
246         handle.abort();
247         assert!(is_ready(handle.as_mut()));
248     }
249 
250     #[tokio::test]
abort_dropped()251     async fn abort_dropped() {
252         let (tx, rx) = oneshot::channel::<()>();
253         let (mut handle, future) = JoinHandle::run(rx);
254         let mut handle = Pin::new(&mut handle);
255         drop(future);
256         assert!(is_ready(handle.as_mut()));
257         handle.abort();
258         assert!(is_ready(handle.as_mut()));
259         assert!(tx.is_closed());
260     }
261 
262     #[tokio::test]
await_completion()263     async fn await_completion() {
264         let (tx, rx) = oneshot::channel::<()>();
265         tx.send(()).unwrap();
266         let (handle, future) = JoinHandle::run(rx);
267         let task = tokio::task::spawn(future);
268         handle.await;
269         task.await.unwrap();
270     }
271 
272     #[tokio::test]
await_abort()273     async fn await_abort() {
274         let (tx, rx) = oneshot::channel::<()>();
275         tx.send(()).unwrap();
276         let (handle, future) = JoinHandle::run(rx);
277         handle.abort();
278         let task = tokio::task::spawn(future);
279         handle.await;
280         task.await.unwrap();
281     }
282 }
283