1 #[cfg(test)]
2 mod operation_test;
3 
4 use std::fmt;
5 use std::future::Future;
6 use std::pin::Pin;
7 use std::sync::atomic::{AtomicUsize, Ordering};
8 use std::sync::Arc;
9 use tokio::sync::mpsc;
10 use waitgroup::WaitGroup;
11 
12 use crate::error::Result;
13 
14 /// Operation is a function
15 pub struct Operation(
16     pub Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = bool> + Send + 'static>>) + Send + Sync>,
17     pub &'static str,
18 );
19 
20 impl Operation {
new( op: impl FnMut() -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> + Send + Sync + 'static, description: &'static str, ) -> Self21     pub(crate) fn new(
22         op: impl FnMut() -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> + Send + Sync + 'static,
23         description: &'static str,
24     ) -> Self {
25         Self(Box::new(op), description)
26     }
27 }
28 
29 impl fmt::Debug for Operation {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result30     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31         f.debug_tuple("Operation")
32             .field(&"_")
33             .field(&self.1)
34             .finish()
35     }
36 }
37 
38 /// Operations is a task executor.
39 #[derive(Default)]
40 pub(crate) struct Operations {
41     length: Arc<AtomicUsize>,
42     ops_tx: Option<Arc<mpsc::UnboundedSender<Operation>>>,
43     close_tx: Option<mpsc::Sender<()>>,
44 }
45 
46 impl Operations {
new() -> Self47     pub(crate) fn new() -> Self {
48         let length = Arc::new(AtomicUsize::new(0));
49         let (ops_tx, ops_rx) = mpsc::unbounded_channel();
50         let (close_tx, close_rx) = mpsc::channel(1);
51         let l = Arc::clone(&length);
52         let ops_tx = Arc::new(ops_tx);
53         let ops_tx2 = Arc::clone(&ops_tx);
54         tokio::spawn(async move {
55             Operations::start(l, ops_tx, ops_rx, close_rx).await;
56         });
57 
58         Operations {
59             length,
60             ops_tx: Some(ops_tx2),
61             close_tx: Some(close_tx),
62         }
63     }
64 
65     /// enqueue adds a new action to be executed. If there are no actions scheduled,
66     /// the execution will start immediately in a new goroutine.
enqueue(&self, op: Operation) -> Result<()>67     pub(crate) async fn enqueue(&self, op: Operation) -> Result<()> {
68         if let Some(ops_tx) = &self.ops_tx {
69             return Operations::enqueue_inner(op, ops_tx, &self.length);
70         }
71 
72         Ok(())
73     }
74 
enqueue_inner( op: Operation, ops_tx: &Arc<mpsc::UnboundedSender<Operation>>, length: &Arc<AtomicUsize>, ) -> Result<()>75     fn enqueue_inner(
76         op: Operation,
77         ops_tx: &Arc<mpsc::UnboundedSender<Operation>>,
78         length: &Arc<AtomicUsize>,
79     ) -> Result<()> {
80         length.fetch_add(1, Ordering::SeqCst);
81         ops_tx.send(op)?;
82 
83         Ok(())
84     }
85 
86     /// is_empty checks if there are tasks in the queue
is_empty(&self) -> bool87     pub(crate) async fn is_empty(&self) -> bool {
88         self.length.load(Ordering::SeqCst) == 0
89     }
90 
91     /// Done blocks until all currently enqueued operations are finished executing.
92     /// For more complex synchronization, use Enqueue directly.
done(&self)93     pub(crate) async fn done(&self) {
94         let wg = WaitGroup::new();
95         let mut w = Some(wg.worker());
96         let _ = self
97             .enqueue(Operation::new(
98                 move || {
99                     let _d = w.take();
100                     Box::pin(async { false })
101                 },
102                 "Operation::done",
103             ))
104             .await;
105         wg.wait().await;
106     }
107 
start( length: Arc<AtomicUsize>, ops_tx: Arc<mpsc::UnboundedSender<Operation>>, mut ops_rx: mpsc::UnboundedReceiver<Operation>, mut close_rx: mpsc::Receiver<()>, )108     pub(crate) async fn start(
109         length: Arc<AtomicUsize>,
110         ops_tx: Arc<mpsc::UnboundedSender<Operation>>,
111         mut ops_rx: mpsc::UnboundedReceiver<Operation>,
112         mut close_rx: mpsc::Receiver<()>,
113     ) {
114         loop {
115             tokio::select! {
116                 _ = close_rx.recv() => {
117                     break;
118                 }
119                 result = ops_rx.recv() => {
120                     if let Some(mut f) = result {
121                         length.fetch_sub(1, Ordering::SeqCst);
122                         if f.0().await {
123                             // Requeue this operation
124                             let _ = Operations::enqueue_inner(f, &ops_tx, &length);
125                         }
126                     }
127                 }
128             }
129         }
130     }
131 
close(&self) -> Result<()>132     pub(crate) async fn close(&self) -> Result<()> {
133         if let Some(close_tx) = &self.close_tx {
134             close_tx.send(()).await?;
135         }
136         Ok(())
137     }
138 }
139